In [1]:
# util 

import os
import pickle
import numpy as np
import networkx as nx
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn.functional as F
import torch.nn as nn

from typing import Callable, Tuple, Union

from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing, NNConv, CGConv, GINEConv
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset, zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size

from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader, NeighborLoader, ClusterData, ClusterLoader
from torch_geometric.utils import from_networkx, to_networkx

from IPython.display import clear_output

DATA_FOLDER = "./data"

seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# random.seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic=True
# torch.backends.cudnn.benchmark = False

In [2]:
def construct_bigLITTLE_graph(data_folder, no_duplicate=False, unobserved=0.0, unobserved_edge=0.0):
    list_files = os.listdir(data_folder)
    list_files = list(filter(lambda x: x.endswith(".pkl"), list_files))
    list_labels = pickle.load(open("labels.pkl", "rb"))
    
    dict_graph_size = defaultdict(lambda: [])
    set_cycle_size = set([])
    set_branch_size = set([])
    
    # Directed or Undirected? Edge weights?
    bigLITTLE_graph = nx.DiGraph()
    gid = 0
    
    for gf in list_files:
        idx, cycle_size, branch_size, _ = gf.split("_")
        cycle_size = int(cycle_size)
        branch_size = int(branch_size)
        
        if no_duplicate and len(dict_graph_size[(cycle_size, branch_size)]) > 0:
            continue
            
        # graph = nx.read_gpickle(os.path.join(DATA_FOLDER, gf))
        
        # For testing
        # if cycle_size > 5 or branch_size > 2:
        #     continue
        
        bigLITTLE_graph.add_node(gid, features=[float(cycle_size), float(branch_size)], label=list_labels[int(idx)])
        dict_graph_size[(cycle_size, branch_size)].append(gid)
        
        set_cycle_size.add(cycle_size)
        set_branch_size.add(branch_size)
        gid += 1
    # print(dict_graph_size)
    # Undirected assumption?
    # Same cycle_size & branch_size ==> Edge: [0,0]
    # for k, items in dict_graph_size.items():
    #     for item_idx1 in range(len(items)):
    #         for item_idx2 in range(item_idx1+1, len(items)):
    #             bigLITTLE_graph.add_edge(items[item_idx1], items[item_idx2], e=[0,0])
                
    # Same cycle size & Different branch size ==> Edge: [0, 1]
    # Filter lists of same cycle_size
    for cs in set_cycle_size:
        list_same_cycle_size = list(filter(lambda x: x[0]==cs, dict_graph_size.keys()))
        list_same_cycle_size = list(sorted(list_same_cycle_size, key=lambda x:x[1]))
        # print(list_same_cycle_size)
        for bs_idx1 in range(len(list_same_cycle_size)-1):
            bs_idx2 = bs_idx1 + 1
            key_cb1 = list_same_cycle_size[bs_idx1] # e.g. (3, 1)
            key_cb2 = list_same_cycle_size[bs_idx2] # e.g. (3, 2)

            for gid1 in dict_graph_size[key_cb1]:
                for gid2 in dict_graph_size[key_cb2]:
                    # print(key_cb2[1]-key_cb1[1])
                    bigLITTLE_graph.add_edge(gid1, gid2, e=[0,1])
                    bigLITTLE_graph.add_edge(gid2, gid1, e=[0,-1])
                        
    # Different cycle size & Same branch size ==> Edge: [1, 0]
    for bs in set_branch_size:
        list_same_branch_size = list(filter(lambda x: x[1]==bs, dict_graph_size.keys()))
        list_same_branch_size = list(sorted(list_same_branch_size, key=lambda x:x[0]))
        
        for cs_idx1 in range(len(list_same_branch_size)-1):
            cs_idx2 = cs_idx1 + 1
            key_cb1 = list_same_branch_size[cs_idx1] # e.g. (3, 1)
            key_cb2 = list_same_branch_size[cs_idx2] # e.g. (4, 1)

            for gid1 in dict_graph_size[key_cb1]:
                for gid2 in dict_graph_size[key_cb2]:
                    bigLITTLE_graph.add_edge(gid1, gid2, e=[1,0])
                    bigLITTLE_graph.add_edge(gid2, gid1, e=[-1,0])
                    
    
    # Add all other edge as [0,0]
    list_nodes = list(bigLITTLE_graph.nodes)
    for i in range(len(list_nodes)):
        for j in range(i, len(list_nodes)):
            nid_i = list_nodes[i]
            nid_j = list_nodes[j]
            if not bigLITTLE_graph.has_edge(nid_i, nid_j):
                bigLITTLE_graph.add_edge(gid1, gid2, e=[0,0])
                bigLITTLE_graph.add_edge(gid2, gid1, e=[0,0])
    
    if unobserved > 0:
        unobserved_node_idxs = np.random.choice(list(range(bigLITTLE_graph.number_of_nodes())), 
                                size=int(unobserved*bigLITTLE_graph.number_of_nodes()), 
                                replace=False)
    else:
        unobserved_node_idxs = None
    
    if unobserved_edge > 0:
        # Remove unobserved edges
        num_edge_to_remove = int(bigLITTLE_graph.number_of_edges() * unobserved_edge)
            
        list_edges = np.array([list(e) for e in bigLITTLE_graph.edges])
        list_unique_edges = list_edges[list_edges[:,1] >= list_edges[:,0]]
        edge_idxs = np.random.choice([0,1], 
                                size=list_unique_edges.shape[0], 
                                p=[unobserved_edge, 1-unobserved_edge])
        list_remove_edges = list_unique_edges[edge_idxs==0]
        
        for edge in list_remove_edges:
            u, v = edge
            bigLITTLE_graph.remove_edge(u,v)
            bigLITTLE_graph.remove_edge(v,u)
    
    return bigLITTLE_graph, unobserved_node_idxs

