# Control

## imports

In [None]:
# general imports
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import json
import os
import pickle

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from datetime import datetime

# biopython
import Bio
from Bio import SeqIO
from Bio import pairwise2
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.pairwise2 import format_alignment
from Bio.SubsMat import MatrixInfo as matlist

# pytorch
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F

# pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

import optuna

# ImmunoBERT
import pMHC
from pMHC.logic import PresentationPredictor
from pMHC.data import MhcAllele
from pMHC import SEP, \
    SPLITS, SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST, \
    VIEWS, VIEW_SA, VIEW_SAMA, VIEW_DECONV, \
    INPUT_PEPTIDE, INPUT_CONTEXT
from pMHC.data.utils import convert_example_to_batch, move_dict_to_device, get_input_rep_PSEUDO


# generative model
import SpikeOracle
from SpikeOracle import PHASE_TRAIN, PHASE_VALID, PHASE_TEST
from SpikeOracle.data import StandardDataset
from SpikeOracle.presentation_scoring.IB import score_seq_IB
from SpikeOracle.presentation_scoring.nMp import eval_peptides_nMp, score_seq_nMp
from SpikeOracle.models.VAE.fc import FcVAE
from SpikeOracle.models.VAE.conv import ConvVAE
from SpikeOracle.latent import get_latent_from_seq_FcVAE, get_seq_from_latent_FcVAE
from SpikeOracle.utils import write_seqs_to_fasta, calc_entropy_vector

## constants

### notebook control

In [None]:
CREATE_DATA_SPLIT = False

RUN_HYP_PARAM_SRCH = False
RUN_TRAINING = False
RUN_GEN_SEQS = False

VERSION = 1
CKPT = "epoch=24-step=25424.ckpt"
CKPT = "epoch=99-step=101699.ckpt"

LOAD_IB_MODEL = True
LOAD_IB_PEPTIDE_SCORES = True
SAVE_IB_PEPTIDE_SCORES = True
CALIB_IB = False

ANTIGENICITY = 2  # 1... ImmunoBERT, 2... netMHCpan

FC_EPOCHS = 100
FC_SAMPLES = 50000

In [None]:
# encoder parameters
FC_BLOCKS = 5
FC_HIDDEN_DIM = 2048 # intermediate dimensions of the encoder

# latent space
FC_LATENT_DIM = 2   # dimensions of the latent space

# VAE parameters
FC_KL_TARGET = 0.1    # value of the KL divergence in the loss function

# training parameters
FC_LR = 3e-4         # the learning rate
FC_BATCH_SIZE = 64   # batch size
FC_DROPOUT = 0.05
FC_WEIGHT_DECAY = 1e-6 # 3e-5

# model and data
MODEL_NAME = "FC_004"
FILENAME_TRAIN = f"..{os.sep}data{os.sep}spikeprot_train.txt"
FILENAME_VALID = f"..{os.sep}data{os.sep}spikeprot_valid.txt"
FILENAME_TEST = f"..{os.sep}data{os.sep}spikeprot_test.txt"

In [None]:
FILENAME_FASTA = f"..{os.sep}data{os.sep}spikeprot_final_dataset.afa"

# data constants
SEQ_LEN = 1299 # 1449 # 1282  # 18   # restricted to 1271 aa sequence lenghts
MAX_SEQ_LEN = SEQ_LEN
AA_ENC_DIM = 21   # count of amino acid encoding dimensions
SEP = os.sep
IMMUNO_CATS = 3

# most relevant MHC alleles
MHC_list = ["HLA-A01:01", "HLA-A02:01", "HLA-A03:01", "HLA-A24:02", "HLA-A26:01",
            "HLA-B07:02", "HLA-B08:01", "HLA-B27:05", "HLA-B39:01", "HLA-B40:01", "HLA-B58:01", "HLA-B15:01"]

# ImmunoBERT
IB_VERSION = "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001"
IB_CHECKPOINT = "epoch=4-step=3648186"
IB_PROJ_PATH = r"C:\Users\s2118339\Documents\MSc_AI_Thesis_final\MScProject"

