In [1]:
import numpy as np
import scipy as sp
import torch
import time
import copy
import pickle
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import os
import sys
import torch_geometric
import itertools
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, GCNConv, GATv2Conv
from my_gat import my_GATConv
from my_mlp_gat_edges import my_MLP_GATConv_edges
from my_mlp_gat import my_MLP_GATConv
sys.path.insert(0, os.path.abspath('../../'))
from torch_geometric.datasets import Planetoid, Amazon
from torch_sparse import SparseTensor

# Data function

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
datasets = [Amazon(root='data/Amazon_Computers/', name='Computers'), Amazon(root='data/Amazon_Photo/', name='Photo'), Planetoid(root='data/Cora/', name='Cora'),Planetoid(root='data/PubMed/', name='PubMed'), Planetoid(root='data/CiteSeer/', name='CiteSeer')]

# Setup models

In [4]:
class Model_GAT(torch.nn.Module):
    def __init__(self, d ,out_d, K):
        super(Model_GAT, self).__init__()
        
        self.conv1 = my_GATConv(d, out_d, heads=K, bias_lin=True)

    def forward(self, data):
        
        # 1. Obtain node embeddings 
        x = data.x
        x = self.conv1(x, data.edge_index)
        
        return x.squeeze(-1)
    
class Model_MLP_GAT(torch.nn.Module):
    def __init__(self, d ,out_d):
        super(Model_MLP_GAT, self).__init__()
        
        self.conv1 = my_MLP_GATConv(d, out_d, 2, 16, bias=True, add_self_loops=True)

    def forward(self, data):
        
        # 1. Obtain node embeddings 
        x = data.x
        x, gamma = self.conv1(x, data.edge_index)
        
        return x.squeeze(-1), gamma.squeeze(-1)
    
class Model_MLP_GAT_edges(torch.nn.Module):
    def __init__(self, d_n ,out_d_n, d_e, out_d_e):
        super(Model_MLP_GAT_edges, self).__init__()
        
        self.conv1 = my_MLP_GATConv_edges(d_n, out_d_n, d_e, out_d_e, 2, 16, bias=True, add_self_loops=True)

    def forward(self, data):
        
        # 1. Obtain node embeddings 
        x = data.x
        e = data.e
        x, gamma = self.conv1(x, data.edge_index, e)
        
        return x.squeeze(-1), gamma.squeeze(-1)
    
class Model_GCN(torch.nn.Module):
    def __init__(self, d ,out_d):
        super(Model_GCN, self).__init__()
        
        self.conv1 = GCNConv(d,out_d, bias=True)

    def forward(self, data):
        
        # 1. Obtain node embeddings 
        x = data.x
        x = self.conv1(x, data.edge_index)
        
        return x.squeeze(-1)
    
class Model_linear(torch.nn.Module):
    def __init__(self, d ,out_d):
        super(Model_linear, self).__init__()
        
        self.linear = torch.nn.Linear(d,out_d, bias=True)

    def forward(self, data):
        
        # 1. Obtain node embeddings 
        x = data.x
        x = self.linear(x)
        
        return x.squeeze(-1)

## Setup train and accuracy functions

In [5]:
def train(model, data, criterion, opt):
    model.train()
    opt.zero_grad()
    logits = model(data) # does a forward computation
    loss = criterion(logits[data.train_mask], data.ynew[data.train_mask]) 
    loss.backward() # this computes the stochastic gradient (or whatever gradient you are using based on the solver)
    opt.step() # this updates the parameters using the gradient that has been computed above.
    return loss
        
@torch.no_grad()
def measure_accuracy(model, data):
    model.eval()
         
    logits = model(data) # forward operation
    preds = torch.sigmoid(logits) > 0.5
    
    # calculate training accuracy
    correct = preds[data.train_mask] == data.ynew[data.train_mask]
    train_acc = int(correct.sum()) / int(data.train_mask.sum())
    
    # calculate training accuracy
    correct = preds[data.test_mask] == data.ynew[data.test_mask]
    test_acc = int(correct.sum()) / int(data.test_mask.sum())
        
    return train_acc, test_acc

## Setup GCN function

