In [1]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/sebastian/masters/') # add my repo to python path
import os
import torch
import torch.nn.functional as F
import torch_geometric
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import pandas as pd
import proteinsolver
import modules

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, BatchSampler
from sklearn.model_selection import KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import *
from torch import nn, optim
from pathlib import Path

from modules.dataset_utils import *
from modules.dataset import *
from modules.utils import *
from modules.models import *
from modules.lstm_utils import *

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f5ff002b6f0>

### Paths

In [2]:
root = Path("/home/sebastian/masters/data/")
data_root = root / "neat_data"
metadata_path = data_root / "metadata.csv"
processed_dir = data_root / "processed"
state_file = root / "state_files" / "e53-s1952148-d93703104.state"
out_dir = root / "state_files" / "tcr_binding"

### Get metadata

In [3]:
model_dir = data_root / "raw" / "tcrpmhc"

paths = list(model_dir.glob("*"))
join_key = [int(x.name.split("_")[0]) for x in paths]
path_df = pd.DataFrame({'#ID': join_key, 'path': paths})

metadata = pd.read_csv(metadata_path)
metadata = metadata.join(path_df.set_index("#ID"), on="#ID", how="inner")  # filter to non-missing data
metadata = metadata.reset_index(drop=True)
metadata

Unnamed: 0,#ID,CDR3a,CDR3b,peptide,partition,binder,v_gene_alpha,j_gene_alpha,v_gene_beta,j_gene_beta,origin,v_alpha_vdjdb_name,j_alpha_vdjdb_name,v_beta_vdjdb_name,j_beta_vdjdb_name,path
0,1,AVSQSNTGKLI,ASSQLMENTEAF,NLVPMVATV,1,0,TRAV12-2,TRAJ37,TRBV4-1,TRBJ1-1,tenX,TRAV12-2*01,TRAJ37*01,TRBV4-1*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
1,2,AASEVCADYKLS,ASSYSLLRAAPNTEAF,NLVPMVATV,1,0,TRAV29DV5,TRAJ20,TRBV6-3,TRBJ1-1,tenX,TRAV29/DV5*01,TRAJ20*01,TRBV6-3*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
2,3,AGRLGAQKLV,ASSQGGRRNQPQH,NLVPMVATV,1,0,TRAV25,TRAJ54,TRBV4-2,TRBJ1-5,tenX,TRAV25*01,TRAJ54*01,TRBV4-2*01,TRBJ1-5*01,/home/sebastian/masters/data/neat_data/raw/tcr...
3,4,AVEPLYGNKLV,ASSSREAEAF,NLVPMVATV,1,0,TRAV22,TRAJ47,TRBV7-9,TRBJ1-1,tenX,TRAV22*01,TRAJ47*01,TRBV7-9*01,TRBJ1-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
4,5,ASGTYKYI,ASSQRAGRVDTQY,NLVPMVATV,1,0,TRAV19,TRAJ40,TRBV27,TRBJ2-3,tenX,TRAV19*01,TRAJ40*01,TRBV27*01,TRBJ2-3*01,/home/sebastian/masters/data/neat_data/raw/tcr...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10326,12961,AVNSYYNQGGKLI,SVLQGSPYEQY,GILGFVFTL,1,1,TRAV12-2*01,TRAJ23*01,TRBV29-1*01,TRBJ2-7*01,positive,TRAV12-2*01,TRAJ23*01,TRBV29-1*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10327,12962,AGNYGGSQGNLI,ASSIYSVNEQF,GILGFVFTL,1,1,TRAV35*01,TRAJ42*01,TRBV19*01,TRBJ2-1*01,positive,TRAV35*01,TRAJ42*01,TRBV19*01,TRBJ2-1*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10328,12966,AVGGSQGNLI,ASSVRSSYEQY,GILGFVFTL,1,1,TRAV8-6*02,TRAJ42*01,TRBV19*01,TRBJ2-7*01,positive,TRAV8-6*01,TRAJ42*01,TRBV19*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...
10329,12968,AENGGGGADGLT,ASSIRSSYEQY,GILGFVFTL,1,1,TRAV13-2*01,TRAJ45*01,TRBV19*01,TRBJ2-7*01,positive,TRAV13-2*01,TRAJ45*01,TRBV19*01,TRBJ2-7*01,/home/sebastian/masters/data/neat_data/raw/tcr...