# netMHCpan
NMP_FOLDER_1 = f"..{os.sep}netMHCpan"
NMP_FOLDER_2 = r"~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan"

# Dataset

In [None]:
# load dataset
if CREATE_DATA_SPLIT:
    ds = StandardDataset(SEQ_LEN, MAX_SEQ_LEN)
    ds.load_from_fasta(FILENAME_FASTA)
    
    print(f"Fasta len: {len(ds.viral_seqs)}")
    
    ds_val = ds.splitoff(2000)
    ds_test = ds.splitoff(2000)
    
    print(f"Train len: {len(ds.viral_seqs)}")
    print(f"Valid len: {len(ds_val.viral_seqs)}")
    print(f"Test len: {len(ds_test.viral_seqs)}")
    
    ds.save_to_file(FILENAME_TRAIN)
    ds_val.save_to_file(FILENAME_VALID)
    ds_test.save_to_file(FILENAME_TEST)

else:
    ds = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_TRAIN)
    ds_val = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_VALID)
    ds_test = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_TEST)

In [None]:
for seq in tqdm(ds_val.viral_seqs):
    if seq in ds.viral_seqs:
        print(f"Error: {seq}")
        
for seq in tqdm(ds_test.viral_seqs):
    if seq in ds.viral_seqs:
        print(f"Error: {seq}")

In [None]:
all_viral_seqs = ds.viral_seqs + ds_val.viral_seqs + ds_test.viral_seqs
len(all_viral_seqs)

In [None]:
len(ds.viral_seqs)

## ImmunoBERT assessment

In [None]:
# can be found under https://github.com/hcgasser/ImmunoBERT

In [None]:
model = None

In [None]:
# load ImmunoBERT model
if LOAD_IB_MODEL:
    pMHC.set_paths(IB_PROJ_PATH)
    MODEL_PATH = f"..{os.sep}data{os.sep}{IB_CHECKPOINT}.ckpt"
    model = PresentationPredictor.load_from_checkpoint(MODEL_PATH,
                                                       num_workers=0, shuffle_data=False, output_attentions=False)

    model.setup();
    model.to("cuda");
    model.eval();

In [None]:
if CALIB_IB:
    rand_peptides = defaultdict(lambda: {})
    for j in tqdm(range(10000)):
        rand_peptide = "".join([ds.tok.dec_dict[x] for x in np.random.choice(range(1, 21), 9)])
        for mhc_name in MHC_list:
            example = get_input_rep_PSEUDO("", rand_peptide, "", 
                MhcAllele.mhc_alleles[mhc_name].pseudo_seq, model)
            pred = float(torch.sigmoid(model(move_dict_to_device(convert_example_to_batch(example), model))))
            rand_peptides[mhc_name][rand_peptide] = pred
            
    IB_weak_antigenic_threshold = {}
    IB_strong_antigenic_threshold = {}
    for mhc_name in MHC_list:
        IB_weak_antigenic_threshold.update({mhc_name: np.quantile(list(rand_peptides[mhc_name].values()), 0.98)})
        IB_strong_antigenic_threshold.update({mhc_name: np.quantile(list(rand_peptides[mhc_name].values()), 0.995)})
        
    with open(f"..{os.sep}data{os.sep}IB_weak_antigenic_threshold.pickle", "wb") as file:
        pickle.dump(IB_weak_antigenic_threshold, file)
        
    with open(f"..{os.sep}data{os.sep}IB_strong_antigenic_threshold.pickle", "wb") as file:
        pickle.dump(IB_strong_antigenic_threshold, file)
        
else:
    with open(f"..{os.sep}data{os.sep}IB_weak_antigenic_threshold.pickle", "rb") as file:
        IB_weak_antigenic_threshold = pickle.load(file)
        
    with open(f"..{os.sep}data{os.sep}IB_strong_antigenic_threshold.pickle", "rb") as file:
        IB_strong_antigenic_threshold = pickle.load(file)

In [None]:
IB_weak_antigenic_threshold

In [None]:
IB_strong_antigenic_threshold

