# DSNN

In [None]:
PROJECT_NAME = "DSNN"

import sys, os, glob

from tqdm import tqdm
import pickle
import traceback
import time
import random
from tqdm import tqdm

import networkx as nx
import torch
from torch_geometric.utils import to_networkx
import numpy as np
import scipy.linalg


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Hyperparams

In [None]:
HIDDEN_DIM = 64
LAYER_NUM = 9
DROPOUT = 0.0
LR = 0.0001
USE_LAYERNORM = False
NUM_EPOCHS = 501
USE_RESIDUAL = True
NOISE_VAR = 0.005  #regularization

## Method

In [None]:
def get_minmax_centrality(g):
    betweenness_centrality = nx.betweenness_centrality(g)
    eigenvector_centrality = nx.eigenvector_centrality(g, max_iter=1000, tol=1e-04)
    try:
        laplacian_centrality = nx.laplacian_centrality(g)
    except:
        laplacian_centrality = nx.betweenness_centrality(g)
    prod_centrality = list()
    
    for n in range(g.number_of_nodes()):
        v = (betweenness_centrality[n]+0.00001) * (eigenvector_centrality[n]+0.00001) * (laplacian_centrality[n]+0.00001)
        prod_centrality.append(v)

    min_value = np.min(prod_centrality)
    min_nodes = [v for v in range(len(prod_centrality)) if prod_centrality[v]<min_value+0.00000000001]
    
    max_value = np.max(prod_centrality)
    max_nodes = [v for v in range(len(prod_centrality))  if prod_centrality[v]>max_value-0.00000000001 ]

    
    minmax_centrality = dict()
    for n in range(g.number_of_nodes()):
        dmin = np.min([nx.shortest_path_length(g, source=n, target=target) for target in min_nodes]) 
        dmax = np.min([nx.shortest_path_length(g, source=n, target=target) for target in max_nodes]) 
        minmax_centrality[n] = [dmin, dmax]
    return minmax_centrality
         
    
    
def convert_graph(g):
  g_nx = to_networkx(g, to_undirected=True, node_attrs=['x'])
  assert nx.is_connected(g_nx)

  minmax_centrality = get_minmax_centrality(g_nx)
  for n in list(g_nx.nodes()):
    node_data = g_nx.nodes[n] 
    x_value = node_data.get('x', 1.0) 
    if 'list' not in str(type(x_value)): 
      g_nx.nodes[n]['x'] = [x_value] 


  first_node = list(g_nx.nodes())[0]
  orig_size = len(g_nx.nodes[first_node]['x'])  


  for node_i, minmax_value in minmax_centrality.items():
    assert("list" in str(type(g_nx.nodes[node_i]['x'])))
    assert("list" in str(type(minmax_value)))
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + minmax_value

    
  betweenness_centrality = nx.betweenness_centrality(g_nx)
  for node_i, between_value in betweenness_centrality.items():
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [between_value]

  degree_centrality = nx.degree_centrality(g_nx)
  for node_i, degree_value in degree_centrality.items():
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [degree_value]

  closeness_centrality = nx.closeness_centrality(g_nx)
  for node_i, closeness_value in closeness_centrality.items():
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [closeness_value]

  second_order_centrality = nx.second_order_centrality(g_nx)
  for node_i, second_order_value in closeness_centrality.items():
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [second_order_value]

  try:
    laplacian_centrality = nx.laplacian_centrality(g_nx)
    for node_i, laplacian_value in laplacian_centrality.items():
      g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [laplacian_value]
  except:
    pass

  eigenvector_centrality = nx.eigenvector_centrality(g_nx, max_iter=1000, tol=1e-04)
  for node_i, eigen_value in eigenvector_centrality.items():
    g_nx.nodes[node_i]['x'] = g_nx.nodes[node_i]['x'] + [eigen_value]



  x_tensor = torch.zeros([g_nx.number_of_nodes(), len(g_nx.nodes[0]['x'])])
  for node_i in g_nx.nodes():
    x_i = g_nx.nodes[node_i]['x']
    x_tensor[node_i,:] = torch.tensor(x_i)
 
  # Sum over all neighboring nodes
  x_tensor2 = torch.zeros_like(x_tensor)
  for i in range(x_tensor.shape[0]):
    neig_indicator = [i in g_nx.neighbors(i) for i in range(x_tensor.shape[0])]
    x_tensor2[node_i,:] = torch.sum(x_tensor[neig_indicator,:], dim=0)

  x_tensor = torch.cat((x_tensor,x_tensor2), dim=1)



  return g_nx, x_tensor, g.y.clone().detach()