i = 10326
print(metadata.iloc[i])
print(metadata.iloc[i]["path"])

print(torch.load(f"/home/sebastian/masters/data/neat_data/processed/tcr_binding/data_{i}.pt"))
print(torch.load(f"/home/sebastian/masters/data/neat_data/processed/tcr_binding/data_{i}.pt").y)

print(torch.load(f"/home/sebastian/masters/data/neat_data/processed/tcr_binding/gnn_out_pos_128/data_{i}.pt").shape)
print(targets[i])
print(raw_files[i])
print(dataset[i][0].shape, dataset[i][1])

### Make GNN embeddings

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# init proteinsolver gnn
num_features = 20
adj_input_size = 2
hidden_size = 128

gnn = Net(
    x_input_size=num_features + 1, 
    adj_input_size=adj_input_size, 
    hidden_size=hidden_size, 
    output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.eval()
gnn = gnn.to(device)

raw_files = np.array(metadata["path"])
targets = np.array(metadata["binder"])

#dataset = ProteinDataset(processed_dir, raw_files, targets, overwrite=False)

gnn_func = gnn.forward_without_last_layer
out_dir = processed_dir / "proteinsolver_preprocess"
#create_gnn_embeddings(dataset, out_dir, device, gnn_func, cores=4, overwrite=False)
#create_gnn_embeddings(dataset, out_dir, device, gnn_func, overwrite=False)

### Make LOO partitions and init dataset

In [5]:
loo_train_partitions, loo_valid_partitions, val, unique_peptides = generate_3_loo_partitions(metadata)

dataset = LSTMDataset(
    data_dir=processed_dir / "proteinsolver_embeddings_pos", 
    annotations_path=processed_dir / "proteinsolver_embeddings_pos" / "targets.pt"
)

### LOO training scheme

In [7]:
# LSTM params
batch_size = 8
embedding_dim = 128 + 4
hidden_dim = 32 #32
num_layers = 2  # from 2
epochs = 10
learning_rate = 1e-3
lr_decay = 0.95
w_decay = 1e-4
dropout = 0.8  # test scheduled dropout. Can set droput using net.layer.dropout = 0.x https://arxiv.org/pdf/1703.06229.pdf

# touch files to ensure output
n_splits = len(unique_peptides)
save_dir = get_non_dupe_dir(out_dir)
loss_paths = touch_output_files(save_dir, "loss", n_splits)
state_paths = touch_output_files(save_dir, "state", n_splits)
pred_paths = touch_output_files(save_dir, "pred", n_splits)

extra_print_str = "\nSaving to {}\nFold: {}\nPeptide: {}"

i = 0
for train_idx, valid_idx in zip(loo_train_partitions, loo_valid_partitions):
    
    net = MyLSTM(
        embedding_dim=embedding_dim, 
        hidden_dim=hidden_dim, 
        num_layers=num_layers, 
        dropout=dropout,
    )
    net = net.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        net.parameters(), 
        lr=learning_rate, 
        #weight_decay=w_decay,
    )  # test learning rate scheduler to reduce validation volatility
    scheduler = optim.lr_scheduler.MultiplicativeLR(
        optimizer, 
        lr_lambda=lambda epoch: lr_decay
    )
    
    net, train_losses, valid_losses = lstm_quad_train(
        net,
        epochs,
        criterion,
        optimizer,
        scheduler,
        dataset,
        train_idx,
        valid_idx,
        batch_size,
        device,
        collate_fn=pad_collate,
        extra_print=extra_print_str.format(save_dir, i, unique_peptides[i]),
    )
    torch.save(net.state_dict(), state_paths[i])
    torch.save({"train": train_losses, "valid": valid_losses}, loss_paths[i])
    
    pred, true = lstm_quad_predict(net, dataset, valid_idx, device)     
    torch.save({"y_pred": pred, "y_true": true}, pred_paths[i])
    
    i += 1


Saving to /home/sebastian/masters/data/neat_data/processed/proteinsolver_preprocess/34c185642b2f58ff00f303f61d4e6027
Fold: 0
Peptide: NLVPMVATV

epoch: 1 - n: 61/953 - [=6%                                                         ]


KeyboardInterrupt: 

