# Control

## imports

In [None]:
# general imports
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import json
import os
import pickle

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from datetime import datetime

# biopython
import Bio
from Bio import SeqIO
from Bio import pairwise2
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.pairwise2 import format_alignment
from Bio.SubsMat import MatrixInfo as matlist

# pytorch
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F

# pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pl_bolts.models import VAE

# ImmunoBERT
import pMHC
from pMHC.logic import PresentationPredictor
from pMHC.data import MhcAllele
from pMHC import SEP, \
    SPLITS, SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST, \
    VIEWS, VIEW_SA, VIEW_SAMA, VIEW_DECONV, \
    INPUT_PEPTIDE, INPUT_CONTEXT
from pMHC.data.utils import convert_example_to_batch, move_dict_to_device, get_input_rep_PSEUDO

# visualizations
from protein_map_visualisation_tools import generate_embedding_map_from_database

# generative model
import SpikeOracle
from SpikeOracle.data import StandardDataset
from SpikeOracle.presentation_scoring.IB import score_seq_IB
from SpikeOracle.presentation_scoring.nMp import eval_peptides_nMp, score_seq_nMp
from SpikeOracle.models.VAE.fc import FcVAE
from SpikeOracle.latent import get_latent_from_seq_FcVAE, get_seq_from_latent_FcVAE
from SpikeOracle.utils import write_seqs_to_fasta

## constants

In [None]:
FILENAME = f"..{os.sep}data{os.sep}spikeprot_bigger_dataset.afa" 

# data constants
SEQ_LEN = 1449 # 1282  # 18   # restricted to 1271 aa sequence lenghts
MAX_SEQ_LEN = SEQ_LEN
AA_ENC_DIM = 21   # count of amino acid encoding dimensions
SEP = os.sep
IMMUNO_CATS = 3

# most relevant MHC alleles
MHC_list = ["HLA-A01:01", "HLA-A02:01", "HLA-A03:01", "HLA-A24:02", "HLA-A26:01",
            "HLA-B07:02", "HLA-B08:01", "HLA-B27:05", "HLA-B39:01", "HLA-B40:01", "HLA-B58:01", "HLA-B15:01"]

# ImmunoBERT
IB_VERSION = "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001"
IB_CHECKPOINT = "epoch=4-step=3648186"
IB_PROJ_PATH = r"C:\Users\s2118339\Documents\MSc_AI_Thesis_final\MScProject"

# netMHCpan
NMP_FOLDER_1 = f"..{os.sep}netMHCpan"
NMP_FOLDER_2 = r"~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan"

### fully connected VAE

In [None]:
# encoder parameters
FC_ENC_INT_DIM = 512 # intermediate dimensions of the encoder
FC_ENC_OUT_DIM = 128 # output dimensions of the encoder

# latent space
FC_LATENT_DIM = 2   # dimensions of the latent space

# decoder parameters
FC_DEC_INT_DIM = 512 # intermediate dimensions of the decoder

# VAE parameters
FC_KL_TARGET = 0.1    # value of the KL divergence in the loss function

# training parameters
FC_LR = 3e-4         # the learning rate
FC_BATCH_SIZE = 64   # batch size
FC_DROPOUT = 0.05
FC_WEIGHT_DECAY = 1e-6 # 3e-5

FC_GEN_FILENAME = f"FcVAE_generated"

## notebook control

In [None]:
LOAD_IB_MODEL = True
LOAD_IB_PEPTIDE_SCORES = True
SAVE_IB_PEPTIDE_SCORES = False

ANTIGENICITY = 2  # 1... ImmunoBERT, 2... netMHCpan

FC_EPOCHS = 100
FC_SAVE = "FC_002.ckpt" # "Fc_test" # None
FC_LOAD = "FC_002.ckpt" # "Fully_Eps_100_KlTgt_25e-2KL_Cdtl.ckpt"

FC_SAMPLES = 1000

# Dataset

In [None]:
# load dataset
ds = StandardDataset(FILENAME, SEQ_LEN, MAX_SEQ_LEN)

## ImmunoBERT assessment

In [None]:
model = None

In [None]:
# load ImmunoBERT model
if LOAD_IB_MODEL:
    pMHC.set_paths(IB_PROJ_PATH)
    MODEL_PATH = f"..{os.sep}data{os.sep}{IB_CHECKPOINT}.ckpt"
    model = PresentationPredictor.load_from_checkpoint(MODEL_PATH,
                                                       num_workers=0, shuffle_data=False, output_attentions=False)

    model.setup();
    model.to("cuda");
    model.eval();

In [None]:
# generate or load scores for peptide flank combinations
IB_peptide_scores = {}
IB_seq_presentation = defaultdict(lambda: 0)
IB_seq_scores = defaultdict(lambda: 0)