In [None]:
# generate or load scores for peptide flank combinations
IB_peptide_scores = {}
IB_seq_scores_50 = defaultdict(lambda: 0)
IB_seq_scores_weak = defaultdict(lambda: 0)
IB_seq_scores_strong = defaultdict(lambda: 0)
IB_seq_avg_scores = defaultdict(lambda: 0)


if LOAD_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "rb") as file:
        IB_peptide_scores = pickle.load(file)

In [None]:
# automatically adds peptide scores if they cannot be found to the peptide score dictionary
for seq in tqdm(all_viral_seqs):             
    IB_seq_scores_50[seq], IB_seq_scores_weak[seq], IB_seq_scores_strong[seq], IB_seq_avg_scores[seq] = \
        score_seq_IB(model, seq,  MHC_list, IB_peptide_scores,
                     weak_antigenic_threshold=IB_weak_antigenic_threshold, 
                     strong_antigenic_threshold=IB_strong_antigenic_threshold)
        
if SAVE_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "wb") as file:
        pickle.dump(IB_peptide_scores, file)

In [None]:
to_delete = []
for key, values in IB_peptide_scores.items():
    if len(values) != len(MHC_list):
        to_delete.append(key)
        
for key in to_delete:
    del IB_peptide_scores[key]

In [None]:
IB_seq_scores = IB_seq_scores_weak

In [None]:
h = np.array(list(IB_seq_scores.values()))
IB_seq_scores_p25 = np.percentile(h, 25)
IB_seq_scores_p75 = np.percentile(h, 75)

IB_seq_immuno_cat = {}
for seq in all_viral_seqs:
    if IB_seq_scores[seq] < IB_seq_scores_p25:
        IB_seq_immuno_cat[seq] = 0
    elif IB_seq_scores[seq] < IB_seq_scores_p75:
        IB_seq_immuno_cat[seq] = 1
    else:
        IB_seq_immuno_cat[seq] = 2
        
print(f"IB_seq_scores_p25: {IB_seq_scores_p25:.5f} IB_seq_scores_p75: {IB_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(IB_seq_scores.values())):.5f}")

In [None]:
sns.distplot(np.array(list(IB_seq_scores.values())))

## netMHCpan assessment

In [None]:
peptides_db = defaultdict(lambda: 0)
for key in IB_peptide_scores.keys():
    start = key.find("_") + 1
    peptides_db[key[start:start+9]] += 1
    
peptides_db = list(peptides_db.keys())

In [None]:
file = open(f"{NMP_FOLDER_1}{os.sep}peptides_db.pep", "w")
for peptide in peptides_db:
      file.writelines([peptide, "\n"])
    
file.close()

In [None]:
# run shell script in Linux

# folder=/home/tux/win/2022H1/Group_project/CovidProject/netMHCpan
# for mhc in A01:01 A02:01 A03:01 A24:02 A26:01 B07:02 B08:01 B27:05 B39:01 B40:01 B58:01 B15:01
# do
# 	./netMHCpan -p $folder/peptides_db.pep -a HLA-$mhc > $folder/peptides_db_${mhc:0:3}${mhc:4:2}.pep.out	
# done

In [None]:
nMp_peptide_scores = eval_peptides_nMp("peptides_db", MHC_list);

In [None]:
nMp_seq_scores = defaultdict(lambda : 0)
nMp_epitopes_db = set()
for seq in tqdm(all_viral_seqs):
    nMp_seq_scores[seq], epitopes = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)
    nMp_epitopes_db = nMp_epitopes_db.union(epitopes)

In [None]:
len(nMp_epitopes_db)

In [None]:
h = np.array(list(nMp_seq_scores.values()))
nMp_seq_scores_p25 = np.percentile(h, 25)
nMp_seq_scores_p75 = np.percentile(h, 75)

nMp_seq_immuno_cat = {}
for seq in all_viral_seqs:
    if nMp_seq_scores[seq] < nMp_seq_scores_p25:
        nMp_seq_immuno_cat[seq] = 0
    elif nMp_seq_scores[seq] < nMp_seq_scores_p75:
        nMp_seq_immuno_cat[seq] = 1
    else:
        nMp_seq_immuno_cat[seq] = 2
        