In [None]:
# LSTM params
batch_size = 8
embedding_dim = 128
hidden_dim = 32 #32
num_layers = 2  # from 2
epochs = 10
learning_rate = 1e-3
lr_decay = 0.95
w_decay = 1e-4
dropout = 0.8  # test scheduled dropout. Can set droput using net.layer.dropout = 0.x https://arxiv.org/pdf/1703.06229.pdf

# touch files to ensure output
n_splits = len(unique_peptides)
save_dir = get_non_dupe_dir(out_dir)
loss_paths = touch_output_files(save_dir, "loss", n_splits)
state_paths = touch_output_files(save_dir, "state", n_splits)
pred_paths = touch_output_files(save_dir, "pred", n_splits)

extra_print_str = "\nSaving to {}\nFold: {}\nPeptide: {}"

i = 0
for train_idx, valid_idx in zip(loo_train_partitions, loo_valid_partitions):
    
    net = QuadLSTM(
        embedding_dim=embedding_dim, 
        hidden_dim=hidden_dim, 
        num_layers=num_layers, 
        dropout=dropout,
    )
    net = net.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        net.parameters(), 
        lr=learning_rate, 
        #weight_decay=w_decay,
    )  # test learning rate scheduler to reduce validation volatility
    scheduler = optim.lr_scheduler.MultiplicativeLR(
        optimizer, 
        lr_lambda=lambda epoch: lr_decay
    )
    
    net, train_losses, valid_losses = lstm_quad_train(
        net,
        epochs,
        criterion,
        optimizer,
        scheduler,
        dataset,
        train_idx,
        valid_idx,
        batch_size,
        device,
        extra_print=extra_print_str.format(save_dir, i, unique_peptides[i]),
    )
    torch.save(net.state_dict(), state_paths[i])
    torch.save({"train": train_losses, "valid": valid_losses}, loss_paths[i])
    
    pred, true = lstm_quad_predict(net, dataset, valid_idx, device)     
    torch.save({"y_pred": pred, "y_true": true}, pred_paths[i])
    
    i += 1

In [None]:
from io import StringIO 
import sys

class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self
    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout

with Capturing() as output:
    for _ in train_loader:
        pass

In [None]:
err_f = open("/home/sebastian/masters/data/train_run.log", "r")
lines = err_f.readlines()
lines = [line.strip() for line in lines]
data = dict()
prev_i = str()
prev_mode = str()
for line in lines:
    if line[0] == ">":
        if line != ">train" and line != ">pred" and line != ">valid":
            prev_i = line[1:]
            data[prev_i] = dict()
        else:
            prev_mode = line[1:]
            data[prev_i][prev_mode] = list()
    else:
        try:
            data[prev_i][prev_mode].append(int(line))
        except ValueError as err:
            print(err)

### Performance metrics

In [None]:
import matplotlib.pyplot as plt
import copy
valid_pep = "KTWGQYWQV"
#save_dir = Path("/home/sebastian/masters/data/state_files/tcr_binding/2nd_gen/proteinsolver_finetuning_es/60a15522a4f4418ac79b91af3ac55478//")
#save_dir = Path("/home/sebastian/masters/data/state_files/tcr_binding/2nd_gen/lstm_es/2947ef9a6aa76a87acb42a9c02594f54/")
save_dir = Path("/home/sebastian/masters/data/state_files/tcr_binding/2nd_gen/lstm_esm_ps/b3edafc0112356cefbdde3ca0ec5b396")
pred_paths = [save_dir / f"pred_{i}.pt" for i in range(len(unique_peptides))]

In [None]:
t = torch.load(save_dir/"loss_0.pt")
plt.plot(t["train"])
plt.plot(t["valid"])

In [None]:
t["valid"]

In [None]:
n_splits = len(unique_peptides)
threshold = 0.1

overall_pred = list()
overall_true = list()
overall_thres_pred = list()
# compute metrics
perf_data = dict()
for i in range(n_splits):
    data = torch.load(pred_paths[i])
    pred = data["y_pred"]
    true = data["y_true"]

    # auc
    auc = roc_auc_score(true, pred)
    fpr, tpr, thr = roc_curve(true, pred, pos_label=1)
    
    thresh_pred = torch.zeros(len(pred))
    thresh_pred[pred >= threshold] = 1
    mcc = matthews_corrcoef(true, thresh_pred)
    
    pep = unique_peptides[i]
    perf_data[pep] = [fpr, tpr, auc, mcc]

    print(auc, mcc)

    overall_pred.extend(pred)
    overall_true.extend(true)
    overall_thres_pred.extend(thresh_pred)

