In [2]:
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import pickle
import torch_geometric
from tqdm import tqdm, trange
import torch.optim as optim
from torch_geometric.nn import GCNConv, GATConv
import matplotlib.pyplot as plt

import itertools
import gc
import joblib
from joblib import delayed, Parallel
import glob

In [3]:
%matplotlib inline

In [4]:
PATH_DATA0 = './data/00.01'
PATH_DATA = './data/00.02'
RANDOM_SEED =0
np.random.seed(RANDOM_SEED)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
CRITERION = nn.BCEWithLogitsLoss()
LR = 0.001
TOLERANCE = 20
LR_TOLERANCE= 5
MAX_EPOCHS = 200
BATCH_SIZE =2

In [6]:
loader_train = torch_geometric.loader.DataLoader(
    pd.read_pickle(os.path.join('max_prob_10_subsample_0.1','graphs_train.pkl')).tolist(),
    batch_size = BATCH_SIZE,shuffle = True)
loader_val = torch_geometric.loader.DataLoader(
    pd.read_pickle(os.path.join('max_prob_10_subsample_0.1','graphs_val.pkl')).tolist(),batch_size = BATCH_SIZE
    ,shuffle = False)
loader_test = torch_geometric.loader.DataLoader(
    pd.read_pickle(os.path.join('max_prob_10_subsample_0.1','graphs_test.pkl')).tolist(),batch_size = BATCH_SIZE
    ,shuffle = False)

In [7]:
loader_train.dataset[0].edge_attr

tensor([[-2.3207e-01, -1.9634e-02,  2.4826e-01, -6.9013e-01],
        [-4.6413e-01, -4.6385e-02,  4.9653e-01, -6.6614e-01],
        [-6.2136e-01, -6.2315e-02,  6.6473e-01, -6.5209e-01],
        ...,
        [-4.2989e-01,  5.0476e+00, -4.9280e-01,  4.4754e+00],
        [-4.9400e-01, -1.7182e-01, -5.0646e-01, -4.8015e-01],
        [ 6.6460e-03, -1.7567e-03,  3.7240e-03, -5.6318e-01]])

In [8]:
class GraphAttentionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, activation, heads,
                 node_emb_dim, edge_attr_dim, dropout = 0): #add edge_attr_dim
        #first get node embeddings via attention ADD EDGES TO GATCONV
        super().__init__()
        self.convs = nn.ModuleList()
        self.acts = nn.ModuleList()

        self.convs.append(
            GATConv(in_channels, hidden_channels, heads=heads, edge_dim = edge_attr_dim, dropout=dropout)
        )
        self.acts.append(activation)

        for _ in range(num_layers - 2):
            self.convs.append(
                GATConv(hidden_channels * heads,
                        hidden_channels,
                        heads=heads,
                        dropout=dropout,
                        edge_dim = edge_attr_dim,
                        concat = True)
            )
            self.acts.append(activation)

        self.convs.append(
            GATConv(hidden_channels * heads,
                    node_emb_dim,
                    heads=heads,
                    edge_dim = edge_attr_dim,
                    concat=False,
                    dropout=dropout)
        )

        #second, get edge classification via MLP
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_emb_dim + edge_attr_dim, hidden_channels),
            activation,
            #nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, x, edge_index, edge_attr):
        for conv, act in zip(self.convs[:-1], self.acts):
            x = conv(x, edge_index, edge_attr)
            x = act(x)
        node_emb = self.convs[-1](x, edge_index, edge_attr)


        src, dst = edge_index
        src, dst = torch.minimum(src, dst), torch.maximum(src, dst) #make sure one edge embedding is invariant
        h_src = node_emb[src]
        h_dst = node_emb[dst]

        edge_feat = torch.cat([h_src, h_dst, edge_attr], dim = 1)
        logits = self.edge_mlp(edge_feat).view(-1)
        return logits


In [10]:
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params

In [11]:
def train(model, loader, optimizer, loss_fn, device=None):
    model.train()
    total_loss = 0.0


    for batch in loader:
        if device is not None:
            batch = batch.to(device)

        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)

        loss = loss_fn(out, batch.y.view(-1).float()) #batch.y.view(-1).float()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs

    return total_loss / len(loader)


