In [1]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/sebastian/masters/') # add my repo to python path
import os
import torch
import torch_geometric
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import proteinsolver
import modules

from pathlib import Path
from modules.dataset import *
from modules.utils import *
from modules.model import *
from modules.my_model import *

np.random.seed(1)

In [2]:
root = Path("/home/sebastian/masters/data/")
data_root = root / "neat_data"
state_file = root / "state_files" / "e53-s1952148-d93703104.state"

### Load positive examples and generate different structure configurations

In [3]:
class ChainFilter(kmbio.PDB.Select):
    def __init__(self, subset):
        self.subset = subset

    def accept_chain(self, chain):
        if chain.id in self.subset:
            return 1
        else:
            return 0

overwrite = False

raw_files, targets = get_data(
    model_dir=data_root / "raw" / "tcrpmhc",
    metadata=data_root / "metadata.csv",
)
mask = np.ma.masked_array(raw_files, mask=targets)  # only get positives

pmhc_chain_subset = ["M", "P"]
p_chain_subset = ["P"]
annotated_paths = list()

outdir_1 = data_root / "raw" / "pmhc"
outdir_2 = data_root / "raw" / "p"

outdir_1.mkdir(parents=True, exist_ok=True)
outdir_2.mkdir(parents=True, exist_ok=True)
for raw_file in raw_files[mask.mask]:

    model_id = raw_file.name.split("_")[0]
    pmhc_file_name = outdir_1 / f"{model_id}_pmhc.pdb"
    p_file_name =  outdir_2/ f"{model_id}_p.pdb"
    
    if overwrite or (not pmhc_file_name.is_file() or not p_file_name.is_file()):
        structure  = kmbio.PDB.load(raw_file)
    
        io = kmbio.PDB.io.PDBIO()
        io.set_structure(structure)
        io.save(pmhc_file_name, ChainFilter(subset=pmhc_chain_subset))
    
        io = kmbio.PDB.io.PDBIO()
        io.set_structure(structure)
        io.save(p_file_name, ChainFilter(subset=p_chain_subset))

    annotated_paths.append([raw_file, "0"])  # add indices of peptide
    annotated_paths.append([pmhc_file_name, "1"])  # add indices of peptide
    annotated_paths.append([p_file_name, "2"])  # add indices of peptide

metadata_path = data_root / "embedding_dataset.csv"
with open(metadata_path, "w") as metadata_outfile:
    for data in annotated_paths:
        print(data[0], data[1], sep=",", file=metadata_outfile)

### Preprocess data

In [4]:
metadata_path = data_root / "embedding_dataset.csv"
raw_files = list()
targets = list()
with open(metadata_path, "r") as infile:
    for line in infile:
        line = line.strip().split(",")
        raw_files.append(line[0])
        targets.append(int(line[1]))

raw_files = np.array(raw_files)
targets = np.array(targets)

# filter sequences to contain peptide only

processed_dir = data_root / "processed" / "embedding_verification"
dataset = ProteinDataset(processed_dir, raw_files, targets, overwrite=False)

### Define classifier

In [5]:
from torch import nn, optim
import torch.nn.functional as F

batch_size = 3
num_features = 20
adj_input_size = 2
hidden_size = 128
#frac_present = 0.5
#frac_present_valid = frac_present
#info_size= 1024

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

gnn = Net(
    x_input_size=num_features + 1, 
    adj_input_size=adj_input_size, 
    hidden_size=hidden_size, 
    output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.eval()
gnn = gnn.to(device)

class MyFNN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_size * num_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)  # output dim should be 3 long
        )

    def forward(self, x):
        x = x.T.flatten()
        x = self.linear(x)
        return x

num_classes = 3
num_features = 20
num_layers = 2
hidden_size = 26

#class MyLSTM(nn.Module):
#    def __init__(self,  num_classes, num_features, num_layers, hidden_size):
#        super(MyLSTM, self).__init__()
#        
#        self.num_layers = num_layers
#        self.hidden_size = hidden_size
#        
#        self.lstm = nn.LSTM(
#            input_size=num_features,
#            hidden_size=hidden_size,
#            num_layers=num_layers, 
#            dropout=0.5, 
#            bidirectional=True
#        )
#        self.dropout = nn.Dropout(p=0.5)
#        self.linear = nn.Linear(hidden_size * 2, num_classes)
#        
#        torch.nn.init.xavier_uniform_(self.linear.weight)
#    
#    def forward(self, x):
#        x = nn.utils.rnn.pack_sequence(x).to(device)
#        x, (h, c) = self.lstm(x)
#        h_cat = torch.cat((h[-2, :, :], h[-1, :, :]), dim=1)
#        out = self.dropout(h_cat)
#        out = self.linear(out)
#        return out