In [3]:
def get_edge_color(e):
    if e == [0,1]:
        return "black"
    elif e == [0,-1]:
        return "red"
    elif e == [1,0]:
        return "green"
    elif e == [-1,0]:
        return "blue"
    else:
        return "yellow"
    
def draw_graph(graph):
    nodeLabels = {nid:graph.nodes[nid]["label"] for nid in graph.nodes}
    nodeColors = "grey"
    edgeColor = [get_edge_color(graph.edges[eid]["e"])for eid in graph.edges]

    nx.draw(graph, nx.kamada_kawai_layout(graph), edge_color=edgeColor, width=1, linewidths=0.1,
              node_size=500, node_color=nodeColors, alpha=0.9,
              labels=nodeLabels)
    
def transform_func(graph):
    graph.x = graph.x.to("cuda:0")
    graph.y = graph.label.to("cuda:0")
    graph.edge_attr = graph.edge_attr.to("cuda:0")
    graph.edge_index = graph.edge_index.to("cuda:0")
    return graph

In [4]:
class TrickyNNConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, nn: Callable = None, aggr: str = 'add',
                 root_weight: bool = True, bias: bool = True, **kwargs):
        super().__init__(aggr=aggr, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.nn = Linear(out_channels, out_channels, bias=False, weight_initializer='uniform')
        if root_weight:
            self.lin = Linear(in_channels[1], out_channels, bias=False,
                              weight_initializer='uniform')

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        if self.root_weight:
            self.lin.reset_parameters()
        zeros(self.bias)


    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""
        x_r = self.lin(x)

        # propagate_type: (x: OptTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x_r, edge_attr=edge_attr, size=size)

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        # weight = self.nn(edge_attr)
        weight = edge_attr.view(-1, self.out_channels)
        return self.nn(weight + x_j).squeeze(1) # torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, aggr={self.aggr}, nn={self.nn})')

In [5]:
class GNN(nn.Module):
    def __init__(self, node_channels, edge_channels, hidden_channels):
        super(GNN, self).__init__()

        self.edge_embed = nn.Linear(edge_channels, hidden_channels, bias=False)
        
        # self.node_embed = nn.Linear(node_channels, hidden_channels, bias=False)
        # self.conv1 = CGConv(hidden_channels, hidden_channels)
        
        self.conv1 = TrickyNNConv(node_channels, hidden_channels, aggr="mean")
        self.conv1_e = nn.Linear(hidden_channels, hidden_channels, bias=False)
        
        self.conv2 = TrickyNNConv(hidden_channels, hidden_channels, aggr="mean")
        
        # self.conv1 = GINEConv(
        #     nn=nn.Sequential(nn.Linear(hidden_channels, node_channels*hidden_channels)),
        #     edge_dim=hidden_channels
        # )
        
        self.lin1 = nn.Linear(hidden_channels, hidden_channels, bias=True)
        self.lin2 = nn.Linear(hidden_channels, hidden_channels, bias=True)
        self.lin3 = nn.Linear(hidden_channels, 1, bias=True)

    def forward(self, x, edge_index, edge_attr, batch):
        
        '''
        edge_attr: batch_size * 2
        '''
        # 1. Obtain node embeddings
        # x = self.node_embed(x)
        e = self.edge_embed(edge_attr)
        
        z = self.conv1(x=x, edge_index=edge_index, edge_attr=e)
        z = z.relu()
        
        e = self.conv1_e(e)#.relu()
        z = self.conv2(x=z, edge_index=edge_index, edge_attr=e)
        z = z.relu()
        
        # 2. Apply a final classifier
        z = F.dropout(z, p=0.1, training=True)
        
        z = self.lin1(z)
        z = z.relu()
        z = self.lin2(z)
        z = z.relu()
        
        z = self.lin3(z)
        z = torch.sigmoid(z) * 110
        
        return z
    
def train(loader, random_mask=0, observed_idxs=None):
    model.train()
    total_loss = 0
    steps = 0

    # Iterate in batches over the training dataset
    for data in loader:
        # Random masking
        if random_mask > 0:
            idxs = np.random.choice(list(range(data.x.shape[0])), size=int(random_mask*data.x.shape[0]), replace=False)
            data.x[idxs] = 1 - data.x[idxs]
        
        # Perform a single forward pass
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        # Compute the loss
        loss = criterion(
            out[observed_idxs], 
            data.y[observed_idxs].view(-1, 1)
        )
        total_loss += loss
        
        loss.backward(); optimizer.step(); optimizer.zero_grad(); steps += 1

    return total_loss / steps

def test(loader, mc_dropout_sample=100):
    model.eval()
    mse = 0
    steps = 0

    # Iterate in batches over the training/test dataset
    for data in loader:
        
        out = []
        for _ in range(mc_dropout_sample):
            out.append(model(data.x, data.edge_index, data.edge_attr, data.batch))
        out = torch.stack(out)
        
        # Check against ground-truth labels
        mse += criterion(out.mean(0), data.y.view(-1, 1))
        
        steps += 1
        
        
        # out_std = out.std(0)
    return mse / steps  # Derive ratio of correct predictions.

In [7]:
generate_data = True
unobserved_fraction = 0.2

if generate_data:
    bL_graph_train, unobserved_idxs = construct_bigLITTLE_graph(DATA_FOLDER, unobserved=unobserved_fraction)
    # draw_graph(bL_graph_train)
    
    data_train = from_networkx(bL_graph_train)
    data_train.x = torch.ones(len(data_train.label), 2).type(torch.FloatTensor)
    data_train.x[:,1] = 0
    
    # data_train.x = np.random.choice([0,1], size=data_train.num_nodes, p=[0.2, 0.8])
    # x_not = np.logical_not(data_train.x)
    # data_train.x = np.stack([data_train.x, x_not], axis=-1)
    # data_train.x = torch.tensor(data_train.x).type(torch.FloatTensor)
    
    data_train.edge_attr = data_train.e.type(torch.FloatTensor)
    # data_train.train_mask = np.ones(data_train.num_nodes) 
    # data.train_mask = np.random.choice(
    #     [0, 1], size=data.num_nodes, p=[0.2, 0.8])
    # data_train.test_mask = np.logical_not(data_train.train_mask)
    data_train = transform_func(data_train)
    
    c_data_train = ClusterData(data_train, num_parts=1, recursive=True)
    train_loader = ClusterLoader(c_data_train)
    
    
    
    bL_graph_test, _ = construct_bigLITTLE_graph(DATA_FOLDER)
    # draw_graph(bL_graph)
    
    data_test = from_networkx(bL_graph_test)
    data_test.x = torch.ones(len(data_test.label), 2).type(torch.FloatTensor)
    data_test.x[:,1] = 0
    data_test.edge_attr = data_test.e.type(torch.FloatTensor)
    data_test = transform_func(data_test)
    
    c_data_test = ClusterData(data_test, num_parts=1, recursive=True)
    test_loader = ClusterLoader(c_data_test)

Computing METIS partitioning...
Done!
Computing METIS partitioning...
Done!


In [8]:
# draw_graph(to_networkx(next(iter(train_loader)), node_attrs=["label"], edge_attrs=["e"]))
bL_graph_train.number_of_nodes()

685

In [9]:
neighbors_list = []
for nid in bL_graph_train.nodes:
    node_neighbors = bL_graph_train.neighbors(nid)
    node_observed_neighbors = set(node_neighbors) - set(unobserved_idxs)
    node_unobserved_neighbors = set(node_neighbors) - set(node_observed_neighbors)
    neighbors_list.append([nid, len(node_observed_neighbors), len(node_unobserved_neighbors)])
# print(neighbors_list)
neighbors_list = np.array(neighbors_list)
print(np.where(neighbors_list[:,1] == 0))
assert np.where(neighbors_list[:,1] == 0)[0].shape[0] == 0


(array([], dtype=int64),)


In [10]:
min_mse = 1e10
min_epoch = 0
epochs = 100000
lr = 0.0005
device = "cuda:0"
hidden_channels = 64

model = GNN(node_channels=2, edge_channels=2, 
            hidden_channels=hidden_channels).to(device)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(model)
print("Number of parameters: ", params)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.L1Loss()

observed_idxs = list(set(bL_graph_train.nodes) - set(unobserved_idxs))

for epoch in tqdm(range(epochs)):
    train_mse = train(train_loader, random_mask=0, observed_idxs=observed_idxs)
    
    if (epoch+1) % 1000 == 0:
        clear_output(wait=True)
        test_mse = test(test_loader)
        if test_mse < min_mse:
            min_mse = test_mse
            min_epoch = epoch
        print(f'Epoch: {epoch+1:03d}, Train MAE: {train_mse:.4f},',
              f'Test MAE: {test_mse:.4f}, Min MAE: {min_mse:.4f}')
    else: test_mse = 0

GNN(
  (edge_embed): Linear(in_features=2, out_features=64, bias=False)
  (conv1): TrickyNNConv(2, 64, aggr=mean, nn=Linear(64, 64, bias=False))
  (conv1_e): Linear(in_features=64, out_features=64, bias=False)
  (conv2): TrickyNNConv(64, 64, aggr=mean, nn=Linear(64, 64, bias=False))
  (lin1): Linear(in_features=64, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=64, bias=True)
  (lin3): Linear(in_features=64, out_features=1, bias=True)
)
Number of parameters:  25153


  0%|▌                                                                                                                                      | 458/100000 [00:22<1:19:57, 20.75it/s]


KeyboardInterrupt: 

In [None]:
train_loader = NeighborLoader(
    data,
    # Sample 10 neighbors for each node for 2 iterations
    num_neighbors=[10],
    # Use a batch size of 128 for sampling training nodes
    batch_size=128,
    input_nodes=data.train_mask
)

test_loader = NeighborLoader(
    data,
    # Sample 10 neighbors for each node for 2 iterations
    num_neighbors=[10],
    # Use a batch size of 128 for sampling training nodes
    batch_size=128,
    input_nodes=data.test_mask
)

In [None]:
# observed node and masked nodes 


# obs_answers_id = [ np.concatenate([
#     # np.array([True]),
#     np.random.choice([True]*(allow)+ [False]*(n_questions-allow), (n_questions), replace=False), \
#     # np.random.choice([True]*1 + [False]*(n_treatments-1), (n_treatments), replace=False) \
#     ]) for _ in range(n_sample)]

# obs_answers_id = np.stack(obs_answers_id)
# obs_answers_id = torch.tensor(obs_answers_id)

# obs_outcomes_id = [
#     np.random.choice([True]*1 + [False]*(n_treatments-1), (n_treatments), replace=False) \
#     for _ in range(n_sample)]
# obs_outcomes_id = np.stack(obs_outcomes_id)
# obs_outcomes_id = torch.tensor(obs_outcomes_id)

# mask_obs_answers_id = copy.deepcopy(obs_answers_id)
# for sample in range(n_sample):
#     all_true_id = torch.where(obs_answers_id[sample])[0]
#     flip = np.random.choice(all_true_id)
#     mask_obs_answers_id[sample, flip] = False

# obs_answers = copy.deepcopy(answers)
# obs_answers[torch.logical_not(obs_answers_id)] = torch.tensor([0, 1], device=device).double()

# mask_obs_answers = copy.deepcopy(answers)
# mask_obs_answers[torch.logical_not(mask_obs_answers_id)] = torch.tensor([0, 1], device=device).double()