In [12]:
@torch.no_grad()
def test(model, loader, loss_fn, device=None):
  model.eval()
  preds, actuals = [], []

  for batch in loader:
    if device is not None:
      batch = batch.to(device)

    logits = model(batch.x, batch.edge_index, batch.edge_attr)
    probs = torch.sigmoid(logits)
    preds.append(probs.cpu())
    actuals.append(batch.y.cpu())

  preds = torch.cat(preds).flatten()
  actuals = torch.cat(actuals).flatten()
  acc = ((preds > 0.5) == (actuals > 0.5)).type(torch.float).mean().item()
  loss = loss_fn(preds, actuals.float()).item()

  return preds.numpy(), actuals.numpy(), acc, loss

In [13]:
loader_train.dataset[0].edge_attr.shape[-1]

4

In [14]:
def find_best_gat_hyperparams(in_channels, out_channels, n_layers, activation, edge_attr_dim, target_params, head_choices = (1,2,4,8),
                              hidden_choices = (16,32,64,128,256,512), emb_choices = (16,32,64,128,256,512)):
    best = None
    for heads in head_choices:
      for hidden in hidden_choices:
        for emb in emb_choices:
          model = GraphAttentionModel(
              in_channels = in_channels,
              hidden_channels = hidden,
              out_channels = out_channels,
              num_layers = n_layers,
              activation = activation,
              heads = heads,
              node_emb_dim = emb,
              edge_attr_dim = edge_attr_dim,
          )
          p = count_parameters(model)
          diff = abs(p-target_params)
          if best is None or diff < best[0]:
            best = (diff, heads, hidden, emb, p)
    return best[1], best[2], best[3], best[4]

In [15]:
EDGE_ATTR_DIM = loader_train.dataset[0].edge_attr.size(-1)

In [16]:
#paper for GAT uses 4 heads for inductive task, so we do that too.
layers_to_hyperparams_2 = {} #key = # (layers, target_params), value = (#heads, hidden dim, embedding dim, # params) output of find_best_gat_hyperparams
for i,j in itertools.product([2,3,4], [100000, 500000, 1000000]):
  layers_to_hyperparams_2[(i,j)] = find_best_gat_hyperparams(in_channels= loader_train.dataset[0].num_node_features,
                                                      out_channels=loader_train.dataset[0].y.size(-1),  # or number of classes
                                                      n_layers=i,
                                                      activation=nn.ReLU(),
                                                      edge_attr_dim = EDGE_ATTR_DIM,
                                                      target_params=j,
                                                      head_choices = (4,),
                                                      hidden_choices = (64,128,256), 
                                                      emb_choices = (64,128,256))

In [None]:
layers_to_hyperparams_2