print(f"nMp_seq_scores_p25: {nMp_seq_scores_p25:.5f} nMp_seq_scores_p75: {nMp_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(nMp_seq_scores.values()))}")

## assign antigenicity category

In [None]:
if ANTIGENICITY == 1:
    ds.seq_immuno_cat = IB_seq_immuno_cat
    ds_val.seq_immuno_cat = IB_seq_immuno_cat
    ds_test.seq_immuno_cat = IB_seq_immuno_cat
elif ANTIGENICITY == 2:
    ds.seq_immuno_cat = nMp_seq_immuno_cat
    ds_val.seq_immuno_cat = nMp_seq_immuno_cat
    ds_test.seq_immuno_cat = nMp_seq_immuno_cat

# FC VAE

## Hyperparameter search

In [None]:
HYP_EPOCHS = 25

In [None]:
def generate_seqs(VAE, n_seqs, antigenicity=0):
    hyp_latent_dim = VAE.latent_dim
    pl.seed_everything(42)
    
    # generate random latent variables
    p = torch.distributions.Normal(
        torch.zeros(hyp_latent_dim, device=VAE.device), 
        torch.ones(hyp_latent_dim, device=VAE.device))

    mus, log_vars, latents, cats = VAE.get_latent_from_seq(VAE.ds[PHASE_VALID].viral_seqs)
    
    Zs = p.sample(sample_shape=torch.Size([n_seqs])).to(VAE.device)
    latents = torch.vstack(latents)
    Zs = Zs @ torch.cov(latents.t()).to(VAE.device).float()

    generated_seqs = VAE.get_seq_from_latent(Zs, antigenicity) # generate low antigenic sequences
    return generated_seqs
    

def define_model(trial):
    hyp_blocks = trial.suggest_int("blocks", 2, 7)
    hyp_hidden = trial.suggest_categorical("hidden_dim", [1024*2, 1024, 512])
    hyp_latent_dim = trial.suggest_int("latent_dim", 2, 50)
    hyp_dropout = trial.suggest_float("dropout", 0.05, 0.5)
    hyp_kl_target = trial.suggest_float("kl_target", 0.01, 1.0)
    
    hyp_VAE = FcVAE(
            aa_dim = AA_ENC_DIM,
            sequence_len = MAX_SEQ_LEN,
            blocks = hyp_blocks,
            hidden_dim = hyp_hidden,
            hidden_dim_scaling_factor=(0.5, 2.0),
            latent_dim = hyp_latent_dim,
            conditional = 3,
            dropout = hyp_dropout,
            kl_target = hyp_kl_target,
            lr = FC_LR,
            batch_size = FC_BATCH_SIZE,
            weight_decay = FC_WEIGHT_DECAY
    )
    
    hyp_VAE.ds[PHASE_TRAIN] = ds
    hyp_VAE.ds[PHASE_VALID] = ds_val
    hyp_VAE.ds[PHASE_TEST] = ds_test
    
    return hyp_VAE, hyp_blocks, hyp_hidden, hyp_latent_dim, hyp_kl_target, hyp_dropout

def objective(trial):
    hyp_VAE, hyp_blocks, hyp_hidden, hyp_latent_dim, hyp_kl_target, hyp_dropout = define_model(trial)
    
    experiment_name = f"OPTUNA-LATENT_DIM-{hyp_latent_dim}-BLOCKS-{hyp_blocks}-HIDDEN-{hyp_hidden}-KL_TARGET-{hyp_kl_target:.3f}-DROPOUT-{hyp_dropout:.3f}"
    
    logger = TensorBoardLogger("tb_logs", name=experiment_name)    
    trainer = Trainer(max_epochs=HYP_EPOCHS, gpus=1, logger=logger)
    
    trainer.fit(hyp_VAE)
    # trainer.save_checkpoint(f"..{os.sep}models{os.sep}{experiment_name}.ckpt")
    
    generated_seqs = generate_seqs(hyp_VAE, 100)
    gen_seqs = []
    for seq, cnt in generated_seqs.items():
        gen_seqs += [seq] * cnt
        
    ev = calc_entropy_vector(gen_seqs, hyp_VAE.ds[PHASE_TRAIN].tok.aa_to_idx)
    
    evaluation_score = torch.norm(torch.tensor(ev - ev_train))
    
    return evaluation_score