## Neural Network

In [None]:
from torch_geometric.nn.conv import x_conv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential


class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, layer_num=4, use_residual=USE_RESIDUAL, use_layernorm=USE_LAYERNORM):
        super(MLP, self).__init__()
        self.mlp_list = nn.ModuleList()
        self.use_layernorm = use_layernorm
        self.use_residual = use_residual

        for i in range(layer_num):
            in_dim = hidden_size if i > 0 else input_size
            out_dim = hidden_size if i < layer_num-1 else output_size
            if use_residual and i>0:
              in_dim += input_size
            lin = nn.Linear(in_dim, out_dim)
            self.mlp_list.append(lin)

            if i < layer_num-1:
                if use_layernorm and i < layer_num-2:
                    self.mlp_list.append(nn.LayerNorm(out_dim))
                self.mlp_list.append(nn.ReLU())

    def forward(self, x_original):
        x = x_original.clone()
        for idx, layer in enumerate(self.mlp_list):
            if idx>0 and self.use_residual and isinstance(layer, nn.Linear):
                x = torch.cat((x, x_original), dim=-1)
                x = x.to(DEVICE)
            x = layer(x)
        return x


class DeepSet(nn.Module):
  def __init__(self, input_size=11, hidden_size=32, output_size=1, use_softmax = False, use_sigmoid=False, layer_num=4, dropout=0.2):
    super(DeepSet, self).__init__()
    self.mlp1 = MLP(input_size, hidden_size, hidden_size, layer_num=layer_num)
    self.mlp2 = MLP(hidden_size, hidden_size, output_size, layer_num=layer_num)
    self.use_softmax = use_softmax
    self.use_sigmoid = use_sigmoid
    self.dropout = dropout

  def forward(self, x):
    x = self.mlp1(x)
    if self.training and self.dropout > 0.0:
      n = x.shape[0]
      random_tensor = torch.bernoulli(torch.full((n,), 1.0-self.dropout), device=DEVICE).to(dtype=torch.bool)
      x = x[random_tensor, :]

    x_sum = torch.sum(x, dim=0)
    if NOISE_VAR > 0.0000001 and  self.training: 
        x_sum += torch.randn_like(x_sum, device=DEVICE)*NOISE_VAR
    x = x_sum.flatten()

    x = self.mlp2(x)
    if self.use_softmax:
      return F.softmax(x.flatten(), dim=0)
    elif self.use_sigmoid:
      return torch.sigmoid(x.flatten())
    return x






## Gen Dataset