if LOAD_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "rb") as file:
        IB_peptide_scores = pickle.load(file)

In [None]:
# automatically adds peptide scores if they cannot be found to the peptide score dictionary
for seq in tqdm(ds.viral_seqs):             
    IB_seq_presentation[seq], IB_seq_scores[seq] = score_seq_IB(model, seq,  MHC_list, IB_peptide_scores)
        
if SAVE_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "wb") as file:
        pickle.dump(IB_peptide_scores, file)

In [None]:
to_delete = []
for key, values in IB_peptide_scores.items():
    if len(values) != len(MHC_list):
        to_delete.append(key)
        
for key in to_delete:
    del IB_peptide_scores[key]

In [None]:
h = np.array(list(IB_seq_scores.values()))
IB_seq_scores_p25 = np.percentile(h, 25)
IB_seq_scores_p75 = np.percentile(h, 75)

IB_seq_immuno_cat = {}
for seq in ds.viral_seqs:
    if IB_seq_scores[seq] < IB_seq_scores_p25:
        IB_seq_immuno_cat[seq] = 0
    elif IB_seq_scores[seq] < IB_seq_scores_p75:
        IB_seq_immuno_cat[seq] = 1
    else:
        IB_seq_immuno_cat[seq] = 2
        
print(f"IB_seq_scores_p25: {IB_seq_scores_p25:.5f} IB_seq_scores_p75: {IB_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(IB_seq_scores.values())):.5f}")

In [None]:
sns.distplot(np.array(list(IB_seq_presentation.values()))/len(MHC_list))

In [None]:
sns.distplot(np.array(list(IB_seq_scores.values()))/len(MHC_list))

## netMHCpan assessment

In [None]:
peptides_db = defaultdict(lambda: 0)
for key in IB_peptide_scores.keys():
    start = key.find("_") + 1
    peptides_db[key[start:start+9]] += 1
    
peptides_db = list(peptides_db.keys())

In [None]:
file = open(f"{NMP_FOLDER_1}{os.sep}peptides_db.pep", "w")
for peptide in peptides_db:
      file.writelines([peptide, "\n"])
    
file.close()

In [None]:
for mhc_name in MHC_list:
    # print(mhc_name)
    mhc_name_2 = mhc_name.replace(":", "").replace("HLA-", "")
    
    print(f"./netMHCpan -p {NMP_FOLDER_2}/peptides_db.pep -a {mhc_name} > {NMP_FOLDER_2}/peptides_db_{mhc_name_2}.pep.out")
    print("\n")

In [None]:
# run the above on Linux

In [None]:
nMp_peptide_scores = eval_peptides_nMp("peptides_db", MHC_list)

In [None]:
nMp_seq_scores = defaultdict(lambda : 0)
for seq in tqdm(ds.viral_seqs):
    nMp_seq_scores[seq] = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)

In [None]:
h = np.array(list(nMp_seq_scores.values()))
nMp_seq_scores_p25 = np.percentile(h, 25)
nMp_seq_scores_p75 = np.percentile(h, 75)

nMp_seq_immuno_cat = {}
for seq in ds.viral_seqs:
    if nMp_seq_scores[seq] < nMp_seq_scores_p25:
        nMp_seq_immuno_cat[seq] = 0
    elif nMp_seq_scores[seq] < nMp_seq_scores_p75:
        nMp_seq_immuno_cat[seq] = 1
    else:
        nMp_seq_immuno_cat[seq] = 2
        
print(f"nMp_seq_scores_p25: {nMp_seq_scores_p25:.5f} nMp_seq_scores_p75: {nMp_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(nMp_seq_scores.values()))}")

## assign antigenicity category

In [None]:
if ANTIGENICITY == 1:
    ds.seq_immuno_cat = IB_seq_immuno_cat
elif ANTIGENICITY == 2:
    ds.seq_immuno_cat = nMp_seq_immuno_cat

# VAE

In [None]:
VAE = FcVAE(
        aa_dim = AA_ENC_DIM,
        sequence_len = MAX_SEQ_LEN,
        enc_int_dim = FC_ENC_INT_DIM,
        enc_out_dim = FC_ENC_OUT_DIM,
        latent_dim = FC_LATENT_DIM,
        dec_int_dim = FC_DEC_INT_DIM,
        kl_target = FC_KL_TARGET,
        lr = FC_LR,
        batch_size = FC_BATCH_SIZE,
        dropout = FC_DROPOUT,
        weight_decay = FC_WEIGHT_DECAY,
        conditional = 3
)

In [None]:
VAE.ds = ds

## training and loading

In [None]:
trainer = Trainer(max_epochs=FC_EPOCHS, gpus=1)

In [None]:
if FC_SAVE:
    trainer.fit(VAE)
    trainer.save_checkpoint(FC_SAVE)

