In [202]:
# util 

import os
import pickle
import random
import numpy as np
import networkx as nx
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn.functional as F
import torch.nn as nn

from typing import Callable, Tuple, Union

from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing, NNConv, CGConv, GINEConv
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset, zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size

from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader, NeighborLoader, ClusterData, ClusterLoader
from torch_geometric.utils import from_networkx, to_networkx

from IPython.display import clear_output

DATA_FOLDER = "./data"

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# random.seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic=True
# torch.backends.cudnn.benchmark = False

In [203]:
def construct_bigLITTLE_graph(data_folder, no_duplicate=False, unobserved=0.0, unobserved_edge=0.0):
    list_files = os.listdir(data_folder)
    list_files = list(filter(lambda x: x.endswith(".pkl"), list_files))
    list_labels = pickle.load(open("labels.pkl", "rb"))
    
    dict_graph_size = defaultdict(lambda: [])
    set_cycle_size = set([])
    set_branch_size = set([])
    
    # Directed or Undirected? Edge weights?
    bigLITTLE_graph = nx.DiGraph()
    gid = 0
    
    for gf in list_files:
        idx, cycle_size, branch_size, _ = gf.split("_")
        cycle_size = int(cycle_size)
        branch_size = int(branch_size)
        
        if no_duplicate and len(dict_graph_size[(cycle_size, branch_size)]) > 0:
            continue
            
        # graph = nx.read_gpickle(os.path.join(DATA_FOLDER, gf))
        
        # For testing
        # if cycle_size > 5 or branch_size > 2:
        #     continue
        
        bigLITTLE_graph.add_node(gid, features=[float(cycle_size), float(branch_size)], label=list_labels[int(idx)])
        dict_graph_size[(cycle_size, branch_size)].append(gid)
        
        set_cycle_size.add(cycle_size)
        set_branch_size.add(branch_size)
        gid += 1
        
    # Same cycle size & Different branch size ==> Edge: [0, 1]
    # Filter lists of same cycle_size
    for cs in set_cycle_size:
        list_same_cycle_size = list(filter(lambda x: x[0]==cs, dict_graph_size.keys()))
        list_same_cycle_size = list(sorted(list_same_cycle_size, key=lambda x:x[1]))
        # print(list_same_cycle_size)
        for bs_idx1 in range(len(list_same_cycle_size)-1):
            bs_idx2 = bs_idx1 + 1
            key_cb1 = list_same_cycle_size[bs_idx1] # e.g. (3, 1)
            key_cb2 = list_same_cycle_size[bs_idx2] # e.g. (3, 2)

            for gid1 in dict_graph_size[key_cb1]:
                for gid2 in dict_graph_size[key_cb2]:
                    # print(key_cb2[1]-key_cb1[1])
                    bigLITTLE_graph.add_edge(gid1, gid2, e=[0,1])
                    bigLITTLE_graph.add_edge(gid2, gid1, e=[0,-1])
                        
    # Different cycle size & Same branch size ==> Edge: [1, 0]
    for bs in set_branch_size:
        list_same_branch_size = list(filter(lambda x: x[1]==bs, dict_graph_size.keys()))
        list_same_branch_size = list(sorted(list_same_branch_size, key=lambda x:x[0]))
        
        for cs_idx1 in range(len(list_same_branch_size)-1):
            cs_idx2 = cs_idx1 + 1
            key_cb1 = list_same_branch_size[cs_idx1] # e.g. (3, 1)
            key_cb2 = list_same_branch_size[cs_idx2] # e.g. (4, 1)

            for gid1 in dict_graph_size[key_cb1]:
                for gid2 in dict_graph_size[key_cb2]:
                    bigLITTLE_graph.add_edge(gid1, gid2, e=[1,0])
                    bigLITTLE_graph.add_edge(gid2, gid1, e=[-1,0])
                    
    
    # Add all other edge as [0,0]
    list_nodes = list(bigLITTLE_graph.nodes)
    for i in range(len(list_nodes)):
        for j in range(i, len(list_nodes)):
            nid_i = list_nodes[i]
            nid_j = list_nodes[j]
            if not bigLITTLE_graph.has_edge(nid_i, nid_j):
                bigLITTLE_graph.add_edge(gid1, gid2, e=[0,0])
                bigLITTLE_graph.add_edge(gid2, gid1, e=[0,0])
            
    if unobserved_edge > 0:
        # Remove unobserved edges
        
        num_edge_to_remove = int(bigLITTLE_graph.number_of_edges() * unobserved_edge)
            
        list_edges = np.array([list(e) for e in bigLITTLE_graph.edges])
        # >>> n_edges * 2
        list_unique_edges = list_edges[list_edges[:,1] >= list_edges[:,0]]
        edge_idxs = np.random.choice([0,1], 
                                size=list_unique_edges.shape[0], 
                                p=[unobserved_edge, 1-unobserved_edge])
        list_remove_edges = list_unique_edges[edge_idxs==0]
        
        for edge in list_remove_edges:
            u, v = edge
            bigLITTLE_graph.remove_edge(u,v)
            bigLITTLE_graph.remove_edge(v,u)
        
        return bigLITTLE_graph, []
    
    return bigLITTLE_graph

In [204]:
def get_edge_color(e):
    if e == [0,1]:
        return "black"
    elif e == [0,-1]:
        return "red"
    elif e == [1,0]:
        return "green"
    elif e == [-1,0]:
        return "blue"
    else:
        return "yellow"
    