In [None]:
def get_dataset(name):
  from torch_geometric.datasets import TUDataset
  assert name in ['MUTAG', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY']

  filename = f'{name}_deepset.pkl'
  if os.path.exists(filename):
    with open(filename, 'rb') as f:
      dataset_converted_train, dataset_converted_test, dataset_converted_val = pickle.load(f)
    return dataset_converted_train, dataset_converted_test, dataset_converted_val

  if name == 'MUTAG':
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
  elif name == 'PROTEINS':
    dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
  elif name == 'ENZYMES':
    dataset = TUDataset(root='data/TUDataset', name='ENZYMES')
  elif name == 'IMDB-BINARY':
    dataset = TUDataset(root='data/TUDataset', name='IMDB-BINARY')
    modified_data_list = list()
    for data in dataset:
        data.x = torch.zeros([data.num_nodes, 1])
        modified_data_list.append(data)
    dataset = modified_data_list

  dataset_converted = list()
  for g in tqdm(dataset):
    try:
      g_nx, x_tensor, y = convert_graph(g)
      dataset_converted.append((x_tensor,y))
    except AssertionError:
      pass

  print("len dataset_converted", len(dataset_converted))

  random.Random(1234).shuffle(dataset_converted)
  split = int(0.8*len(dataset_converted))
  dataset_converted_train = dataset_converted[:split]
  dataset_converted_testval = dataset_converted[split:]
    
  split = int(0.5*len(dataset_converted_testval))
  dataset_converted_test = dataset_converted_testval[:split]
  dataset_converted_val = dataset_converted_testval[split:]


  with open(filename, 'wb') as f:
    pickle.dump((dataset_converted_train, dataset_converted_test, dataset_converted_val), f)

  return dataset_converted_train, dataset_converted_test, dataset_converted_val



In [2]:
## Training

In [3]:
def train(dataset, optimizer, model, use_softmax=False):
  model.train()
  loss_list = list()
  for x, y in (dataset):  # Iterate in batches over the training dataset.
    x = torch.nan_to_num(x, nan=0.0)
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    out = model(x)  # Perform a single forward pass.
    if use_softmax:
      loss = (1.0 - out[y.item()])**2
    else:
      loss = torch.abs(out-y)**2
    loss_list.append(loss.item())
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    optimizer.zero_grad()  # Clear gradients.
  return np.nanmean(loss_list)

def test(dataset, model, use_softmax = False):
  model.eval()
  correct = 0
  error_list = list()
  print_output = True
  for x, y in (dataset):  # Iterate in batches over the training/test dataset.
    x = torch.nan_to_num(x, nan=0.0)
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    out = model(x)
    if use_softmax:
      loss = (1.0 - out[y.item()])
      if torch.argmax(out) == y.item():
        correct +=1
    else:
      loss = torch.abs(out-y).item()
      assert(y <= 1.0 and y>= 0.0)
      correct +=  1 if loss < 0.5 else 0
      print_output = False
    error_list.append(loss)

  return correct/len(dataset) # Derive ratio of correct predictions.


def start_agent(name="MUTAG", use_softmax=True):
  trainset, testset, valset = get_dataset(name)
  atom_dim = trainset[0][0].shape[1]
  output_num = 1 #trainset[0][1].numel()
  if name == "ENZYMES":
    output_num = 6

  use_softmax = output_num > 1
  use_sigmoid = not use_softmax
  print("Train", name, 'with input dim', atom_dim, "and output dim", output_num)

  model = DeepSet(input_size = atom_dim, hidden_size = HIDDEN_DIM, output_size=output_num, use_softmax=use_softmax, use_sigmoid=use_sigmoid, layer_num=LAYER_NUM, dropout=DROPOUT)
  model = model.to(DEVICE)
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
  criterion = torch.nn.CrossEntropyLoss()
  optimizer.zero_grad()
  best_val_acc = -1.0
  best_test_acc = -1.0


  for epoch in range(1, NUM_EPOCHS):
      loss = train(trainset, optimizer, model, use_softmax=use_softmax)
      # The final ACC is given as the test ACC at the point where the val ACC is highest. 
      mean_correct_train = test(trainset, model, use_softmax=use_softmax)
      mean_correct_test = test(testset, model, use_softmax=use_softmax)
      mean_correct_val = test(valset, model, use_softmax=use_softmax)
      if mean_correct_val > best_val_acc:
        best_val_acc = mean_correct_val
        best_test_acc = mean_correct_test
      if epoch % 10 == 0:
          print(f'({name}) Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train Acc: {mean_correct_train:.4f}, Test Acc: {mean_correct_test:.4f}, Val Acc: {mean_correct_val:.4f}, Test Best Acc: {best_test_acc:.4f}')
      try:
        wandb.log({f"{name}/test_acc": mean_correct_test, f"{name}/train_acc": mean_correct_train, f"{name}/val_acc": mean_correct_val, f"{name}/besttest_acc": best_test_acc, f"{name}/test_loss": loss})
      except:
        pass
  return mean_correct_test

In [None]:
def start_experiments():
    try:
        start_agent(name="ENZYMES", use_softmax=True) 
    except Exception as e:
        error_message = traceback.format_exc()
        print("final error:\n", error_message)
        with open('_error_log.txt', 'a') as f:
            f.write(error_message + '\n')

    try:
        start_agent(name="IMDB-BINARY", use_softmax=False) #
    except Exception as e:
        error_message = traceback.format_exc()
        print("final error:\n", error_message)
        with open('_error_log.txt', 'a') as f:
            f.write(error_message + '\n')

    try:
        start_agent(name="MUTAG", use_softmax=False) #
    except Exception as e:
        error_message = traceback.format_exc()
        print("final error:\n", error_message)
        with open('_error_log.txt', 'a') as f:
            f.write(error_message + '\n')

    try:
        start_agent(name="PROTEINS", use_softmax=False) #
    except Exception as e:
        error_message = traceback.format_exc()
        print("final error:\n", error_message)
        with open('_error_log.txt', 'a') as f:
            f.write(error_message + '\n')
        




In [None]:
for _ in range(10):
    start_experiments()
