In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle

import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

In [2]:
# def format_pytorch_version(version):
#   return version.split('+')[0]

# TORCH_version = torch.__version__
# TORCH = format_pytorch_version(TORCH_version)

# def format_cuda_version(version):
#   return 'cu' + version.replace('.', '')

# CUDA_version = torch.version.cuda
# CUDA = format_cuda_version(CUDA_version)

# !pip install tqdm
# !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-geometric

In [3]:
from torch_geometric.data import Dataset, download_url, DataLoader
from torch_geometric.data import Data

from scipy.spatial.distance import squareform, pdist
import os
from tqdm import tqdm

In [4]:
root = ''
db_path = 'pt_cache/'

In [5]:
from torch_geometric.nn import GCNConv, GINConv, GATConv, SAGEConv, GATv2Conv
import glob
import torch.optim as optim


class CustomGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads):
        super(CustomGATLayer, self).__init__()
        self.multihead_attention = GATConv(
            in_dim, out_dim, heads=num_heads, concat=True
        )
        self.norm = nn.LayerNorm(out_dim * num_heads)

    def forward(self, x, edge_index):
        x = self.multihead_attention(x, edge_index)
        x = self.norm(x)
        return x


class CustomGATModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes, num_layers, num_heads, dict_size, num_fc_layers):
        super(CustomGATModel, self).__init__()
        self.embedding = nn.Embedding(dict_size, in_dim)
        self.gat_layers = nn.ModuleList()
        for _ in range(num_layers):
            gat_layer = CustomGATLayer(in_dim, hidden_dim, num_heads)
            self.gat_layers.append(gat_layer)
            in_dim = hidden_dim * num_heads  # Update input dimension for the next layer
        
        self.fc = nn.ModuleList()
        
        for _ in range(num_fc_layers-1):
            self.fc.append(nn.Linear(hidden_dim * num_heads, hidden_dim * num_heads))

        self.fc.append(nn.Linear(hidden_dim * num_heads, hidden_dim))
        self.fc.append(nn.Linear(hidden_dim, num_classes))

    def forward(self, data):
        x, edge_index = data.x.int(), data.edge_index.int()
        x = self.embedding(x)[:, 0, :]
        for gat_layer in self.gat_layers:
            x = F.relu(gat_layer(x, edge_index))
        
        for fc_layer in self.fc[:-1]:
            x = F.relu(fc_layer(x))
        
        logits = self.fc[-1](x)
            
        return logits

In [13]:
max_graph_size = 2000

# # List of file paths to the data
# file_paths = os.listdir(db_path)

# df = pd.read_csv('size_data.csv', index_col=0)
# df.file = df.file.apply(lambda x: x.split('/')[-1].split('.')[0] + '.pt')

# subset = df[(df['size'] <= max_graph_size) & (df.file.isin(file_paths))]

# file_paths = subset.file.values.tolist()

# dataset = []
# for x in tqdm(file_paths):
#     dataset.append([x.split('.')[0], torch.load(db_path+x)])
    
# pickle.dump(dataset, file = open("dataset.pickle", "wb"))

dataset = pickle.load(open("dataset.pickle", "rb"))
protein_names = [x[0] for x in dataset]
dataset = [x[1] for x in dataset]

In [19]:
hyperparameters = dict(
    n_heads = 8,
    n_layers = 4,
    embed_dim = 256,
    n_fc_layers = 1,
    batch_size = 32,
    early_stopping_epochs = 10,
    lr = 0.001,
    clip = 1.0
)

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

train_ratio = 0.8
split_idx = int(train_ratio * len(dataset))

train_files = dataset[:split_idx]
test_files = dataset[split_idx:]

model = CustomGATModel(32, hyperparameters['embed_dim'], 3, hyperparameters['n_layers'],
                       hyperparameters['n_heads'], max_graph_size+1, hyperparameters['n_fc_layers']).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hyperparameters['lr'])

T_max = 32
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)

cpu


In [21]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(model)/1e6

13.278755

In [25]:
train_loader = DataLoader(train_files, batch_size=hyperparameters['batch_size'], shuffle=True, num_workers=4)
test_loader = DataLoader(test_files, batch_size=hyperparameters['batch_size'], shuffle=False, num_workers=4)

In [None]:
epochs = 100

prev_loss_test = 1e100
es_epochs = 0

for epoch in range(epochs):
    correct = 0
    total = 0
    loss_test = 0
    with torch.no_grad():
        for data in test_loader:
            outputs = model(data.to(device))
            loss_test += criterion(outputs, data.y)

            total += data.y.size(0)
            correct += (outputs.argmax(dim=1) == data.y).sum().item()

    accuracy_test = correct / total
    best_yet = ' '

    if loss_test < prev_loss_test:
        es_spochs = 0
        prev_loss_test = loss_test
        torch.save(model, root + 'best_model.pt')
        best_yet = '*'
    else:
        es_epochs += 1

    correct = 0
    total = 0
    loss_train = 0
    with torch.no_grad():
        for data in train_loader:
            outputs = model(data.to(device))
            loss_train += criterion(outputs, data.y)

            total += data.y.size(0)
            correct += (outputs.argmax(dim=1) == data.y).sum().item()

    accuracy_train = correct / total

    print(f'Epoch: {epoch}/{epochs}\tTrain Accuracy: {accuracy_train:.2f}\tTest Accuracy: {accuracy_test:.2f}\t' + best_yet + f'\tLoss Train: {loss_train:.2f}\tLoss Test: {loss_test:.2f}', file=open("output.txt", "a"))
    print(f'Epoch: {epoch}/{epochs}\tTrain Accuracy: {accuracy_train:.2f}\tTest Accuracy: {accuracy_test:.2f}\t' + best_yet + f'\tLoss Train: {loss_train:.2f}\tLoss Test: {loss_test:.2f}')

    for data in train_loader:
        outputs = model(data.to(device))

        loss = criterion(outputs, data.y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), hyperparameters['clip'])
        optimizer.step()

    scheduler.step()

    if es_epochs > hyperparameters['early_stopping_epochs']:
        print('Early Stopping...')
        print('Early Stopping...', file=open("output.txt", "a"))
        break