In [None]:
if FC_LOAD:
    VAE = FcVAE.load_from_checkpoint(checkpoint_path=FC_LOAD)
    VAE = VAE.cuda()
    VAE.ds = ds

In [None]:
VAE.eval()

In [None]:
j = -1

In [None]:
j += 1
h = ds.tok.decode(
        VAE.forward(
            ds.tok.tokenize(ds.viral_seqs[j][:MAX_SEQ_LEN]).unsqueeze(dim=0).to(VAE.device),
            torch.tensor(ds.seq_immuno_cat_tokens[ds.seq_immuno_cat[ds.viral_seqs[j]]]).unsqueeze(dim=0).to(VAE.device),
            sample=False).reshape(1, MAX_SEQ_LEN, -1)
)
alignments = pairwise2.align.globalxx(ds.viral_seqs[j][:MAX_SEQ_LEN], h[0])
print(format_alignment(*alignments[0]))

## latent space

In [None]:
Mu_Xs, Mu_Ys, Latent_Xs, Latent_Ys = get_latent_from_seq_FcVAE(VAE, ds.viral_seqs)

In [None]:
plt.scatter(Mu_Xs, Mu_Ys, s=1)

In [None]:
plt.scatter(Latent_Xs, Latent_Ys, s=1)

In [None]:
plt.rc('font', size=12)
matplotlib.rcParams.update({'font.size': 22})
fig = plt.figure(figsize=(40, 20))
ax = fig.subplots(1, 4)
ax[0].scatter(Mu_Xs, Mu_Ys, s=1)
ax[1].scatter(Latent_Xs, Latent_Ys, s=1)
ax[2].hist(Latent_Xs)
ax[3].hist(Latent_Ys)

In [None]:
print(f"Var Latent X: {np.var(Latent_Xs):.4f}")
print(f"Var Latent Y: {np.var(Latent_Ys):.4f}")

print(f"Var Mu X: {np.var(Mu_Xs):.4f}")
print(f"Var Mu Y: {np.var(Mu_Ys):.4f}")

In [None]:
np.cov(Latent_Xs, Latent_Ys)

## generate new sequences

In [None]:
generated_seqs = [None, None, None]
generated_seqs_new = [None, None, None]
antigenicity_names = ["low", "intermediate", "high"]

In [None]:
# generate random latent variables
p = torch.distributions.Normal(
    torch.zeros(FC_LATENT_DIM, device=VAE.device), 
    torch.ones(FC_LATENT_DIM, device=VAE.device))
Zs = p.sample(sample_shape=torch.Size([FC_SAMPLES])).to(VAE.device)
Zs = Zs @ torch.tensor(np.cov(Latent_Xs, Latent_Ys)).to(VAE.device).float()

In [None]:
np.cov(Latent_Xs, Latent_Ys)

### lowly antigenic

In [None]:
j = 0

In [None]:
generated_seqs[j] = get_seq_from_latent_FcVAE(VAE, Zs, j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds.viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}_all.fasta")
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}.fasta")
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

### intermediate antigenic

In [None]:
j = 1

In [None]:
generated_seqs[j] = get_seq_from_latent_FcVAE(VAE, Zs, j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds.viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}_all.fasta")
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}.fasta")
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

### highly antigenic

In [None]:
j = 2

In [None]:
generated_seqs[j] = get_seq_from_latent_FcVAE(VAE, Zs, j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds.viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}_all.fasta")
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{FC_GEN_FILENAME}_{antigenicity_names[j]}.fasta")
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

### evaluate antigenicity

#### with ImmunoBERT

In [None]:
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):             
        IB_seq_presentation[seq], IB_seq_scores[seq] = score_seq_IB(model, seq,  MHC_list, IB_peptide_scores)

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([IB_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])

#### with netMHCpan

In [None]:
missing = []
for j in range(3):
    for seq in tqdm(list(generated_seqs_new[j].keys())):
        seq = seq.replace("-", "")
        for position in range(len(seq)-9):
            if seq[position:(position+9)] not in nMp_peptide_scores:
                missing.append(seq[position:(position+9)])
                
file = open(f"{NMP_FOLDER_1}{os.sep}missing.pep", "w")
for peptide in missing:
      file.writelines([peptide, "\n"])
file.close()

for mhc_name in MHC_list:
    # print(mhc_name)
    mhc_name_2 = mhc_name.replace(":", "").replace("HLA-", "")
    
    print(f"./netMHCpan -p {NMP_FOLDER_2}/missing.pep -a {mhc_name} > {NMP_FOLDER_2}/missing_{mhc_name_2}.pep.out")
    print("\n")

In [None]:
nMp_peptide_scores.update(eval_peptides_nMp("missing", MHC_list))

In [None]:
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):
        nMp_seq_scores[seq] = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([nMp_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])