In [None]:
study = optuna.create_study(
    study_name="hyp", storage='sqlite:///hyp.db', load_if_exists=True
)

In [None]:
if RUN_HYP_PARAM_SRCH:
    ev_train = calc_entropy_vector(ds.viral_seqs, ds.tok.aa_to_idx)
    study.optimize(objective, n_trials=5)

In [None]:
len(study.trials)

In [None]:
for t in study.trials:
    print(t.value)

In [None]:
trial = study.best_trial

In [None]:
for key, value in trial.params.items():
    print(f" {key:<20s}: {value}")

In [None]:
trial.value

In [None]:
optuna.visualization.plot_param_importances(study)

In [None]:
optuna.visualization.plot_slice(study)

In [None]:
optuna.visualization.plot_contour(study)

## training and loading 

In [None]:
LOG_PATH = f".{os.sep}tb_logs"

In [None]:
def url_from_trial(trial, version, ckpt):
    for key, value in trial.params.items():
        print(f" {key:<20s}: {value}")
    
    hyp_latent_dim = trial.params['latent_dim']
    hyp_blocks = trial.params['blocks']
    hyp_hidden = trial.params['hidden_dim']
    hyp_kl_target = trial.params['kl_target']
    hyp_dropout = trial.params['dropout']
    experiment_name = f"OPTUNA-LATENT_DIM-{hyp_latent_dim}-BLOCKS-{hyp_blocks}-HIDDEN-{hyp_hidden}-KL_TARGET-{hyp_kl_target:.3f}-DROPOUT-{hyp_dropout:.3f}"

    url = f"{LOG_PATH}{os.sep}{experiment_name}{os.sep}version_{version}{os.sep}checkpoints{os.sep}{ckpt}"
    
    return url, experiment_name

In [None]:
url, experiment_name = url_from_trial(trial, VERSION, CKPT)
logger = TensorBoardLogger("tb_logs", name=experiment_name)
VAE = FcVAE.load_from_checkpoint(checkpoint_path=url) #f"..{os.sep}models{os.sep}{MODEL_NAME}.ckpt")
VAE = VAE.cuda()
VAE.ds = [ds, ds_val, ds_test]

if RUN_TRAINING:
    trainer = pl.Trainer(gpus=1,  logger=logger, max_epochs=FC_EPOCHS)
    trainer.fit(VAE, ckpt_path=url)

In [None]:
VAE.eval()

In [None]:
j = -1

In [None]:
j += 1
h = ds.tok.decode(
        VAE.forward(
            ds.tok.tokenize(ds.viral_seqs[j][:MAX_SEQ_LEN]).unsqueeze(dim=0).to(VAE.device),
            torch.tensor(ds.seq_immuno_cat_tokens[ds.seq_immuno_cat[ds.viral_seqs[j]]]).unsqueeze(dim=0).to(VAE.device),
            sample=False).reshape(1, MAX_SEQ_LEN, -1)
)
alignments = pairwise2.align.globalxx(ds.viral_seqs[j][:MAX_SEQ_LEN], h[0])
print(format_alignment(*alignments[0]))

## latent space

In [None]:
mus, log_vars, latents, cats = VAE.get_latent_from_seq(VAE.ds[PHASE_TRAIN].viral_seqs)

## generate new sequences

In [None]:
generated_seqs = [None, None, None]
generated_seqs_new = [None, None, None]
antigenicity_names = ["low", "medium", "high"]

### lowly antigenic

In [None]:
j = 0