In [6]:
def run_gcn(data, d, out_d, device, weight_decay, loss_tol, epochs):
    # Define the model
    model = Model_GCN(d, out_d=out_d).to(device)

    # Define the criterion
    criterion = torch.nn.BCEWithLogitsLoss()

    # Define the solver, check documentation in pytorch for how to set the learning rate.
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=weight_decay)

    # Test using the randomly initialized parameters.
    train_acc, test_acc = measure_accuracy(model, data)

    for epoch in range(1, epochs):
        loss = train(model, data, criterion, opt) # Performs an Adam step etc.
        train_acc, test_acc = measure_accuracy(model, data) # Test at each epoch
#         if loss <= 1.0e-2 or train_acc > 0.99:
        if loss <= loss_tol:
            break
#         print(f"q: {epoch:0.1f} | Loss: {loss:0.15f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
    return loss, train_acc, test_acc

## Setup linear model

In [7]:
def run_linear(data, d, out_d, device, weight_decay, loss_tol, epochs):
    # Define the model
    model = Model_linear(d, out_d=out_d).to(device)

    # Define the criterion
    criterion = torch.nn.BCEWithLogitsLoss()

    # Define the solver, check documentation in pytorch for how to set the learning rate.
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=weight_decay)

    # Test using the randomly initialized parameters.
    train_acc, test_acc = measure_accuracy(model, data)

    for epoch in range(1, epochs):
        loss = train(model, data, criterion, opt) # Performs an Adam step etc.
        train_acc, test_acc = measure_accuracy(model, data) # Test at each epoch
#         if loss <= 1.0e-2 or train_acc > 0.99:
        if loss <= loss_tol:
            break
#         print(f"q: {epoch:0.1f} | Loss: {loss:0.15f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
    return loss, train_acc, test_acc

## Setup MLP GAT and GAT

In [8]:
def train_gat(model, data, criterion, opt):
    model.train()
    opt.zero_grad()
    logits = model(data) # does a forward computation
    loss = criterion(logits[data.train_mask], data.ynew[data.train_mask]) 
    loss.backward() # this computes the stochastic gradient (or whatever gradient you are using based on the solver)
    opt.step() # this updates the parameters using the gradient that has been computed above.
    return loss
        
@torch.no_grad()
def measure_accuracy_gat(model, data):
    model.eval()
         
    logits = model(data) # forward operation
    preds = torch.sigmoid(logits) > 0.5
    
    # calculate training accuracy
    correct = preds[data.train_mask] == data.ynew[data.train_mask]
    train_acc = int(correct.sum()) / int(data.train_mask.sum())
    
    # calculate training accuracy
    correct = preds[data.test_mask] == data.ynew[data.test_mask]
    test_acc = int(correct.sum()) / int(data.test_mask.sum())
        
    return train_acc, test_acc

def train_mlp_gat(model, data, criterion, opt):
    model.train()
    opt.zero_grad()
    logits, gamma = model(data) # does a forward computation
    loss = criterion(logits[data.train_mask], data.ynew[data.train_mask]) 
    loss.backward() # this computes the stochastic gradient (or whatever gradient you are using based on the solver)
    opt.step() # this updates the parameters using the gradient that has been computed above.
    return loss
        
@torch.no_grad()
def measure_accuracy_mlp_gat(model, data):
    model.eval()
         
    logits, gamma = model(data) # forward operation
    preds = torch.sigmoid(logits) > 0.5
    
    # calculate training accuracy
    correct = preds[data.train_mask] == data.ynew[data.train_mask]
    train_acc = int(correct.sum()) / int(data.train_mask.sum())
    
    # calculate training accuracy
    correct = preds[data.test_mask] == data.ynew[data.test_mask]
    test_acc = int(correct.sum()) / int(data.test_mask.sum())
        
    return train_acc, test_acc

In [9]:
def run_mlp_gat_edges(data, d_n, out_d_n, d_e, out_d_e, device, weight_decay, loss_tol, epochs):
    # Define the model
    model = Model_MLP_GAT_edges(d_n=d_n, out_d_n=out_d_n, d_e=d_e, out_d_e=out_d_e).to(device)

    # Define the criterion
    criterion = torch.nn.BCEWithLogitsLoss()

    # Define the solver, check documentation in pytorch for how to set the learning rate.
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')

    # Test using the randomly initialized parameters.
    train_acc, test_acc = measure_accuracy_mlp_gat(model, data)

    for epoch in range(1, epochs):
        loss = train_mlp_gat(model, data, criterion, opt) # Performs an Adam step etc.
        train_acc, test_acc = measure_accuracy_mlp_gat(model, data) # Test at each epoch
