Installing Dependencies

In [None]:
!pip install torch
import torch 
pytorch_version = f"torch-{torch.__version__}.html"
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install torch-geometric
!pip install torchvision
!pip install seaborn
!pip install torchmetrics
!nvidia-smi

Looking for GPU

In [None]:
#searching for GPU
import torch
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')

Processing training data

In [None]:
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

torch.manual_seed(100)

data = TUDataset(root='data/TUDataset/',name='PROTEINS_full')
data = data.shuffle()
print(len(data))

train_data = data[:int(0.75*len(data))]
test_data = data[int(0.75*len(data)):]

print(len(train_data))
print(len(test_data))

trainX = DataLoader(train_data, batch_size=20, shuffle=True)
testX = DataLoader(train_data, batch_size=20, shuffle=True)

print('Total nodes of dataset {}'.format(data.num_features))
print('Total labels of dataset {}'.format(data.num_classes))
print('Edge indexes of one molecule {}'.format(len(data[1].edge_index)))
print('Edge attributes of one molecule {}'.format(data[1].edge_attr))

Visualizing molecule

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

one_mol = data[1]   #getting first molecule
edge_list = one_mol.edge_index.t().numpy()

plt.figure(1)
graph = nx.Graph()
for i in range(len(edge_list)):
  u = edge_list[i][0]
  v = edge_list[i][1]
  graph.add_edge(u,v)

plt.title('Proteins')
nx.draw(graph)
plt.savefig('Proteins.png')


GAT-SAGE

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Module,Dropout,Linear
from torch_geometric.nn import GCNConv,GATConv,SAGEConv,global_max_pool

class GraphSAGE(Module):
    def __init__(
        self,
        n_hidden,
        in_features,
        out_features,
        ):
        super(GraphSAGE,self).__init__()
        self.n_hidden = n_hidden
        self.in_features = in_features
        self.out_features = out_features
        self.conv1 = SAGEConv(self.in_features,self.n_hidden)
        self.conv2 = SAGEConv(self.n_hidden, self.n_hidden)
        self.dropout = Dropout(p=0.2)
        self.out = Linear(self.n_hidden,self.out_features)

    def forward(self,x,edge_index,batch):
        x = self.conv1(x,edge_index)
        x = self.conv2(x,edge_index)
        x = self.dropout(x)
        x = x.relu()
        x = global_max_pool(x,batch)
        x = self.out(x)
        return x

class GAT(Module):
    def __init__(
        self,
        n_hidden,
        in_features,
        out_features,
        ):
        super(GAT,self).__init__()
        self.n_hidden = n_hidden
        self.in_features = in_features
        self.out_features = out_features
        self.conv1 = GATConv(self.in_features,self.n_hidden)
        self.conv2 = GATConv(self.n_hidden, self.n_hidden)
        self.dropout = Dropout(p=0.2)
        self.out = Linear(self.n_hidden,self.out_features)

    def forward(self,x,edge_index,batch):
        x = self.conv1(x,edge_index)
        x = self.conv2(x,edge_index)
        x = self.dropout(x)
        x = x.relu()
        x = global_max_pool(x,batch)
        x = self.out(x)
        return x


In [None]:
class GAT_SAGE(Module):
  def __init__(
        self,
        GAT,
        SAGE,
        ):
    super(GAT_SAGE,self).__init__()
    self.GAT = GAT
    self.SAGE = SAGE
    
  
  def forward(self,x,edge_index,batch):
    x1 = self.GAT(x,edge_index,batch)
    x2 = self.SAGE(x,edge_index,batch)
    x = torch.cat((x1, x2), dim=1)
    x = F.softmax(x,dim=1)
    return x

In [None]:
model1 = GraphSAGE(n_hidden=32, in_features=data.num_features,out_features=data.num_classes).to(device)
model2 = GAT(n_hidden=32, in_features=data.num_features,out_features=data.num_classes).to(device)
model = GAT_SAGE(model2,model1).to(device)

In [None]:
print(model)

Training GAT-SAGE

In [None]:
#training
import time
from torch.optim import Adam
from torch.nn import CrossEntropyLoss,L1Loss

losses = []
accuracy = []

optimizer = Adam(model.parameters(), lr=0.002)
loss_func = CrossEntropyLoss()
epochs = 500

def train(loader):
  for data in loader:
    data.to(device)
    model.zero_grad()
    out = model(data.x, data.edge_index, data.batch)
    loss = loss_func(out,data.y) 
    loss.backward()
    optimizer.step()
  return loss

def test(loader):
  model.eval()
  correct = 0
  for data in loader:
    data.to(device)
    out = model(data.x, data.edge_index, data.batch)
    pred = out.argmax(dim=1)
    correct += int((pred == data.y).sum())  
  return correct / len(loader.dataset) 
 
for epoch in range(0, epochs):
    start = time.process_time()
    loss = train(trainX)
    train_acc = test(trainX)
    losses.append(loss)
    test_acc = test(testX)
    accuracy.append(test_acc)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}, Loss: {loss:.4f}, Time: {time.process_time()-start:.4f}')

Training statistics

In [None]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.plot(accuracy)
plt.title('GAT-SAGE(Proteins)')
plt.xlabel('EPOCHS')
plt.ylabel('Accuracy')
plt.show()

Saving training weights

In [None]:
import numpy as np
import torch
np.save('GatSAGE-Proteins(accuracy).npy',accuracy)
torch.save(model.state_dict(),'GatSAGE-Proteins(weights).pth')
torch.save(optimizer.state_dict(),'GatSAGE-Proteins(optimizer).pth')


Loading weights

In [None]:
#loading weights
import torch
from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=0.002)

model.load_state_dict(torch.load('GatSAGE-Proteins(weights).pth'))
optimizer.load_state_dict(torch.load('GatSAGE-Proteins(optimizer).pth'))
model.eval()

Testing GAT-SAGE

In [None]:
import torch
#testing the weights
target = []
preds = []

import pandas as pd 
test_batch = next(iter(testX))  #extraxt one batch from dataset
target = test_batch.y
with torch.no_grad():
    test_batch.to(device)
    pred = model(test_batch.x, test_batch.edge_index, test_batch.batch)
    time_step = time.process_time()-start
    preds.append(pred) 
    df = pd.DataFrame()
    df["y_actual"] = test_batch.y.tolist()
    df["y_predicted"] = torch.argmax(pred,dim=1).tolist()
df

Micro F1-score

In [None]:
#f1 accuracy
from torchmetrics.classification import BinaryF1Score
test = [torch.max(pred,dim=1) for pred in preds]
target = test_batch.y
f1 = BinaryF1Score(average='micro').to(device)
f1(test[0][0].to(device), target)

Confusion Matrix

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

confusion_matrix = pd.crosstab(df['y_actual'], df['y_predicted'], rownames=['Actual'], colnames=['Predicted'])
sns.heatmap(confusion_matrix,annot=True)
plt.title('Proteins')
plt.savefig('confusion matrix.png')
plt.show()