In [17]:
def train_gat_for(n_layers, target_params, lth = layers_to_hyperparams_2, path_data = PATH_DATA, **model_kwargs):
    gc.collect(); torch.cuda.empty_cache()
    print(f"\n>> Layers={n_layers}, target≈{target_params:,} params")

    # ---- grab the precomputed tuple, instead of re‐searching ----
    try:
        heads, hidden, emb, actual = lth[(n_layers, target_params)]
    except KeyError:
        raise KeyError(f"No entry in lth for {(n_layers, target_params)}")

    print(f"→ Using cached heads={heads}, hidden={hidden}, emb={emb} → {actual:,} params")


    in_dim  = loader_train.dataset[0].num_node_features
    out_dim = loader_train.dataset[0].y.size(-1)
    edge_attr_dim = 4
    ACT, LR = nn.ReLU(), 1e-3
    lr = LR

    model = GraphAttentionModel(
        in_channels    = in_dim,
        hidden_channels= hidden,
        out_channels   = out_dim,
        num_layers     = n_layers,
        activation     = ACT,
        heads          = heads,
        node_emb_dim   = emb,
        edge_attr_dim  = edge_attr_dim,
        **model_kwargs
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=LR)
    best_val, epochs_no_imp, epochs_no_imp2 = float('inf'), 0, 0
    stats, best = [], {}

    print(f"{'E':>3} | {'Train L':>7} | {'Val L':>6} | {'Val Acc':>7} | {'Test Acc':>8}")
    print("-"*45)
    for epoch in range(MAX_EPOCHS):
        tr_loss = train(model, loader_train, optimizer, CRITERION, device)
        pv, av, accv, lv = test (model, loader_val,   CRITERION, device)
        pt, at, acct, lt = test (model, loader_test,  CRITERION, device)

        stats.append({'train_loss':tr_loss,'val_loss':lv,'acc_val':accv,'acc_test':acct})
        star = ""
        if lv < best_val:
            best_val = lv
            epochs_no_imp = epochs_no_imp2 = 0
            best = {
                'state':     {k:v.cpu() for k,v in model.state_dict().items()},
                'preds_val': pv, 'preds_test': pt
            }
            star = "*"
        else:
            epochs_no_imp  += 1
            epochs_no_imp2 += 1

        print(f"{epoch+1:3d} | {tr_loss:7.4f} | {lv:6.4f} | {accv:7.4f} | {acct:8.4f} {star}")

        if epochs_no_imp  >= TOLERANCE:
            print(f"→ Early stopping @ epoch {epoch+1}")
            break
        if epochs_no_imp2 >= LR_TOLERANCE:
            if lr >=1.0e-8:
                lr /= 10
                for g in optimizer.param_groups:
                    g['lr'] = lr
            print(f"→ LR reduced to {lr:e}")

    best['stats'] = pd.DataFrame(stats)
    os.makedirs(path_data, exist_ok=True)
    joblib.dump(best, os.path.join(path_data, f"gat_{n_layers}_{target_params}.pkl"))

    # optional plotting
    #best['stats'][['train_loss','val_loss']].plot(figsize=(12,3)); plt.show()
    #best['stats'][['acc_val','acc_test']].plot(figsize=(12,3)); plt.show()

    del model, optimizer, stats
    gc.collect(); torch.cuda.empty_cache()


In [21]:
PATH_DATA3 = './data/00.03'

