### Standard functions for dataloading and meta-classifier training 

In [100]:
import os
os.environ["DGLBACKEND"] = "pytorch"
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import DGLDataset
import numpy as np
import graph_tool as gt
import ast
import pandas as pd
import json
import random
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
import wandb

In [114]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
g = torch.Generator()
g.manual_seed(42)

<torch._C.Generator at 0x33a3bf7b0>

In [83]:
def extract_results_df(data_dir):
  # read results into dataframe
  NSHARDS = 10

  dfs = []
  for shard_idx in range(NSHARDS):
    filename = 'results.ndjson-%s-of-%s' % (str(shard_idx).zfill(5), str(NSHARDS).zfill(5))
    print(filename)

    with open(f'{data_dir}/{filename}', 'r') as f:
      lines = f.readlines()
      records = map(json.loads, lines)
      dfs.append(pd.DataFrame.from_records(records))

  # Construct df and remove nans
  results_df = pd.concat(dfs)
  results_df.drop(['marginal_param', 'fixed_params'], axis=1, inplace=True)
  results_df.dropna(axis=0, inplace=True)
  del dfs
  return results_df

In [89]:
def load_graph_data(graph_id, data_dir, results_df):
    
    gt_graph = gt.load_graph(data_dir + '{}_graph.gt'.format(graph_id))
    # Extract edges
    src, dst = gt_graph.get_edges().T
    src, dst = torch.tensor(src, dtype = torch.int64), torch.tensor(dst, dtype = torch.int64)

    # Load node features
    node_feats = torch.tensor(np.loadtxt(data_dir+'{}_node_features.txt'.format(graph_id)), dtype=torch.float32)

    # Verify shape (should be [num_nodes, 16])
    assert node_feats.dim() == 2 and node_feats.size(1) == 16
    
    # Instead of adding self-loops later, rebuild the edge list with self-loops.
    num_nodes = node_feats.shape[0]
    # Convert src and dst to lists (or numpy arrays) and add self-loops.
    src_list = src.tolist() + list(range(num_nodes))
    dst_list = dst.tolist() + list(range(num_nodes))
    
    # Build a new graph with the combined edge list.
    g = dgl.graph((src_list, dst_list), num_nodes=num_nodes)

    g.ndata['feat'] = node_feats

    # Load labels
    matched_df = results_df[results_df['sample_id'] == int(graph_id)]
    assert matched_df.shape[0] == 1, "Expected exactly one matching row in results_df, however found {}".format(matched_df.shape[0])
    label_mlp = matched_df[''].item()
    label_gcn = matched_df['GCN__test_accuracy'].item()
    probs = F.softmax(torch.tensor([label_mlp, label_gcn], dtype=torch.float32), dim=0)
    g.label = probs 

    g.graph_metrics_results = matched_df.to_dict()
    
    return g

In [90]:
class MyGraphDataset(DGLDataset):
    def __init__(self, graph_dir, num_graphs = None):
        self.graph_dir = graph_dir
        self.num_graphs = num_graphs
        self.dim_nfeats = None
        self.gclasses = None
        self.results_df = extract_results_df(graph_dir)
        super().__init__(name='my_graph_dataset')

    def _gen_graphs_ids(self):

        def cut_to_five(id:str):
            if len(id) == 5:
                return id
            else:
                id = id[1:]
                return cut_to_five(id)
    
        IDS = ['0000' + str(i) for i in range(0,self.num_graphs)]
        IDS = [cut_to_five(id) for id in IDS]

        return IDS 

    def process(self):
        self.graphs = []
        self.labels = []

        graph_ids = self._gen_graphs_ids()

        for gid in graph_ids:

            g = load_graph_data(gid, self.graph_dir, self.results_df)

            self.graphs.append(g)
            self.labels.append(g.label)
        
        self.dim_nfeats = self.graphs[0].ndata['feat'].shape[1]
        self.gclasses = self.graphs[0].label.shape[0]

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        return len(self.graphs)
    
    def has_cache(self): return False
    def download(self): pass
    def save(self): pass
    def load(self): pass

In [116]:
def train(model: nn.Module, train_dataloader: GraphDataLoader, test_dataloader: GraphDataLoader, val_dataloader:GraphDataLoader, optimizer: torch.optim.Optimizer, epochs : int):
    
    wandb.init(project="meta-classifier dev", reinit=True)
    wandb.watch(model)
    for epoch in range(epochs):
        print("Epoch:", epoch)
        for batched_graph, labels in train_dataloader:            
            pred = model(batched_graph, batched_graph.ndata["feat"].float())
            log_pred = F.log_softmax(pred, dim=1)
            loss = F.kl_div(log_pred, labels, reduction='batchmean')
            wandb.log({"train_loss": loss})
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        num_correct = 0
        num_tests = 0
        for batched_graph, labels in val_dataloader:
            pred = model(batched_graph, batched_graph.ndata["feat"].float())
            num_correct += (pred.argmax(1) == labels.argmax(1)).sum().item()
            num_tests += len(labels)
        wandb.log({"val_accuracy": num_correct / num_tests})
        print("Validation accuracy:", num_correct / num_tests)

    num_correct = 0
    num_tests = 0
    for batched_graph, labels in test_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float())
        num_correct += (pred.argmax(1) == labels.argmax(1)).sum().item()
        num_tests += len(labels)
    wandb.log({"test_accuracy": num_correct / num_tests})

    print("Test accuracy:", num_correct / num_tests)