#         if loss <= 1.0e-2 or train_acc > 0.99:
        if loss <= loss_tol:
            break
#         print(f"q: {epoch:0.1f} | Loss: {loss:0.15f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
    return loss, train_acc, test_acc, model

In [10]:
def run_mlp_gat(data, d, out_d, device, weight_decay, loss_tol, epochs):
    # Define the model
    model = Model_MLP_GAT(d, out_d=out_d).to(device)

    # Define the criterion
    criterion = torch.nn.BCEWithLogitsLoss()

    # Define the solver, check documentation in pytorch for how to set the learning rate.
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')

    # Test using the randomly initialized parameters.
    train_acc, test_acc = measure_accuracy_mlp_gat(model, data)

    for epoch in range(1, epochs):
        loss = train_mlp_gat(model, data, criterion, opt) # Performs an Adam step etc.
        train_acc, test_acc = measure_accuracy_mlp_gat(model, data) # Test at each epoch
#         if loss <= 1.0e-2 or train_acc > 0.99:
        if loss <= loss_tol:
            break
#         print(f"q: {epoch:0.1f} | Loss: {loss:0.15f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
    return loss, train_acc, test_acc, model

In [11]:
def run_gat(data, d, out_d, K, device, weight_decay, loss_tol, epochs):
    # Define the model
    model = Model_GAT(d, out_d=out_d, K=K).to(device)

    # Define the criterion
    criterion = torch.nn.BCEWithLogitsLoss()

    # Define the solver, check documentation in pytorch for how to set the learning rate.
    opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')

    # Test using the randomly initialized parameters.
    train_acc, test_acc = measure_accuracy_gat(model, data)

    for epoch in range(1, epochs):
        loss = train_gat(model, data, criterion, opt) # Performs an Adam step etc.
        train_acc, test_acc = measure_accuracy_gat(model, data) # Test at each epoch
#         if loss <= 1.0e-2 or train_acc > 0.99:
        if loss <= loss_tol:
            break
#         print(f"q: {epoch:0.1f} | Loss: {loss:0.15f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
    return loss, train_acc, test_acc, model

## Run experiments

In [66]:
for dataset in datasets:
#     print(dataset.name)
    data = dataset[0].to(device)
    n_classes = data.y.max() + 1
#     print("Number of classes: ", n_classes)
    n = data.y.shape[0]
    d = data.x.shape[1]
    d_n = int(np.floor(d/2))
    d_e = d - d_n
    
    data.e = data.x[:,0:d_n]
    data.x = data.x[:,d_n:d]
    
    data.edge_index2 = torch_geometric.utils.add_remaining_self_loops(data.edge_index)[0]
    
    if dataset.name == 'computers' or dataset.name == 'photo':
        data.train_mask = torch.BoolTensor(np.random.binomial(1, 0.01, size=n)) # NOTE: some training_ratio should be specifiied for Amazon and Coa
        data.test_mask = ~data.train_mask

    for which_class in range(n_classes):
        y = torch.zeros(n, dtype=torch.float64).to(device)
        idx = data.y == which_class
        y[idx] = 1
        data.ynew = y

        class0 = torch.where(data.ynew == 0)[0]
        class1 = torch.where(data.ynew == 1)[0]
        perm = torch.cat((class0, class1), 0).cpu().detach().numpy()
#         print("Class: ", which_class, "#C0: ", class0.shape[0], "#C1: ", class1.shape[0])

        sum_test_acc_gcn = 0
        sum_test_acc_lin = 0
        sum_test_acc_mlp_gat = 0
        sum_test_acc_gat = 0
        
        intra_gamma_ = 0
        intra_gamma_default_ = 0
        inter_gamma_ = 0
        inter_gamma_default_ = 0

        weight_decay = 0
        loss_tol = 1.0e-2
        epochs = 500

        trials = 5
        for trial in range(trials):

#             print("trial/trials: ", (trial+1)/trials)

            loss, train_acc, test_acc = run_gcn(data, d=data.x.shape[1], out_d=1, device=device, weight_decay=weight_decay, loss_tol=loss_tol, epochs=epochs)
#             print(f"GCN,     Loss: {loss:0.16f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
            sum_test_acc_gcn += test_acc

            loss, train_acc, test_acc = run_linear(data, d=data.x.shape[1], out_d=1, device=device, weight_decay=weight_decay, loss_tol=loss_tol, epochs=epochs)
