Next steps:
1. Add a "master node" that connects all nodes together, so that message passing works between all nodes

In [1]:
import pandas as pd

from EmbedDataset import LigandBinaryDataset

import time
import numpy as np
import torch
from torch_geometric.loader import DataLoader

In [2]:
dataset = LigandBinaryDataset('./data2/')

In [3]:
dataset = dataset.shuffle()
train_dataset = dataset[:int(len(dataset) * 0.7)]
val_dataset = dataset[int(len(dataset) * 0.7):int(len(dataset) * 0.85)]
test_dataset = dataset[int(len(dataset) * 0.85):]

In [4]:
train_dl = DataLoader(train_dataset, batch_size=16)
val_dl = DataLoader(val_dataset, batch_size=16)
test_dl = DataLoader(test_dataset, batch_size=16)

In [27]:
from LigandGNNV2 import LigandGNNV2
from LigandGNNV1 import LigandGNNV1
from sagn.models import SAGN

device = torch.device('cuda')
# model = LigandGNNV1(dataset.num_node_features, 1).to(device)
model = LigandGNNV2(128, 37).to(device)
# model = SAGN(in_feats=1070, hidden=1024, out_feats=1, num_hops=3, n_layers=2, num_heads=1, dropout=0.3).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([22]).to(device))

In [28]:
def train(model, loader, criterion, optimizer):
    model.train()

    loss_acc = 0.
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y

        loss = criterion(output, label.reshape(-1, 1))
        loss.backward()
        loss_acc += data.num_graphs * loss.item()
        optimizer.step()

    return loss_acc / len(loader.dataset)

In [29]:
from sklearn.metrics import f1_score, precision_recall_curve, roc_curve, auc, precision_recall_fscore_support, roc_auc_score

def evaluate(model, loader):
    model.eval()

    preds = np.asarray([])
    labels = np.asarray([])

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            pred = torch.sigmoid(model(data).round().squeeze().cpu()).round().numpy()
            label = data.y.cpu().numpy()

            preds = np.concatenate([preds, pred])
            labels = np.concatenate([labels, label])

    precision, recall, _ = precision_recall_curve(labels, preds)
    fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1)
    print(precision_recall_fscore_support(labels, preds))
    return auc(fpr, tpr), auc(recall, precision)
    # return roc_auc_score(labels, preds)

In [30]:
train_hist = []
val_hist = []

for epoch in range(1, 201):
    s = time.time()
    loss = train(model, train_dl, criterion, optimizer)
    train_score = evaluate(model, train_dl)
    val_score = evaluate(model, val_dl)
    scheduler.step(loss)
    e = time.time()

    train_hist.append(train_score)
    val_hist.append(val_score)

    if sum(val_score) > (.77 + .44) and sum(val_score) >= np.asarray(val_hist).sum(axis=1).max():
        print("saving...")
        torch.save(model.state_dict(), './models/BestModel5.pt')

    # print(f'Epoch: {epoch:03d}, Loss: {loss:.05f}, Train Score: {train_score:.05f}, Val Score: {val_score:.05f}, Time: {e - s:.05f}s')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.05f}, Train Score: {train_score}, Val Score: {val_score}, Time: {e - s:.05f}s')

(array([0.96641103, 0.13048841]), array([0.98945599, 0.04398803]), array([0.97779775, 0.06579602]), array([334313,  12026], dtype=int64))
(array([0.96239707, 0.19136961]), array([0.99384013, 0.03618304]), array([0.9778659 , 0.06085919]), array([69969,  2819], dtype=int64))
Epoch: 001, Loss: 1.16346, Train Score: (0.5167220073962809, 0.10383611597356257), Val Score: (0.515011586416284, 0.1324401155528596), Time: 23.89848s
(array([0.97786954, 0.13533015]), array([0.90035087, 0.43356062]), array([0.93751051, 0.20627448]), array([334313,  12026], dtype=int64))
(array([0.97463133, 0.14254729]), array([0.89829782, 0.41965236]), array([0.93490904, 0.21280806]), array([69969,  2819], dtype=int64))
Epoch: 002, Loss: 1.03540, Train Score: (0.6669557437280889, 0.29427967938088545), Val Score: (0.6589750882987448, 0.2923379430966717), Time: 23.40100s
(array([0.97762835, 0.18591582]), array([0.9363052 , 0.40437386]), array([0.95652068, 0.25472069]), array([334313,  12026], dtype=int64))
(array([0.9

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
plt.plot(val_hist)
plt.plot(train_hist)
plt.show()

In [31]:
evaluate(model, test_dl)

(array([0.98432008, 0.27153033]), array([0.95537485, 0.52220839]), array([0.9696315 , 0.35728486]), array([75630,  2409], dtype=int64))


(0.7387916182357941, 0.40424387344167656)

In [32]:
model.load_state_dict(torch.load('./models/BestModel5.pt'))
evaluate(model, test_dl)

(array([0.99411908, 0.12418102]), array([0.80911014, 0.84973018]), array([0.89212377, 0.21669401]), array([75630,  2409], dtype=int64))


(0.8294201599877756, 0.4892749544042394)

In [None]:
data = test_dl.dataset[4].to(device)

In [None]:
pred = model(data)

In [None]:
out = torch.sigmoid(pred).round()
out

In [None]:
print(out.sum())
print(len(out))

In [None]:
len(data.y) / data.y.sum()

In [None]:
data.y.sum()

In [None]:
f1_score(data.y.detach().cpu().numpy(), out.detach().cpu().numpy())

In [None]:
torch.save(model.state_dict(), './models/modelBest.pt')