In [119]:
from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")

### Training GCN on small (100 graphs) locally generated

In [47]:
locally_gen_dir = "../../graph_gen/locally_gen/nodeclassification/sbm/"
dataset = MyGraphDataset(locally_gen_dir, num_graphs = 100)


In [60]:
num_examples = len(dataset)
num_train = int(num_examples * 0.9)

train_sampler = SubsetRandomSampler(torch.arange(num_train), generator=g)
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples), generator=g)

train_dataloader = GraphDataLoader(dataset, sampler=train_sampler, batch_size=5, drop_last=False, generator=g)
test_dataloader = GraphDataLoader(dataset, sampler=test_sampler, batch_size=5, drop_last=False, generator=g)


In [None]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float())
        labels = labels.argmax(dim=1)
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float())
    num_correct += (pred.argmax(1) == labels.argmax(1)).sum().item()
    num_tests += len(labels)

print("Using cross entropy loss i.e. hard labels")
print("Test accuracy:", num_correct / num_tests)

Using cross entropy loss i.e. hard labels
Test accuracy: 0.6


In [66]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float())
        log_pred = F.log_softmax(pred, dim=1)
        loss = F.kl_div(log_pred, labels, reduction='batchmean')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float())
    num_correct += (pred.argmax(1) == labels.argmax(1)).sum().item()
    num_tests += len(labels)

print("Using KL_div loss i.e. soft labels")
print("Test accuracy:", num_correct / num_tests)

Using KL_div loss i.e. soft labels
Test accuracy: 0.7


### Training on first Snellius batch: 500 examples, feature dimension 16

In [91]:
snellius_data_dir = "../../graph_gen/snellius_gen/nodeclassification/sbm/"
snellius_dataset = MyGraphDataset(snellius_data_dir, num_graphs = 500)

results.ndjson-00000-of-00010
results.ndjson-00001-of-00010
results.ndjson-00002-of-00010
results.ndjson-00003-of-00010
results.ndjson-00004-of-00010
results.ndjson-00005-of-00010
results.ndjson-00006-of-00010
results.ndjson-00007-of-00010
results.ndjson-00008-of-00010
results.ndjson-00009-of-00010


In [126]:
num_examples = len(snellius_dataset)
num_train = int(num_examples * 0.6)
num_val = int(num_examples * 0.2)

train_sampler = SubsetRandomSampler(torch.arange(num_train), generator=g)
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_train+num_val), generator=g)
val_sampler = SubsetRandomSampler(torch.arange(num_train+num_val, num_examples), generator=g)

train_dataloader = GraphDataLoader(snellius_dataset, sampler=train_sampler, batch_size=20, drop_last=False, generator=g)
test_dataloader = GraphDataLoader(snellius_dataset, sampler=test_sampler, batch_size=20, drop_last=False, generator=g)
val_dataloader = GraphDataLoader(snellius_dataset, sampler=val_sampler, batch_size=20, drop_last=False, generator=g)

In [127]:
model = GCN(snellius_dataset.dim_nfeats, 16, snellius_dataset.gclasses)
train(model, train_dataloader, test_dataloader, val_dataloader, torch.optim.Adam(model.parameters(), lr=0.01), 50)

0,1
test_accuracy,▁
train_loss,▇█▅▃▄▄▃▂▃▂▂▂▂▁▂▁▂▂▁▂▁▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▂▁
val_accuracy,▁█▁▁▃▃▄▄▆▆▆▄▄▄▆▄▄▄▄▄▄▄▄▄▄▄▄▄█▆█▆█▆▄▄▄█▆▄

0,1
test_accuracy,0.35
train_loss,0.0001
val_accuracy,0.45


Epoch: 0
Validation accuracy: 0.37
Epoch: 1
Validation accuracy: 0.46
Epoch: 2
Validation accuracy: 0.41
Epoch: 3
Validation accuracy: 0.47
Epoch: 4
Validation accuracy: 0.52
Epoch: 5
Validation accuracy: 0.53
Epoch: 6
Validation accuracy: 0.57
Epoch: 7
Validation accuracy: 0.52
Epoch: 8
Validation accuracy: 0.55
Epoch: 9
Validation accuracy: 0.58
Epoch: 10
Validation accuracy: 0.44
Epoch: 11
Validation accuracy: 0.5
Epoch: 12
Validation accuracy: 0.46
Epoch: 13
Validation accuracy: 0.47
Epoch: 14
Validation accuracy: 0.5
Epoch: 15
Validation accuracy: 0.42
Epoch: 16
Validation accuracy: 0.45
Epoch: 17
Validation accuracy: 0.46
Epoch: 18
Validation accuracy: 0.46
Epoch: 19
Validation accuracy: 0.48
Epoch: 20
Validation accuracy: 0.44
Epoch: 21
Validation accuracy: 0.45
Epoch: 22
Validation accuracy: 0.45
Epoch: 23
Validation accuracy: 0.48
Epoch: 24
Validation accuracy: 0.49
Epoch: 25
Validation accuracy: 0.52
Epoch: 26
Validation accuracy: 0.4
Epoch: 27
Validation accuracy: 0.48
Epoch