#             print(f"Linear,  Loss: {loss:0.16f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
            sum_test_acc_lin += test_acc

#             loss, train_acc, test_acc, model_mlp_gat = run_mlp_gat(data, d=data.x.shape[1], out_d=1, device=device, weight_decay=weight_decay, loss_tol=loss_tol, epochs=epochs)
#             print(f"MLP GAT, Loss: {loss:0.1f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
# #             logits, gamma = model_mlp_gat(data) 
#             sum_test_acc_mlp_gat += test_acc
            
            loss, train_acc, test_acc, model_mlp_gat_edges = run_mlp_gat_edges(data, d_n=data.x.shape[1], out_d_n=1, d_e=data.e.shape[1], out_d_e=1, device=device, weight_decay=weight_decay, loss_tol=loss_tol, epochs=epochs)
#             print(f"GAT with edge features, Loss: {loss:0.1f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
            logits, gamma = model_mlp_gat_edges(data) 
            sum_test_acc_mlp_gat += test_acc
            
#             heads = 2
#             loss, train_acc, test_acc, model_gat = run_gat(data, d=data.x.shape[1], out_d=1, K=heads, device=device, weight_decay=weight_decay, loss_tol=loss_tol, epochs=epochs)
#             print(f"GAT,     Loss: {loss:0.1f} | Train: {train_acc:0.4f} | Test: {test_acc:0.4f}")
#             sum_test_acc_gat += test_acc
            
            attn_adj = SparseTensor(row=data.edge_index2[0], col=data.edge_index2[1], value=gamma).to_scipy('csr')
#             diffs = []
            intra_gamma = []
            inter_gamma = []
            intra_gamma_default = []
            inter_gamma_default = []
            
            for i in range(n):
                neighbors = attn_adj[:,i].nonzero()[0]
                uniform = np.ones(neighbors.shape[0]) / neighbors.shape[0]
#                 a1 = torch.FloatTensor(uniform)
#                 a2 = torch.FloatTensor(attn_adj[neighbors,i].data)
#                 diff = (a2*(a2.log()-a1.log())).sum()
#                 diffs.append(diff)
                
                intra_neighbors = data.ynew[i] == data.ynew[neighbors]
                inter_neighbors = data.ynew[i] != data.ynew[neighbors]
                
                tmp = attn_adj[neighbors,i].data[intra_neighbors.cpu().detach().numpy()]
                tmp_l = tmp.shape[0]
                if tmp_l != 0:
                    intra_gamma.append(np.sum(tmp))
                tmp = attn_adj[neighbors,i].data[inter_neighbors.cpu().detach().numpy()]
                tmp_l = tmp.shape[0]
                if tmp_l != 0:
                    inter_gamma.append(np.sum(tmp))
                
                tmp = uniform[intra_neighbors.cpu().detach().numpy()]
                tmp_l = tmp.shape[0]
                if tmp_l != 0:
                    intra_gamma_default.append(np.sum(tmp))
                tmp = uniform[inter_neighbors.cpu().detach().numpy()]
                tmp_l = tmp.shape[0]
                if tmp_l != 0:
                    inter_gamma_default.append(np.sum(tmp))

#             print("Mean KL distance: ", np.asarray(diffs).mean())
            
            total_ga = np.asarray(intra_gamma).sum() + np.asarray(inter_gamma).sum()
            total_gc = np.asarray(intra_gamma_default).sum() + np.asarray(inter_gamma_default).sum()
            intra_gamma_ += 100*np.asarray(intra_gamma).sum()/total_ga
            inter_gamma_ += 100*np.asarray(inter_gamma).sum()/total_ga
            intra_gamma_default_ += 100*np.asarray(intra_gamma_default).sum()/total_gc
            inter_gamma_default_ += 100*np.asarray(inter_gamma_default).sum()/total_gc
#             print("mass intra-edges for attention: ", 100*np.asarray(intra_gamma).sum()/total_ga, " mass intra-edges for GC: ", 100*np.asarray(intra_gamma_default).sum()/total_gc)
#             print("mass inter-edges for attention: ", 100*np.asarray(inter_gamma).sum()/total_ga, " mass inter-edges for GC: ", 100*np.asarray(inter_gamma_default).sum()/total_gc)
            
