In [1]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/sebastian/masters/') # add my repo to python path
import os
import torch
import torch.nn.functional as F
import torch_geometric
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import pandas as pd
import proteinsolver
import modules

#from Bio import SeqIO
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, BatchSampler
from sklearn.model_selection import KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import *
from torch import nn, optim
from pathlib import Path

from modules.dataset import *
from modules.utils import *
from modules.models import *
from modules.lstm_utils import *


np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f71f810d6d0>

In [2]:
root = Path("/home/sebastian/masters/data/")
data_root = root / "neat_data"
metadata_path = data_root / "metadata.csv"
processed_dir = data_root / "processed" / "tcr_binding"
state_file = root / "state_files" / "e53-s1952148-d93703104.state"
out_dir = root / "state_files" / "tcr_binding"

### Dataset

In [13]:
model_dir = data_root / "raw" / "tcrpmhc"

paths = list(model_dir.glob("*"))
join_key = [int(x.name.split("_")[0]) for x in paths]
path_df = pd.DataFrame({'#ID': join_key, 'path': paths})

metadata = pd.read_csv(metadata_path)
metadata = metadata.join(path_df.set_index("#ID"), on="#ID", how="inner")  # filter to non-missing data
metadata = metadata.reset_index(drop=True)
metadata

Unnamed: 0,#ID,CDR3a,CDR3b,peptide,partition,binder,v_gene_alpha,j_gene_alpha,v_gene_beta,j_gene_beta,origin,v_alpha_vdjdb_name,j_alpha_vdjdb_name,v_beta_vdjdb_name,j_beta_vdjdb_name,path
0,1,AVSQSNTGKLI,ASSQLMENTEAF,NLVPMVATV,1,0,TRAV12-2,TRAJ37,TRBV4-1,TRBJ1-1,tenX,TRAV12-2*01,TRAJ37*01,TRBV4-1*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
1,2,AASEVCADYKLS,ASSYSLLRAAPNTEAF,NLVPMVATV,1,0,TRAV29DV5,TRAJ20,TRBV6-3,TRBJ1-1,tenX,TRAV29/DV5*01,TRAJ20*01,TRBV6-3*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
2,3,AGRLGAQKLV,ASSQGGRRNQPQH,NLVPMVATV,1,0,TRAV25,TRAJ54,TRBV4-2,TRBJ1-5,tenX,TRAV25*01,TRAJ54*01,TRBV4-2*01,TRBJ1-5*01,/home/sebastian/masters/data/neat_data/raw/tcr...
3,4,AVEPLYGNKLV,ASSSREAEAF,NLVPMVATV,1,0,TRAV22,TRAJ47,TRBV7-9,TRBJ1-1,tenX,TRAV22*01,TRAJ47*01,TRBV7-9*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
4,5,ASGTYKYI,ASSQRAGRVDTQY,NLVPMVATV,1,0,TRAV19,TRAJ40,TRBV27,TRBJ2-3,tenX,TRAV19*01,TRAJ40*01,TRBV27*01,TRBJ2-3*01,/home/sebastian/masters/data/neat_data/raw/tcr...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10326,12961,AVNSYYNQGGKLI,SVLQGSPYEQY,GILGFVFTL,1,1,TRAV12-2*01,TRAJ23*01,TRBV29-1*01,TRBJ2-7*01,positive,TRAV12-2*01,TRAJ23*01,TRBV29-1*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10327,12962,AGNYGGSQGNLI,ASSIYSVNEQF,GILGFVFTL,1,1,TRAV35*01,TRAJ42*01,TRBV19*01,TRBJ2-1*01,positive,TRAV35*01,TRAJ42*01,TRBV19*01,TRBJ2-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10328,12966,AVGGSQGNLI,ASSVRSSYEQY,GILGFVFTL,1,1,TRAV8-6*02,TRAJ42*01,TRBV19*01,TRBJ2-7*01,positive,TRAV8-6*01,TRAJ42*01,TRBV19*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10329,12968,AENGGGGADGLT,ASSIRSSYEQY,GILGFVFTL,1,1,TRAV13-2*01,TRAJ45*01,TRBV19*01,TRBJ2-7*01,positive,TRAV13-2*01,TRAJ45*01,TRBV19*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...