net = MyLSTM(num_classes, num_features, num_layers, hidden_size)
net = net.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001) 

In [6]:
import random
import hashlib

out_dir = root / "state_files/embedding_verification/"

In [7]:
#def gnn_to_lstm_batch(data, gnn_instance, num_classes):
#        data = data.to(device)
#        y = data.y
#        with torch.no_grad():
#            out = gnn_instance(data.x, data.edge_index, data.edge_attr)
#        
#        batches = torch.unique(data.batch)
#        sliced_embeddings = list()
#        encoded_y = list()
#        for batch_idx in batches:
#            # split sub graphs into batches
#            batch_slice = torch.nonzero(data.batch == batch_idx)
#            chain_map = data.chain_map[batch_idx]
#            one_batch_peptide_emb = out[batch_slice][chain_map == "P"]  # get peptide only
#            sliced_embeddings.append(one_batch_peptide_emb.squeeze(1))
#            
#            # one hot encode targets
#            sliced_y = int(y[batch_idx].item())
#            one_hot_y = np.zeros(num_classes)
#            one_hot_y[sliced_y] = 1
#            encoded_y.append(one_hot_y)
#            
#        sliced_embeddings.sort(key=lambda x: len(x))
#        encoded_y = torch.Tensor(encoded_y)
#        
#        return sliced_embeddings, encoded_y

#x, y = gnn_to_batch(next(train_loader), gnn, num_classes)
#x = net(x)

### Fine tuning classifier

In [8]:
#def train_model(
#    epochs, 
#    model,
#    criterion,
#    train_data, 
#    valid_data,
#    batch_size,
#    device,
#    transform=None,
#):
#    train_losses = list()
#    valid_losses = list()
#    for i in range(epochs):
#        train_loader = iter(torch_geometric.loader.DataLoader(train_data, batch_size=batch_size))
#        valid_loader = iter(torch_geometric.loader.DataLoader(valid_data, batch_size=batch_size))
#        
#        train_len = len(train_loader)
#        valid_len = len(valid_loader)
#
#        train_loss = 0
#        model.train()
#        for j, x in enumerate(train_loader):
#            if transform:
#                x, y = transform(x)  # y needs to be changed to multiclass y i.e. [0,0,1] instead of [2].
#            else:
#                x, y = x.to(device), y.to(device)
#            out = model(x)
#            loss = criterion(out, y)
#            loss.backward()
#            train_loss += loss.item()
#            optimizer.step()
#
#            display_func(j, train_len, i, train_losses, valid_losses)
#
#        valid_loss = 0
#        model.eval()
#        with torch.no_grad():
#            for x in valid_loader:
#                if transform:
#                    x, y = transform(x)
#                else:
#                    x, y = x.to(device), y.to(device)
#                out = model(x)
#                loss = criterion(out, y)
#                valid_loss += loss.item()
#
#        train_losses.append(train_loss / train_len)
#        valid_losses.append(valid_loss / valid_len)
#    return model, train_losses, valid_losses

In [None]:
from sklearn.model_selection import KFold

epochs = 1
n_splits = 5
batch_size = 5
num_classes = 3

loss_paths, state_paths = touch_output_files(out_dir, n_splits)

CV = KFold(n_splits=n_splits, shuffle=True)
i = 0
for train_idx, valid_idx in CV.split(dataset):
    
    train_subset = dataset[torch.LongTensor(train_idx)][0:10]
    valid_subset = dataset[torch.LongTensor(valid_idx)][0:10]
    
    net = MyLSTM(num_classes, num_features, num_layers, hidden_size)
    net = net.to(device)
    
    # partial function - gnn arg is static, x is given later
    gnn_transform = lambda x: gnn_to_lstm_batch(
        x, 
        gnn_instance=gnn, 
        device=device,
        num_classes=num_classes
)
    
    net, train_subset_losses, valid_subset_losses = train_model(
        epochs=epochs, 
        model=net, 
        criterion=criterion,
        optimizer=optimizer,
        train_data=train_subset, 
        valid_data=valid_subset,
        batch_size=batch_size,
        device=device,
        transform=gnn_transform,
)

    torch.save({"train": train_subset_losses, "valid": valid_subset_losses}, loss_paths[i])
    torch.save(net.state_dict(), state_paths[i])
    i += 1


