# Train Baseline - DSNN

### Setup

In [None]:
from tqdm import tqdm
import pickle
import traceback
import time
import random
from tqdm import tqdm

import networkx as nx
import torch
from torch_geometric.utils import to_networkx
import numpy as np
import scipy.linalg
import glob
import traceback

from torch_geometric.nn.conv import x_conv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential
from torch_geometric.nn import GCN, GIN, PNA


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Hypterparams

In [None]:
HIDDEN_DIM = 64
LAYER_NUM = 5
DROPOUT = 0.05
LR = 0.00001
USE_LAYERNORM = False  #unused
NUM_EPOCHS = 501
USE_RESIDUAL = False  #unused
NOISE_VAR = 0.0 #unused

### Get Dataset

In [None]:
from torch_geometric.datasets import TUDataset

def get_dataset(name):
    assert name in ['MUTAG', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY']

    if name == 'MUTAG':
        dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    elif name == 'PROTEINS':
        dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
        dataset2 = []
        for data in dataset:
            try:
                data.x = data.x.reshape(-1, 3)
            except:
                print(data)
                print("Cannot reshape data", data.x.shape, data)
            if data.x.numel() > 1 and data.edge_index.shape[1] > 1:
                dataset2.append(data)
            else:
                print(data)
                print("Illegal data", data.x.shape, data)
        dataset = dataset2
    elif name == 'ENZYMES':
        dataset = TUDataset(root='data/TUDataset', name='ENZYMES')
    elif name == 'IMDB-BINARY':
        dataset = TUDataset(root='data/TUDataset', name='IMDB-BINARY')
        dataset2 = []
        for data in dataset:
            data.x = torch.ones(data.num_nodes).reshape(-1, 1)
            if data.x.numel() > 1 and data.edge_index.shape[1] > 1:
                dataset2.append(data)
        dataset = dataset2

    dataset = [d for d in dataset]

    print(f"Length of dataset: {len(dataset)}")

    random.Random(1234).shuffle(dataset)
    split = int(0.8 * len(dataset))
    dataset_train = dataset[:split]
    dataset_testval = dataset[split:]

    split = int(0.5 * len(dataset_testval))
    dataset_test = dataset_testval[:split]
    dataset_val = dataset_testval[split:]

    return dataset_train, dataset_test, dataset_val



## NN Architectures

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCNnet(torch.nn.Module):
    def __init__(self, atom_dim, output_num):
        super(GCNnet, self).__init__()
        try:
            self.gcn =  GCN(in_channels = atom_dim, hidden_channels = HIDDEN_DIM, out_channels=HIDDEN_DIM, num_layers=LAYER_NUM, dropout=DROPOUT)
        except:
            self.gcn =  GCN(in_channels = atom_dim, hidden_channels = HIDDEN_DIM, output_size=HIDDEN_DIM, num_layers=LAYER_NUM, dropout=DROPOUT)
        self.lin = Linear(HIDDEN_DIM, output_num)
        self.output_num = output_num

    def forward(self, x, edge_index):
        x = self.gcn(x, edge_index)
        x = global_mean_pool(x, batch=torch.zeros(x.shape[0], dtype=torch.int64, device=DEVICE))  # [batch_size, hidden_channels]
        x = self.lin(x)
        if self.output_num == 1:
            x = torch.sigmoid(x.flatten())
        else:
            x= F.softmax(x.flatten(), dim=0)
        return x.flatten()

class GINnet(torch.nn.Module):
    def __init__(self, atom_dim, output_num):
        super(GINnet, self).__init__()
        self.gin =  GIN(in_channels = atom_dim, hidden_channels = HIDDEN_DIM, out_channels=HIDDEN_DIM, num_layers=LAYER_NUM, dropout=DROPOUT)
        self.lin = Linear(HIDDEN_DIM, output_num)
        self.output_num = output_num

    def forward(self, x, edge_index):
        x = self.gin(x, edge_index)
        x = global_mean_pool(x, batch=torch.zeros(x.shape[0], dtype=torch.int64, device=DEVICE))  # [batch_size, hidden_channels]
        x = self.lin(x)
        if self.output_num == 1:
            x = torch.sigmoid(x.flatten())
        else:
            x= F.softmax(x.flatten(), dim=0)
        return x.flatten()

    
class PNAnet(torch.nn.Module):
    def __init__(self, atom_dim, output_num ,deg):
        super(PNAnet, self).__init__()
        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']
        self.pna =  PNA(in_channels = atom_dim, hidden_channels = HIDDEN_DIM, out_channels=HIDDEN_DIM, num_layers=LAYER_NUM, dropout=DROPOUT,  aggregators=aggregators, scalers=scalers, deg=deg)
        self.lin = Linear(HIDDEN_DIM, output_num)
        self.output_num = output_num

    def forward(self, x, edge_index):
        x = self.pna(x, edge_index)
        x = global_mean_pool(x, batch=torch.zeros(x.shape[0], dtype=torch.int64, device=DEVICE))  # [batch_size, hidden_channels]
        x = self.lin(x)
        if self.output_num == 1:
            x = torch.sigmoid(x.flatten())
        else:
            x= F.softmax(x.flatten(), dim=0)
        return x.flatten()
    

In [None]:
def count_parameters(model):
    try:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    except:
        pass
    try:
        return sum(p.numel() for p in model.parameters())
    except:
        pass
    return 0

## Run Training

In [None]:
def train(dataset, optimizer, model, use_softmax=False):
  model.train()
  loss_list = list()
  for data in (dataset):  # Iterate in batches over the training dataset.
    try:
      data = data.to(DEVICE)
      y = data.y
      out = model(data.x, data.edge_index)  # Perform a single forward pass
      #print(out, y)
      if use_softmax:
        loss = (1.0 - out[y.item()])**2
      else:
        loss = torch.abs(out-y)
      if torch.isnan(loss).any():
        continue
      loss_list.append(loss.item())
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      optimizer.zero_grad()  # Clear gradients.
    except Exception as e:
        print(f"An error occurred while running data: {e}")
        traceback.print_exc()
  return np.nanmean(loss_list)


def test(dataset, model, use_softmax = False):
  model.eval()
  correct = 0
  error_list = list()
  print_output = True
  for data in (dataset):  # Iterate in batches over the training dataset.
    try:
      data = data.to(DEVICE)
      y = data.y
      out = model(data.x, data.edge_index)  # Perform a single forward pass.
      if use_softmax:
        loss = (1.0 - out[y.item()])
        if torch.isnan(loss).any():
          continue
        if torch.argmax(out) == y.item():
          correct +=1
      else:
        loss = torch.abs(out-y).item()
        assert(y <= 1.0 and y>= 0.0)
        correct +=  1 if loss < 0.5 else 0
        print_output = False
      error_list.append(loss)
    except Exception as e:
        print(f"An error occurred while running data: {e}")
        traceback.print_exc()

  return correct/len(dataset) # Derive ratio of correct predictions.


def start_agent(name="HIV", use_softmax=True, modeltype="GIN"):
  from torch_geometric.utils import degree
  trainset, testset, valset = get_dataset(name)
  atom_dim = trainset[0].x.shape[1]
  output_num = 1 #trainset[0][1].numel()
  if name == "ENZYMES":
    output_num = 6
  name = name + modeltype
  use_softmax = output_num > 1
  use_sigmoid = not use_softmax
  print("Train", name, 'with input dim', atom_dim, "and output dim", output_num)


  max_degree = -1
  for data in trainset:
      d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
      max_degree = max(max_degree, int(d.max()))
  deg = torch.zeros(max_degree + 1, dtype=torch.long)
  for data in trainset:
      d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
      deg += torch.bincount(d, minlength=deg.numel())

  if modeltype == "GIN":
    model =  GINnet(atom_dim, output_num)
  elif modeltype == "GCN":
    model =  GCNnet(atom_dim, output_num)
  elif modeltype == "PNA":
    model =  PNAnet(atom_dim, output_num, deg)
  else:
    assert False
    
  model = model.to(DEVICE)
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
  #criterion = torch.nn.CrossEntropyLoss()
  optimizer.zero_grad()
  best_val_acc = -1.0
  best_test_acc = -1.0

  print("run model on ", name, " with num parameters: ", count_parameters(model))


  for epoch in range(1, NUM_EPOCHS):
      loss = train(trainset, optimizer, model, use_softmax=use_softmax)
      mean_correct_train = test(trainset, model, use_softmax=use_softmax)
      mean_correct_test = test(testset, model, use_softmax=use_softmax)
      mean_correct_val = test(valset, model, use_softmax=use_softmax)
      if mean_correct_val > best_val_acc:
        best_val_acc = mean_correct_val
        best_test_acc = mean_correct_test
      if epoch % 10 == 0:
          print(f'({name}) Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train Acc: {mean_correct_train:.4f}, Test Acc: {mean_correct_test:.4f}, Val Acc: {mean_correct_val:.4f}, Test Best Acc: {best_test_acc:.4f}')
      try:
        wandb.log({f"{name}/test_acc": mean_correct_test, f"{name}/train_acc": mean_correct_train, f"{name}/val_acc": mean_correct_val, f"{name}/besttest_acc": best_test_acc, f"{name}/test_loss": loss})
      except:
        pass
  try:
    save_src_file()
    print("(finished) run model on ", name, " with num parameters: ", count_parameters(model))
  except:
    pass
  return mean_correct_test

In [None]:
def start_experiments():
    for modeltype in ["PNA", "GIN", "GCN"]:
        
        try:
            start_agent(name="PROTEINS", use_softmax=False, modeltype=modeltype)
        except Exception as e:
            print(f"An error occurred while running start_agent for PROTEINS with modeltype {modeltype}: {e}")
            traceback.print_exc()

        try:
            start_agent(name="IMDB-BINARY", use_softmax=False, modeltype=modeltype)
        except Exception as e:
            print(f"An error occurred while running start_agent for IMDB-BINARY with modeltype {modeltype}: {e}")
            traceback.print_exc()
        
        try:
            start_agent(name="ENZYMES", use_softmax=True, modeltype=modeltype)
        except Exception as e:
            print(f"An error occurred while running start_agent for ENZYMES with modeltype {modeltype}: {e}")
            traceback.print_exc()
        
        try:
            start_agent(name="MUTAG", use_softmax=False, modeltype=modeltype)
        except Exception as e:
            print(f"An error occurred while running start_agent for MUTAG with modeltype {modeltype}: {e}")
            traceback.print_exc()
        


In [None]:
for _ in range(10):
    start_experiments()