In [None]:
unique_peptides = metadata["peptide"].unique()

metadata["merged_chains"] = metadata["CDR3a"] + metadata["CDR3b"]
loo_train_partitions = list()
loo_valid_partitions = list()
for pep in unique_peptides:
    valid_df = metadata[metadata["peptide"] == pep]
    valid_unique_cdr = valid_df["merged_chains"].unique()
    
    # get training rows and drop swapped data
    train_df = metadata[metadata["peptide"] != pep]
    train_df = train_df[~train_df["merged_chains"].str.contains('|'.join(valid_unique_cdr))]

    loo_train_partitions.append(list(train_df.index))
    loo_valid_partitions.append(list(valid_df.index))

# hacky dataset fix
# hacky dataset fix
# hacky dataset fix
filtered_peptides = ["CLGGLLTMV", "ILKEPVHGV"]
filtered_indices = list()
filtered_partitions = list()

for pep in filtered_peptides:
    filtered_indices.extend(list(metadata[metadata["peptide"] == pep].index))
    filtered_partitions.extend(np.where(unique_peptides == pep)[0])

loo_train_partitions = [part for i, part in enumerate(loo_train_partitions) if i not in filtered_partitions]
loo_valid_partitions = [part for i, part in enumerate(loo_valid_partitions) if i not in filtered_partitions]

filtered_indices = set(filtered_indices)

for i in range(len(loo_train_partitions)):
    train_part, valid_part = loo_train_partitions[i], loo_valid_partitions[i]
    train_part = [i for i in train_part if i not in filtered_indices]
    valid_part = [i for i in valid_part if i not in filtered_indices]
    loo_train_partitions[i], loo_valid_partitions[i] = train_part, valid_part
    
unique_peptides = np.delete(unique_peptides, filtered_partitions)

raw_files = np.array(metadata["path"])
targets = np.array(metadata["binder"])
dataset = ProteinDataset(processed_dir, raw_files, targets, overwrite=False)

In [5]:
for x in [(len(x), len(y), len(x)+len(y)) for x,y in zip(loo_train_partitions, loo_valid_partitions)]:
    print(x)

(8381, 1217, 9598)
(10214, 70, 10284)
(10174, 87, 10261)
(10017, 201, 10218)
(10253, 39, 10292)
(10236, 54, 10290)
(10182, 89, 10271)
(10169, 86, 10255)
(9946, 217, 10163)
(10262, 34, 10296)
(9917, 245, 10162)
(7613, 1667, 9280)
(10192, 76, 10268)
(2011, 6223, 8234)
(10306, 8, 10314)
(10308, 7, 10315)


### ProteinSolver model

In [6]:
from torch_geometric.nn import global_mean_pool

class MyGNN(nn.Module):
    def __init__(self, x_input_size, adj_input_size, hidden_size, output_size):
        super().__init__()

        self.embed_x = nn.Sequential(
            nn.Embedding(x_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            # nn.ReLU(),
        )
        self.embed_adj = (
            nn.Sequential(
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                # nn.ELU(),
            )
            if adj_input_size
            else None
        )
        self.graph_conv_0 = get_graph_conv_layer(
            (2 + bool(adj_input_size)) * hidden_size, 2 * hidden_size, hidden_size
        )

        N = 3
        graph_conv = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv = _get_clones(graph_conv, N)

        self.linear_out = nn.Linear(hidden_size, output_size)  # re-assign to (hidden_size, 1)
        
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.forward_without_last_layer(x, edge_index, edge_attr)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5)
        x = self.linear_out(x)
        return x

    def forward_without_last_layer(self, x, edge_index, edge_attr):
        x = self.embed_x(x)
        # edge_index, _ = add_self_loops(edge_index)  # We should remove self loops in this case!
        edge_attr = self.embed_adj(edge_attr) if edge_attr is not None else None

        x_out, edge_attr_out = self.graph_conv_0(x, edge_index, edge_attr)
        x = x + x_out
        edge_attr = (
            (edge_attr + edge_attr_out) if edge_attr is not None else edge_attr_out
        )

        for i in range(3):
            x = F.relu(x)
            edge_attr = F.relu(edge_attr)
            x_out, edge_attr_out = self.graph_conv[i](x, edge_index, edge_attr)
            x = x + x_out
            edge_attr = edge_attr + edge_attr_out

        return x
    

