In [4]:
import sys

# allows importing local scripts (utils folder)
sys.path.append("..")

In [None]:
import torch

from utils.metrics import *
from utils.data_prep import dataset, trainloader, testloader

from torch_geometric.data import DataLoader as DataLoader

In [None]:
import math

total_samples = len(dataset)
n_iterations = math.ceil(total_samples / 5)

total_samples, n_iterations

In [None]:
from torch.optim.lr_scheduler import MultiStepLR


device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


def train(model, trainloader, optimizer, epoch):
    model.train()

    loss_func = torch.nn.MSELoss()
    predictions_tr = torch.Tensor()
    scheduler = MultiStepLR(optimizer, milestones=[1, 5], gamma=0.5)

    labels_tr = torch.Tensor()
    for prot_1, prot_2, label in trainloader:
        prot_1 = prot_1.to(device)
        prot_2 = prot_2.to(device)

        optimizer.zero_grad()
        output = model(prot_1, prot_2)

        predictions_tr = torch.cat((predictions_tr, output.cpu()), 0)
        labels_tr = torch.cat((labels_tr, label.view(-1, 1).cpu()), 0)
        loss = loss_func(output, label.view(-1, 1).float().to(device))

        loss.backward()
        optimizer.step()

    scheduler.step()
    labels_tr = labels_tr.detach().numpy()
    predictions_tr = predictions_tr.detach().numpy()
    acc_tr = get_accuracy(labels_tr, predictions_tr, 0.5)

    print(f"Epoch {epoch - 1}/30 - train_loss: {loss} - train_accuracy: {acc_tr}")

In [None]:
@torch.no_grad()
def predict(model, loader):
    model.eval()
    predictions = torch.Tensor()
    labels = torch.Tensor()
    for prot_1, prot_2, label in loader:
        prot_1 = prot_1.to(device)
        prot_2 = prot_2.to(device)
        # print(torch.Tensor.size(prot_1.x), torch.Tensor.size(prot_2.x))
        output = model(prot_1, prot_2)
        predictions = torch.cat((predictions, output.cpu()), 0)
        labels = torch.cat((labels, label.view(-1, 1).cpu()), 0)
    labels = labels.numpy()
    predictions = predictions.numpy()
    return labels.flatten(), predictions.flatten()

## GCNN

In [None]:
from utils.models import GCNN

N_EPOCHS_TO_STOP = 6
NUM_EPOCHS = 50

epochs_no_improve = 0
early_stop = False

min_loss = 100
best_accuracy = 0

model = GCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = torch.nn.MSELoss()

In [15]:
try:
    for epoch in range(NUM_EPOCHS):
        train(model, trainloader, optimizer, epoch + 1)
        G, P = predict(model, testloader)

        loss = get_mse(G, P)
        accuracy = get_accuracy(G, P, 0.5)

        print(f"Epoch {epoch}/{NUM_EPOCHS} - val_loss: {loss} - val_accuracy: {accuracy}")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_acc_epoch = epoch

            torch.save(model.state_dict(), "../datasets/models/GCN.pth")

        if loss < min_loss:
            epochs_no_improve = 0
            min_loss = loss
            min_loss_epoch = epoch
        elif loss > min_loss:
            epochs_no_improve += 1
        if epoch > 5 and epochs_no_improve == N_EPOCHS_TO_STOP:
            print("Early stopping!")
            early_stop = True
            break
except KeyboardInterrupt:
    torch.save(model.state_dict(), "../datasets/models/GCN.pth")

Epoch 0/50 - val_loss: 0.22415585425100204 - val_accuracy: 100.0
Epoch 0/30 - train_loss: 0.2417270988225937 - train_accuracy: 87.5
Epoch 0/50 - val_loss: 0.22415585425100204 - val_accuracy: 100.0
Epoch 1/30 - train_loss: 0.21505936980247498 - train_accuracy: 100.0
Epoch 1/50 - val_loss: 0.2020673362312042 - val_accuracy: 100.0
Epoch 2/30 - train_loss: 0.19504331052303314 - train_accuracy: 100.0
Epoch 2/50 - val_loss: 0.18626437093041126 - val_accuracy: 100.0
Epoch 3/30 - train_loss: 0.17464394867420197 - train_accuracy: 100.0
Epoch 3/50 - val_loss: 0.17689909676743198 - val_accuracy: 100.0
Epoch 4/30 - train_loss: 0.18159420788288116 - train_accuracy: 100.0
Epoch 4/50 - val_loss: 0.1717450691658673 - val_accuracy: 100.0
Epoch 5/30 - train_loss: 0.16102281212806702 - train_accuracy: 100.0
Epoch 5/50 - val_loss: 0.16902849359761163 - val_accuracy: 100.0
Epoch 6/30 - train_loss: 0.16341885924339294 - train_accuracy: 100.0
Epoch 6/50 - val_loss: 0.16762955680216152 - val_accuracy: 100.0
E

In [3]:
print(f"min_val_loss: {min_loss} for epoch {min_loss_epoch}")
print(f"best_val_accuracy: {best_accuracy} for epoch {best_acc_epoch}")

min_val_loss: 0.16619315895778186 for epoch 21
best_val_accuracy: 100.0 for epoch 0


## GAT

In [None]:
from utils.models import AttGNN

N_EPOCHS_TO_STOP = 6
NUM_EPOCHS = 50

epochs_no_improve = 0
early_stop = False

min_loss = 100
best_accuracy = 0

model = AttGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = torch.nn.MSELoss()

In [2]:
try:
    for epoch in range(NUM_EPOCHS):
        train(model, trainloader, optimizer, epoch + 1)
        G, P = predict(model, testloader)

        loss = get_mse(G, P)
        accuracy = get_accuracy(G, P, 0.5)

        print(f"Epoch {epoch}/{NUM_EPOCHS} - val_loss: {loss} - val_accuracy: {accuracy}")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_acc_epoch = epoch

            torch.save(model.state_dict(), "../datasets/models/GAT.pth")

        if loss < min_loss:
            epochs_no_improve = 0
            min_loss = loss
            min_loss_epoch = epoch
        elif loss > min_loss:
            epochs_no_improve += 1
        if epoch > 5 and epochs_no_improve == N_EPOCHS_TO_STOP:
            print("Early stopping!")
            early_stop = True
            break
except KeyboardInterrupt:
    torch.save(model.state_dict(), "../datasets/models/GAT.pth")

Epoch 0/30 - train_loss: 0.2417270988225937 - train_accuracy: 87.50
Epoch 0/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 1/30 - train_loss: 0.2393461464416413 - train_accuracy: 88.10
Epoch 1/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 2/30 - train_loss: 0.2369651940606889 - train_accuracy: 88.69
Epoch 2/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 3/30 - train_loss: 0.2345842416797365 - train_accuracy: 89.29
Epoch 3/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 4/30 - train_loss: 0.2322032892987842 - train_accuracy: 89.88
Epoch 4/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 5/30 - train_loss: 0.2298223369178318 - train_accuracy: 90.48
Epoch 5/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 6/30 - train_loss: 0.2274413845368794 - train_accuracy: 91.07
Epoch 6/50 - val_loss: 0.2241558542510020 - val_accuracy: 100.00
Epoch 7/30 - train_loss: 0.2250604321559270 - train_accuracy: 91.67
E

In [13]:
print(f"min_val_loss : {min_loss} for epoch {min_loss_epoch}")
print(f"best_val_accuracy : {best_accuracy} for epoch {best_acc_epoch}")

min_val_loss: 0.16619315895778186 for epoch 21
best_val_accuracy: 100.0 for epoch 18