In [None]:
if RUN_GEN_SEQS:
    generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

    generated_seqs_new[j] = {}
    for seq, cnt in generated_seqs[j].items():
        if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
            generated_seqs_new[j].update({seq: cnt})

    print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

In [None]:
filename_all = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta"
filename_new = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta"

if RUN_GEN_SEQS:
    write_seqs_to_fasta(generated_seqs[j], filename_all);
    write_seqs_to_fasta(generated_seqs_new[j], filename_new);
else:
    generated_seqs[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs[j].update({str(record.seq): int(record.id)})

    generated_seqs_new[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs_new[j].update({str(record.seq): int(record.id)})
    

### intermediate antigenic

In [None]:
j = 1

In [None]:
if RUN_GEN_SEQS:
    generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

    generated_seqs_new[j] = {}
    for seq, cnt in generated_seqs[j].items():
        if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
            generated_seqs_new[j].update({seq: cnt})

    print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

In [None]:
filename_all = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta"
filename_new = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta"

if RUN_GEN_SEQS:
    write_seqs_to_fasta(generated_seqs[j], filename_all);
    write_seqs_to_fasta(generated_seqs_new[j], filename_new);
else:
    generated_seqs[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs[j].update({str(record.seq): int(record.id)})

    generated_seqs_new[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs_new[j].update({str(record.seq): int(record.id)})
    

### highly antigenic

In [None]:
j = 2

In [None]:
if RUN_GEN_SEQS:
    generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

    generated_seqs_new[j] = {}
    for seq, cnt in generated_seqs[j].items():
        if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
            generated_seqs_new[j].update({seq: cnt})

    print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

In [None]:
filename_all = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta"
filename_new = f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta"

if RUN_GEN_SEQS:
    write_seqs_to_fasta(generated_seqs[j], filename_all);
    write_seqs_to_fasta(generated_seqs_new[j], filename_new);
else:
    generated_seqs[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs[j].update({str(record.seq): int(record.id)})

    generated_seqs_new[j] = {}
    for record in SeqIO.parse(filename_all, "fasta"):
        generated_seqs_new[j].update({str(record.seq): int(record.id)})
    

### evaluate antigenicity

#### with ImmunoBERT

In [None]:
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):             
        IB_seq_presentation[seq], IB_seq_scores[seq] = score_seq_IB(model, seq,  MHC_list, IB_peptide_scores)

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([IB_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])

#### with netMHCpan

In [None]:
missing = []
for j in range(3):
    for seq in tqdm(list(generated_seqs_new[j].keys())):
        seq = seq.replace("-", "")
        for position in range(len(seq)-9):
            if seq[position:(position+9)] not in nMp_peptide_scores:
                missing.append(seq[position:(position+9)])
                
file = open(f"{NMP_FOLDER_1}{os.sep}missing.pep", "w")
for peptide in missing:
      file.writelines([peptide, "\n"])
file.close()

In [None]:
# run shell scritp in linux

# folder=/home/tux/win/2022H1/Group_project/CovidProject/netMHCpan
# for mhc in A01:01 A02:01 A03:01 A24:02 A26:01 B07:02 B08:01 B27:05 B39:01 B40:01 B58:01 B15:01
# do
# 	./netMHCpan -p $folder/missing.pep -a HLA-$mhc > $folder/missing_${mhc:0:3}${mhc:4:2}.pep.out	
# done

In [None]:
nMp_peptide_scores.update(eval_peptides_nMp("missing", MHC_list));

In [None]:
nMp_epitopes_gen = [set(), set(), set()]
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):
        nMp_seq_scores[seq], epitopes = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)
        nMp_epitopes_gen[j] = nMp_epitopes_gen[j].union(epitopes)

In [None]:
for j, name in enumerate(antigenicity_names):
    print(f"{name}: antigenic epitopes - {len(nMp_epitopes_gen[j])} \t", end="")
    print(f"new antigenic epitopes - {len(nMp_epitopes_gen[j].difference(nMp_epitopes_db))}")

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([nMp_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])

In [None]:
for j in range(3):
    print(nMp_epitopes_gen[j].difference(nMp_epitopes_db))