In [22]:
Parallel(n_jobs=1)(delayed(train_gat_for)(n_layers,target_params, lth = layers_to_hyperparams_2, path_data = PATH_DATA3)
                    for n_layers in tqdm([2,3,4], leave = False)
                    for target_params in tqdm([100000, 500000,1000000], leave = False))

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]
[A%|                                                                                            | 0/2 [00:00<?, ?it/s]


>> Layers=4, target≈500,000 params
→ Using cached heads=4, hidden=64, emb=256 → 440,705 params
  E | Train L |  Val L | Val Acc | Test Acc
---------------------------------------------
  1 |  0.6205 | 0.4348 |  0.9042 |   0.9067 *
  2 |  0.4006 | 0.4217 |  0.9205 |   0.9233 *
  3 |  0.3527 | 0.4167 |  0.9265 |   0.9291 *
  4 |  0.3268 | 0.4139 |  0.9315 |   0.9339 *
  5 |  0.3026 | 0.4112 |  0.9378 |   0.9399 *
  6 |  0.2799 | 0.4079 |  0.9437 |   0.9456 *
  7 |  0.2567 | 0.4057 |  0.9481 |   0.9499 *
  8 |  0.2418 | 0.4026 |  0.9501 |   0.9518 *
  9 |  0.2293 | 0.4020 |  0.9524 |   0.9538 *
 10 |  0.2183 | 0.4009 |  0.9542 |   0.9551 *
 11 |  0.2079 | 0.3991 |  0.9570 |   0.9582 *
 12 |  0.1999 | 0.3985 |  0.9577 |   0.9584 *
 13 |  0.1928 | 0.3975 |  0.9584 |   0.9598 *
 14 |  0.1870 | 0.3969 |  0.9614 |   0.9625 *
 15 |  0.1808 | 0.3959 |  0.9621 |   0.9631 *
 16 |  0.1744 | 0.3955 |  0.9615 |   0.9627 *
 17 |  0.1699 | 0.3951 |  0.9637 |   0.9644 *
 18 |  0.1662 | 0.3945 |  0.9642


[A%|█████████████████████████████████████████                                         | 1/2 [17:10<17:10, 1030.41s/it]


>> Layers=4, target≈1,000,000 params
→ Using cached heads=4, hidden=128, emb=256 → 1,136,129 params
  E | Train L |  Val L | Val Acc | Test Acc
---------------------------------------------
  1 |  0.6059 | 0.4307 |  0.9049 |   0.9079 *
  2 |  0.3747 | 0.4169 |  0.9271 |   0.9292 *
  3 |  0.3170 | 0.4105 |  0.9386 |   0.9407 *
  4 |  0.2700 | 0.4051 |  0.9467 |   0.9484 *
  5 |  0.2411 | 0.4038 |  0.9521 |   0.9536 *
  6 |  0.2228 | 0.4000 |  0.9539 |   0.9557 *
  7 |  0.2066 | 0.3982 |  0.9591 |   0.9601 *
  8 |  0.1912 | 0.3974 |  0.9593 |   0.9607 *
  9 |  0.1785 | 0.3955 |  0.9637 |   0.9650 *
 10 |  0.1679 | 0.3937 |  0.9650 |   0.9664 *
 11 |  0.1568 | 0.3927 |  0.9688 |   0.9695 *
 12 |  0.1489 | 0.3912 |  0.9695 |   0.9704 *
 13 |  0.1404 | 0.3908 |  0.9687 |   0.9699 *
 14 |  0.1340 | 0.3899 |  0.9721 |   0.9728 *
 15 |  0.1277 | 0.3889 |  0.9727 |   0.9733 *
 16 |  0.1210 | 0.3882 |  0.9744 |   0.9750 *
 17 |  0.1151 | 0.3872 |  0.9744 |   0.9749 *
 18 |  0.1115 | 0.3870 |  0


[A%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [35:22<00:00, 1066.57s/it]
                                                                                                                       

[None, None]

In [25]:
PICKLE_DIR3 = "./data/00.03" 


records_3 = []

print('here')
for path in tqdm(glob.glob(os.path.join(PICKLE_DIR3, "*.pkl"))):
    run = joblib.load(path)
    stats_df = run['stats'].copy()
    ep = stats_df['val_loss'].idxmin()
    name = os.path.splitext(os.path.basename(path))[0]
    records_3.append({'run': name, 'best_epoch': ep, 'test_accuracy': stats_df.loc[ep, 'acc_test'],
                   'validation_accuracy': stats_df.loc[ep, 'acc_val']})
    
summary = pd.DataFrame(records_3)

here


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 39.29it/s]


In [26]:
summary

Unnamed: 0,run,best_epoch,test_accuracy,validation_accuracy
0,gat_2_100000,109,0.971456,0.970404
1,gat_2_1000000,84,0.977348,0.976752
2,gat_2_500000,102,0.977895,0.977084
3,gat_3_100000,129,0.978504,0.97784
4,gat_3_1000000,87,0.977994,0.977495
5,gat_3_500000,121,0.980475,0.980034
6,gat_4_100000,88,0.981028,0.980677
7,gat_4_1000000,61,0.981645,0.981479
8,gat_4_500000,82,0.974438,0.973589


In [27]:
summary[['prefix', 'layers', 'params']] = summary['run'].str.split('_', expand=True)
summary['layers'] = summary['layers'].astype(int)
summary['params'] = summary['params'].astype(int)

In [28]:
output_df = summary[['layers', 'params', 'test_accuracy', 'validation_accuracy']].rename(
    columns={
        'layers': '# layers',
        'params': '#params',
        'validation_accuracy': 'val_accuracy'
    }
)

In [32]:
output_df = output_df.sort_values(by = 'test_accuracy', ascending = False)
output_df

Unnamed: 0,# layers,#params,test_accuracy,val_accuracy
7,4,1000000,0.981645,0.981479
6,4,100000,0.981028,0.980677
5,3,500000,0.980475,0.980034
3,3,100000,0.978504,0.97784
4,3,1000000,0.977994,0.977495
2,2,500000,0.977895,0.977084
1,2,1000000,0.977348,0.976752
8,4,500000,0.974438,0.973589
0,2,100000,0.971456,0.970404


In [33]:
csv_path = 'gat_performance.csv'
output_df.to_csv(csv_path, index=False)