print("overall AUC:", roc_auc_score(overall_true, overall_pred))  
print(f"overall MCC (t={threshold}):", matthews_corrcoef(overall_true, overall_thres_pred))

performance_file = save_dir / "performance_data.pt"
torch.save(perf_data, performance_file)

# ROC plot
cm = plt.get_cmap('tab20')  # https://matplotlib.org/stable/tutorials/colors/colormaps.html

fig = plt.figure(figsize=(12, 7))
ax = fig.add_subplot(111)
ax.set_prop_cycle(color=[cm(1*i/n_splits) for i in range(n_splits)])
excluded = ["KLQCVDLHV", "KVAELVHFL", "YLLEMLWRL", "SLLMWITQV"] # TODO delete (filter <40 in test set)
for pep in unique_peptides:
    if pep not in excluded:
        ax.plot(
            perf_data[pep][0], 
            perf_data[pep][1], 
            label=f"{pep}, AUC = {round(perf_data[pep][2], 3)}",
        )
plt.legend()
plt.plot([0, 1], [0, 1], linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("LOO validation ROC curve (peptides with count < 40 left out for visual clarity (drop on CDR3b))")
plt.show()

In [None]:
unique_peptides

In [None]:
t = torch.load(save_dir/"pred_4.pt")

In [None]:
t["y_true"]

In [None]:
t["y_pred"]

In [None]:
pep_df = metadata[metadata["peptide"] == unique_peptides[4]]
pep_df = pep_df.reset_index()
pep_df

In [None]:
i=-7
metadata[metadata["CDR3a"] == pep_df.iloc[i]["CDR3a"]]

In [None]:
print(t["y_pred"][i], t["y_true"][i])

In [None]:
plt.hist(t["y_pred"], bins=len(t["y_pred"]))
plt.xlabel("score")
plt.ylabel("count")
plt.title("distribution of pred scores for KVLEYVIKV")
plt.show()

In [None]:
metadata[metadata["peptide"] == unique_peptides[4]]

In [None]:
fpr, tpr, thr = roc_curve(t["y_true"], t["y_pred"], drop_intermediate=False)

In [None]:
len(thr)

In [None]:
fpr

In [None]:
plt.plot(fpr, tpr)

In [None]:
# confusion matrix
labels = ["non-binder", "binder"]
pred_copy = copy.deepcopy(pred)
pred_copy[pred >= 0.2] = 1
pred_copy[pred < 0.2] = 0
cm = confusion_matrix(true, pred_copy)

# f1
f1 = f1_score(true, pred_copy)
disp = ConfusionMatrixDisplay(cm, display_labels=labels)
plot = disp.plot()
plot.figure_.show()

In [None]:
fold_idx = 1
data = torch.load(performance_file)[fold_idx]

cm, f1, auc = data["cm"], data["f1"], data["auc"]
disp = ConfusionMatrixDisplay(cm, display_labels=labels)
disp.plot()
print(f"LOO performance of fold {fold_idx}:")
print(f"AUC={auc}")
print(f"F1={f1}")

In [None]:
# quick viz
count_dict = dict()
for pep in unique_peptides:
    total = len(metadata[metadata["peptide"] == pep])
    pos = len(metadata[(metadata["peptide"] == pep) & (metadata["binder"] == 1)])
    count_dict[pep] = [total, pos]

In [None]:
count_dict

In [None]:
fig = plt.figure(figsize=(12, 7))
ax = fig.add_subplot(111)

peptides = list(count_dict.keys())
negatives = [x[0] - x[1] for x in count_dict.values()]
positives = [x[1] for x in count_dict.values()]

width = 0.4
idx = np.arange(len(unique_peptides))

ax.bar(idx, negatives, width, zorder=3)
ax.bar(idx + width, positives, width, zorder=3)
ax.set_yscale('log')
ax.set_xticks(idx+width)
ax.set_xticklabels(count_dict.keys(), rotation=45)

ax.grid(zorder=0, which='both', axis='y')

plt.legend(["Negatives", "Positives"])
plt.xlabel("Peptide")
plt.ylabel("log(Count)")
plt.title("Number of TCR-pMHC models for each unique peptide (log-scale)")
plt.show()