def draw_graph(graph):
    nodeLabels = {nid:graph.nodes[nid]["label"] for nid in graph.nodes}
    nodeColors = "grey"
    edgeColor = [get_edge_color(graph.edges[eid]["e"])for eid in graph.edges]

    nx.draw(graph, nx.kamada_kawai_layout(graph), edge_color=edgeColor, width=1, linewidths=0.1,
              node_size=500, node_color=nodeColors, alpha=0.9,
              labels=nodeLabels)
    
def transform_func(graph, unobserved_edge_mask=None):
    graph.x = graph.x.to("cuda:0")
    graph.edge_index = graph.edge_index.to("cuda:0")
    graph.edge_attr = graph.edge_attr.to("cuda:0")
    return graph

In [205]:
class TrickyActivation(nn.Module):
    def __init__(self, low = -1, high=1, inplace=False):
        super().__init__()
        self.low = -1
        self.high = 1
        self.zero = torch.zeros(1).to("cuda:0")
        self.inplace = inplace

    def forward(self, e):
        # F.threshold: y = x if x > theshold else default_value
        e = F.threshold(e, self.low, 0, self.inplace)
        e = -e
        e = F.threshold(e, -self.high, 0, self.inplace)
        e = -e
        return e

In [206]:
class GNN(nn.Module):
    def __init__(self, node_channels, edge_channels, hidden_channels):
        super(GNN, self).__init__()
        
        self.node_embed = nn.Linear(node_channels, hidden_channels, bias=False)
        
        self.lin1 = nn.Linear(node_channels, hidden_channels, bias=True)
        self.lin2 = nn.Linear(hidden_channels, edge_channels, bias=True)
        self.out_act = TrickyActivation()

    def forward(self, x, observed_edge_nid, batch):
        
        '''
        x: (num_nodes, 2)
        observed_edge_nid: (num_observed_edge, 2)
        '''
        # 1. Obtain node embeddings
        z = self.node_embed(x)
        z = z.relu()
        
        # 2. Apply a final classifier
        # z = F.dropout(z, p=0.1, training=True)
        
        # observed_edge_nid: [(1,2), (3,4)] ==> (x[1] <-> x[2])
        # (40.1 30.1) ==> e:[1,0]
        # (50.1 70.1) ==> e:[2,0]
        head = x[observed_edge_nid[1]]
        tail = x[observed_edge_nid[0]]
        
        e = head - tail
        e = self.lin1(e)
        e = self.lin2(e)
        return e
    
def train(loader):
    model.train()
    total_loss = 0
    steps = 0

    # Iterate in batches over the training dataset
    for data in loader:
        # Perform a single forward pass
        out = model(data.x, data.edge_index, data.batch)
        
        # Compute the loss
        loss = criterion(
            out, 
            data.edge_attr
        )
        total_loss += loss
        
        loss.backward(); optimizer.step(); optimizer.zero_grad(); steps += 1

    return total_loss / steps

def test(loader, mc_dropout_sample=100):
    model.eval()
    mse = 0
    steps = 0

    # Iterate in batches over the training/test dataset
    for data in loader:
        
        out = []
        for _ in range(mc_dropout_sample):
            out.append(model(data.x, data.edge_index, data.batch))
        out = torch.stack(out)
        
        # Check against ground-truth labels
        mse += criterion(out.mean(0), data.edge_attr)
        
        steps += 1
        
        
        # out_std = out.std(0)
    return mse / steps  # Derive ratio of correct predictions.

In [207]:
generate_data = True
unobserved_fraction = 0.2

if generate_data:
    bL_graph_train, unobserved_edge_mask = construct_bigLITTLE_graph(DATA_FOLDER, unobserved_edge=unobserved_fraction)
    # draw_graph(bL_graph_train)
    
    data_train = from_networkx(bL_graph_train)
    data_train.x = data_train.features.type(torch.FloatTensor)
    data_train.edge_attr = data_train.e.type(torch.FloatTensor)
    data_train = transform_func(data_train, unobserved_edge_mask)
    
    c_data_train = ClusterData(data_train, num_parts=1)
    train_loader = ClusterLoader(c_data_train)
    
    
    
    bL_graph_test = construct_bigLITTLE_graph(DATA_FOLDER)
    # draw_graph(bL_graph)
    
    data_test = from_networkx(bL_graph_test)
    data_test.x = data_test.features.type(torch.FloatTensor)
    data_test.edge_attr = data_test.e.type(torch.FloatTensor)
    data_test = transform_func(data_test)
    
    c_data_test = ClusterData(data_test, num_parts=1)
    test_loader = ClusterLoader(c_data_test)

Computing METIS partitioning...
Done!
Computing METIS partitioning...
Done!


In [210]:
min_mse = 1e10
min_epoch = 0
epochs = 100000
lr = 0.0005
device = "cuda:0"
hidden_channels = 64

model = GNN(node_channels=2, edge_channels=2, 
            hidden_channels=hidden_channels).to(device)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(model)
print("Number of parameters: ", params)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.L1Loss()


for epoch in tqdm(range(epochs)):
    train_mse = train(train_loader)
    
    if (epoch+1) % 1000 == 0:
        clear_output(wait=True)
        test_mse = test(test_loader)
        if test_mse < min_mse:
            min_mse = test_mse
            min_epoch = epoch
        print(f'Epoch: {epoch+1:03d}, Train MAE: {train_mse:.4f},',
              f'Test MAE: {test_mse:.4f}, Min MAE: {min_mse:.4f}')
    else: test_mse = 0

100%|██████████| 100000/100000 [03:22<00:00, 494.01it/s]

Epoch: 100000, Train MAE: 0.0012, Test MAE: 0.0011, Min MAE: 0.0004