def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def gnn_train(
    model,
    epochs,
    criterion,
    optimizer,
    #scheduler,
    dataset,
    train_idx,
    valid_idx,
    batch_size,
    device,
    extra_print=None,
):
    train_losses = list()
    valid_losses = list()
    
    for e in range(epochs):
        
        train_sampler = BatchSampler(SubsetRandomSampler(train_idx), batch_size=batch_size, drop_last=False)
        valid_sampler = BatchSampler(SubsetRandomSampler(valid_idx), batch_size=1, drop_last=False)
        
        train_loader = torch_geometric.loader.DataLoader(dataset=dataset, batch_sampler=train_sampler)
        valid_loader = torch_geometric.loader.DataLoader(dataset=dataset, batch_sampler=valid_sampler)

        train_len = len(train_loader)
        valid_len = len(valid_loader)
        
        train_loss = 0
        model.train()
        j = 0
        for data in train_loader:    
            data.y = data.y.to(device)
            y_pred = model(data.x, data.edge_index, data.edge_attr, data.batch).squeeze(1)
            optimizer.zero_grad()
            loss = criterion(y_pred, data.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            display_func(j, train_len, e, train_losses, valid_losses, extra_print)
            j += 1
        
        valid_loss = 0
        model.eval()
        with torch.no_grad():
            for data in valid_loader:    
                data.y = data.y.to(device)
                y_pred = model(data.x, data.edge_index, data.edge_attr, data.batch).squeeze(1)
                loss = criterion(y_pred, data.y)
                valid_loss += loss.item()
        
        #scheduler.step()
        train_losses.append(train_loss / train_len)
        valid_losses.append(valid_loss / valid_len)

    return model, train_losses, valid_losses

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# init proteinsolver gnn
num_features = 20
adj_input_size = 2
hidden_size = 128

gnn = MyGNN(
    x_input_size=num_features + 1, 
    adj_input_size=adj_input_size, 
    hidden_size=hidden_size, 
    output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.linear_out = nn.Linear(hidden_size, 1)

#gnn.eval()
gnn = gnn.to(device)

#https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=CN3sRVuaQ88l

learning_rate = 1e-2
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    gnn.parameters(), 
    lr=learning_rate, 
    #weight_decay=w_decay,
)


epochs = 10
batch_size = 2
train_idx = list(range(50))
valid_idx = list(range(50, 60))

gnn_train(
    gnn,
    epochs,
    criterion,
    optimizer,
    #scheduler,
    dataset,
    train_idx,
    valid_idx,
    batch_size,
    device,
)




ValueError: Target size (torch.Size([1])) must be the same as input size (torch.Size([1, 1]))

In [None]:
#batch_size = 3
#train_idx = list(range(1000))
#train_sampler = BatchSampler(SubsetRandomSampler(train_idx), batch_size=batch_size, drop_last=False)
#data_loader = iter(torch_geometric.loader.DataLoader(dataset, batch_sampler=train_sampler))
#data_loader = iter(torch_geometric.loader.DataLoader(dataset, batch_size=batch_size))

In [None]:
#data = next(data_loader)
#print(data.x.shape)

In [None]:
#y_pred = gnn(data.x, data.edge_index, data.edge_attr, data.batch)

In [None]:
#y_pred = y_pred
#loss = criterion(y_pred, data.y)
#loss.backward()
#optimizer.step()
#optimizer.zero_grad()
#
#print(loss.item())