#         print("Test accuracy GCN: ", sum_test_acc_gcn/trials)
#         print("Test accuracy Linear: ", sum_test_acc_lin/trials)
#         print("Test accuracy GAT: ", sum_test_acc_gat/trials)
#         print("Test accuracy MLP-GAT: ", sum_test_acc_mlp_gat/trials)
    
#         print("mass intra-edges for attention: ", intra_gamma_/trials, " mass intra-edges for GC: ", intra_gamma_default_/trials)
#         print("mass inter-edges for attention: ", inter_gamma_/trials, " mass inter-edges for GC: ", inter_gamma_default_/trials)
        
        if which_class == 0:
            print("\multirow{"+str(n_classes.item()*2)+"}{*}{\\rotatebox[origin=c]{90}{"+dataset.name+"}} & \multirow{2}{*}{$"+str(which_class)+"$} & GC & $"+str(round(intra_gamma_default_/trials,1))+"$ & $"+str(round(inter_gamma_default_/trials,1))+"$ & $"+str(round(100*sum_test_acc_gcn/trials,1))+"$ \\\\")
            print("& & GA & $"+str(round(intra_gamma_/trials,1))+"$ & $"+str(round(inter_gamma_/trials,1))+"$ & $"+str(round(100*sum_test_acc_mlp_gat/trials,1))+"$ \\\\ \cline{2-6}")
        elif which_class < n_classes-1:
            print("& \multirow{2}{*}{$"+str(which_class)+"$} & GC & $"+str(round(intra_gamma_default_/trials,1))+"$ & $"+str(round(inter_gamma_default_/trials,1))+"$ & $"+str(round(100*sum_test_acc_gcn/trials,1))+"$ \\\\")
            print("& & GA & $"+str(round(intra_gamma_/trials,1))+"$ & $"+str(round(inter_gamma_/trials,1))+"$ & $"+str(round(100*sum_test_acc_mlp_gat/trials,1))+"$ \\\\ \cline{2-6}")
        else:
            print("& \multirow{2}{*}{$"+str(which_class)+"$} & GC & $"+str(round(intra_gamma_default_/trials,1))+"$ & $"+str(round(inter_gamma_default_/trials,1))+"$ & $"+str(round(100*sum_test_acc_gcn/trials,1))+"$ \\\\")
            print("& & GA & $"+str(round(intra_gamma_/trials,1))+"$ & $"+str(round(inter_gamma_/trials,1))+"$ & $"+str(round(100*sum_test_acc_mlp_gat/trials,1))+"$ \\\\")
            print("\hline \hline")


\multirow{20}{*}{\rotatebox[origin=c]{90}{computers}} & \multirow{2}{*}{$0$} & GC & $98.7$ & $1.3$ & $96.8$ \\
& & GA & $98.1$ & $1.9$ & $96.7$ \\ \cline{2-6}
& \multirow{2}{*}{$1$} & GC & $93.6$ & $6.4$ & $91.4$ \\
& & GA & $93.3$ & $6.7$ & $88.7$ \\ \cline{2-6}
& \multirow{2}{*}{$2$} & GC & $98.1$ & $1.9$ & $95.8$ \\
& & GA & $97.8$ & $2.2$ & $92.1$ \\ \cline{2-6}
& \multirow{2}{*}{$3$} & GC & $97.5$ & $2.5$ & $96.0$ \\
& & GA & $96.0$ & $4.0$ & $96.0$ \\ \cline{2-6}
& \multirow{2}{*}{$4$} & GC & $89.4$ & $10.6$ & $89.7$ \\
& & GA & $89.4$ & $10.6$ & $83.3$ \\ \cline{2-6}
& \multirow{2}{*}{$5$} & GC & $99.6$ & $0.4$ & $97.8$ \\
& & GA & $99.5$ & $0.5$ & $98.0$ \\ \cline{2-6}
& \multirow{2}{*}{$6$} & GC & $97.0$ & $3.0$ & $96.4$ \\
& & GA & $96.5$ & $3.5$ & $96.5$ \\ \cline{2-6}
& \multirow{2}{*}{$7$} & GC & $99.1$ & $0.9$ & $96.8$ \\
& & GA & $98.5$ & $1.5$ & $94.9$ \\ \cline{2-6}
& \multirow{2}{*}{$8$} & GC & $91.8$ & $8.2$ & $88.9$ \\
& & GA & $91.7$ & $8.3$ & $86.5$ \\ \cline{2-6}