In [22]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import torch
import torch_geometric
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import proteinsolver

from proteinsolver.models.model import *
from proteinsolver.datasets import *

# custom stuff
#import proteinsolver_utils
#import proteinsolver_datasets
np.random.seed(1)

sys.path.append('/home/sebastian/masters/')

### Parameter file

In [15]:
UNIQUE_ID = "191f05de"
BEST_STATE_FILES = {
    #
    "191f05de": "/home/sebastian/proteinsolver/data/e53-s1952148-d93703104.state"
}
state_file = BEST_STATE_FILES[UNIQUE_ID]

### Load data

In [23]:
from modules.dataset import *


raw_files, targets = get_data("/home/sebastian/masters/data/210916_TCRpMHCmodels/")
root = "/home/sebastian/masters/data/210916_TCRpMHCmodels/"

# make quick data split
n_data = len(raw_files)
valid_frac = 0.1
valid_num = int(n_data * valid_frac)
selection = np.random.randint(0, n_data, valid_num)
mask = np.zeros(n_data, bool)
mask[selection] = 1

valid_files = raw_files[mask]
valid_targets = targets[mask]
d_valid = ProteinDataset(f"{root}/valid", valid_files, valid_targets, overwrite=False)

train_files = raw_files[~mask]
train_targets = targets[~mask]
d_train = ProteinDataset(f"{root}/train", train_files, train_targets, overwrite=False)

### Init proteinsolver network

In [18]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
#frac_present = 0.5
#frac_present_valid = frac_present
#info_size= 1024

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

gnn = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.eval()
gnn = gnn.to(device)

### Init classifier network

In [19]:
from torch import nn, optim
import torch.nn.functional as F

num_features = 20
epochs = 1

class TestNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        
        self.pool = nn.AdaptiveAvgPool1d(  # https://stackoverflow.com/a/63603993/11398318
            output_size=hidden_size
        ) 
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_size * num_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, x):
        x = self.pool(x)
        x = x.T.flatten()
        x = self.linear(x)
        return x
    
net = TestNet(50)
net = net.to(device)
save_path = "/home/sebastian/masters/data/trained_models/test_model.state"

#criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

In [24]:
from modules.utils import *

train_losses = list()
valid_losses = list()

print(f"Training for {epochs} epochs:")

for i in range(epochs):
    train_loader = iter(torch_geometric.data.DataLoader(d_train, batch_size=batch_size))
    valid_loader = iter(torch_geometric.data.DataLoader(d_valid, batch_size=batch_size))
    
    train_len = len(train_loader)
    valid_len = len(valid_loader)
    
    train_loss = 0
    net.train()
    for j, x in enumerate(train_loader):
        optimizer.zero_grad()
        x = x.to(device)
        y = torch.Tensor(x.y)
        with torch.no_grad():
            x = gnn(x.x, x.edge_index, x.edge_attr)
        out = net(x.T.unsqueeze(0))
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

        display_func(j, train_len, i, train_losses, valid_losses)
        
    valid_loss = 0
    net.eval()
    with torch.no_grad():
        for x in valid_loader:
            x = x.to(device)
            y = torch.Tensor(x.y)
            x = gnn(x.x, x.edge_index, x.edge_attr)
            out = net(x.T.unsqueeze(0))
            loss = criterion(out, y)

            valid_loss += loss.item()

    train_losses.append(train_loss / train_len)
    valid_losses.append(valid_loss / valid_len)
    
print_loss(train_losses, valid_losses, clear_print=True)

torch.save(net.state_dict(), save_path)

epoch: 1 	train_loss: 0.33302355 	valid_loss: 0.58643964


### Visualization and performance

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt

d_test = d_valid
test_loader = iter(torch_geometric.data.DataLoader(d_test, batch_size=1))

net.load_state_dict(torch.load(save_path))
net.eval()

pred = list()
y = list()
with torch.no_grad():
        for x in test_loader:
            x = x.to(device)
            y = x.y
            x = gnn(x.x, x.edge_index, x.edge_attr)
            out = torch.sigmoid(net(x.T.unsqueeze(0)))
            pred.append(out.item())
            y.append(y[0])

pred = np.array(pred)
y = np.array(y)

In [None]:
epoch_range = list(range(epochs))
plt.plot(epoch_range, train_losses, '-')
plt.plot(epoch_range, valid_losses, '-')
plt.legend(["Training loss", "Validation loss"])
plt.title("Binary cross-entropy loss")
plt.xlabel("Epochs")
plt.ylabel("BCE loss")

plt.show()

In [None]:
##### fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=1)
auc = metrics.auc(fpr, tpr)
plt.plot(fpr, tpr, '-')
plt.legend([f"AUC = {round(auc, 5)}"])
plt.title("ROC curve")
plt.xlabel("FPR")
plt.ylabel("TPR")

plt.show()