# Control

## imports

In [1]:
# general imports
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import json
import os
import pickle

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from datetime import datetime

# biopython
import Bio
from Bio import SeqIO
from Bio import pairwise2
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.pairwise2 import format_alignment
from Bio.SubsMat import MatrixInfo as matlist

# pytorch
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F

# pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

import optuna

# ImmunoBERT
import pMHC
from pMHC.logic import PresentationPredictor
from pMHC.data import MhcAllele
from pMHC import SEP, \
    SPLITS, SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST, \
    VIEWS, VIEW_SA, VIEW_SAMA, VIEW_DECONV, \
    INPUT_PEPTIDE, INPUT_CONTEXT
from pMHC.data.utils import convert_example_to_batch, move_dict_to_device, get_input_rep_PSEUDO

# visualizations
from protein_map_visualisation_tools import generate_embedding_map_from_database

# generative model
import SpikeOracle
from SpikeOracle import PHASE_TRAIN, PHASE_VALID, PHASE_TEST
from SpikeOracle.data import StandardDataset
from SpikeOracle.presentation_scoring.IB import score_seq_IB
from SpikeOracle.presentation_scoring.nMp import eval_peptides_nMp, score_seq_nMp
from SpikeOracle.models.VAE.fc import FcVAE
from SpikeOracle.models.VAE.conv import ConvVAE
from SpikeOracle.latent import get_latent_from_seq_FcVAE, get_seq_from_latent_FcVAE
from SpikeOracle.utils import write_seqs_to_fasta, calc_entropy_vector



## constants

In [2]:
FILENAME_FASTA = f"..{os.sep}data{os.sep}spikeprot_final_dataset.afa"

# data constants
SEQ_LEN = 1299 # 1449 # 1282  # 18   # restricted to 1271 aa sequence lenghts
MAX_SEQ_LEN = SEQ_LEN
AA_ENC_DIM = 21   # count of amino acid encoding dimensions
SEP = os.sep
IMMUNO_CATS = 3

# most relevant MHC alleles
MHC_list = ["HLA-A01:01", "HLA-A02:01", "HLA-A03:01", "HLA-A24:02", "HLA-A26:01",
            "HLA-B07:02", "HLA-B08:01", "HLA-B27:05", "HLA-B39:01", "HLA-B40:01", "HLA-B58:01", "HLA-B15:01"]

# ImmunoBERT
IB_VERSION = "CONTEXT-PSEUDO-HEAD_Cls-DECOY_19-LR_0.00001"
IB_CHECKPOINT = "epoch=4-step=3648186"
IB_PROJ_PATH = r"C:\Users\s2118339\Documents\MSc_AI_Thesis_final\MScProject"

# netMHCpan
NMP_FOLDER_1 = f"..{os.sep}netMHCpan"
NMP_FOLDER_2 = r"~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan"

### fully connected VAE

In [129]:
# encoder parameters
FC_BLOCKS = 5
FC_HIDDEN_DIM = 2048 # intermediate dimensions of the encoder

# latent space
FC_LATENT_DIM = 2   # dimensions of the latent space

# VAE parameters
FC_KL_TARGET = 0.1    # value of the KL divergence in the loss function

# training parameters
FC_LR = 3e-4         # the learning rate
FC_BATCH_SIZE = 64   # batch size
FC_DROPOUT = 0.05
FC_WEIGHT_DECAY = 1e-6 # 3e-5

# model and data
MODEL_NAME = "FC_004"
FILENAME_TRAIN = f"..{os.sep}data{os.sep}spikeprot_train.txt"
FILENAME_VALID = f"..{os.sep}data{os.sep}spikeprot_valid.txt"
FILENAME_TEST = f"..{os.sep}data{os.sep}spikeprot_test.txt"

## notebook control

In [4]:
CREATE_DATA_SPLIT = False

LOAD_IB_MODEL = True
LOAD_IB_PEPTIDE_SCORES = True
SAVE_IB_PEPTIDE_SCORES = True
CALIB_IB = False

ANTIGENICITY = 2  # 1... ImmunoBERT, 2... netMHCpan

FC_EPOCHS = 100
FC_SAVE = True # "FC_003.ckpt" # "Fc_test" # None
FC_LOAD = True # "FC_003.ckpt" # "Fully_Eps_100_KlTgt_25e-2KL_Cdtl.ckpt"

FC_SAMPLES = 50000

# Dataset

In [5]:
# load dataset
if CREATE_DATA_SPLIT:
    ds = StandardDataset(SEQ_LEN, MAX_SEQ_LEN)
    ds.load_from_fasta(FILENAME_FASTA)
    
    print(f"Fasta len: {len(ds.viral_seqs)}")
    
    ds_val = ds.splitoff(2000)
    ds_test = ds.splitoff(2000)
    
    print(f"Train len: {len(ds.viral_seqs)}")
    print(f"Valid len: {len(ds_val.viral_seqs)}")
    print(f"Test len: {len(ds_test.viral_seqs)}")
    
    ds.save_to_file(FILENAME_TRAIN)
    ds_val.save_to_file(FILENAME_VALID)
    ds_test.save_to_file(FILENAME_TEST)

else:
    ds = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_TRAIN)
    ds_val = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_VALID)
    ds_test = StandardDataset(SEQ_LEN, MAX_SEQ_LEN, filename=FILENAME_TEST)

In [6]:
for seq in tqdm(ds_val.viral_seqs):
    if seq in ds.viral_seqs:
        print(f"Error: {seq}")
        
for seq in tqdm(ds_test.viral_seqs):
    if seq in ds.viral_seqs:
        print(f"Error: {seq}")

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [7]:
all_viral_seqs = ds.viral_seqs + ds_val.viral_seqs + ds_test.viral_seqs
len(all_viral_seqs)

69027

## ImmunoBERT assessment

In [8]:
model = None

In [None]:
# load ImmunoBERT model
if LOAD_IB_MODEL:
    pMHC.set_paths(IB_PROJ_PATH)
    MODEL_PATH = f"..{os.sep}data{os.sep}{IB_CHECKPOINT}.ckpt"
    model = PresentationPredictor.load_from_checkpoint(MODEL_PATH,
                                                       num_workers=0, shuffle_data=False, output_attentions=False)

    model.setup();
    model.to("cuda");
    model.eval();

In [None]:
if CALIB_IB:
    rand_peptides = defaultdict(lambda: {})
    for j in tqdm(range(10000)):
        rand_peptide = "".join([ds.tok.dec_dict[x] for x in np.random.choice(range(1, 21), 9)])
        for mhc_name in MHC_list:
            example = get_input_rep_PSEUDO("", rand_peptide, "", 
                MhcAllele.mhc_alleles[mhc_name].pseudo_seq, model)
            pred = float(torch.sigmoid(model(move_dict_to_device(convert_example_to_batch(example), model))))
            rand_peptides[mhc_name][rand_peptide] = pred
            
    IB_weak_antigenic_threshold = {}
    IB_strong_antigenic_threshold = {}
    for mhc_name in MHC_list:
        IB_weak_antigenic_threshold.update({mhc_name: np.quantile(list(rand_peptides[mhc_name].values()), 0.98)})
        IB_strong_antigenic_threshold.update({mhc_name: np.quantile(list(rand_peptides[mhc_name].values()), 0.995)})
        
    with open(f"..{os.sep}data{os.sep}IB_weak_antigenic_threshold.pickle", "wb") as file:
        pickle.dump(IB_weak_antigenic_threshold, file)
        
    with open(f"..{os.sep}data{os.sep}IB_strong_antigenic_threshold.pickle", "wb") as file:
        pickle.dump(IB_strong_antigenic_threshold, file)
        
else:
    with open(f"..{os.sep}data{os.sep}IB_weak_antigenic_threshold.pickle", "rb") as file:
        IB_weak_antigenic_threshold = pickle.load(file)
        
    with open(f"..{os.sep}data{os.sep}IB_strong_antigenic_threshold.pickle", "rb") as file:
        IB_strong_antigenic_threshold = pickle.load(file)

In [None]:
IB_weak_antigenic_threshold

In [None]:
IB_strong_antigenic_threshold

In [9]:
# generate or load scores for peptide flank combinations
IB_peptide_scores = {}
IB_seq_scores_50 = defaultdict(lambda: 0)
IB_seq_scores_weak = defaultdict(lambda: 0)
IB_seq_scores_strong = defaultdict(lambda: 0)
IB_seq_avg_scores = defaultdict(lambda: 0)


if LOAD_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "rb") as file:
        IB_peptide_scores = pickle.load(file)

In [None]:
# automatically adds peptide scores if they cannot be found to the peptide score dictionary
for seq in tqdm(all_viral_seqs):             
    IB_seq_scores_50[seq], IB_seq_scores_weak[seq], IB_seq_scores_strong[seq], IB_seq_avg_scores[seq] = \
        score_seq_IB(model, seq,  MHC_list, IB_peptide_scores,
                     weak_antigenic_threshold=IB_weak_antigenic_threshold, 
                     strong_antigenic_threshold=IB_strong_antigenic_threshold)
        
if SAVE_IB_PEPTIDE_SCORES:
    with open(f"..{os.sep}data{os.sep}IB_peptide_scores.pickle", "wb") as file:
        pickle.dump(IB_peptide_scores, file)

In [None]:
to_delete = []
for key, values in IB_peptide_scores.items():
    if len(values) != len(MHC_list):
        to_delete.append(key)
        
for key in to_delete:
    del IB_peptide_scores[key]

In [None]:
IB_seq_scores = IB_seq_scores_weak

In [None]:
h = np.array(list(IB_seq_scores.values()))
IB_seq_scores_p25 = np.percentile(h, 25)
IB_seq_scores_p75 = np.percentile(h, 75)

IB_seq_immuno_cat = {}
for seq in all_viral_seqs:
    if IB_seq_scores[seq] < IB_seq_scores_p25:
        IB_seq_immuno_cat[seq] = 0
    elif IB_seq_scores[seq] < IB_seq_scores_p75:
        IB_seq_immuno_cat[seq] = 1
    else:
        IB_seq_immuno_cat[seq] = 2
        
print(f"IB_seq_scores_p25: {IB_seq_scores_p25:.5f} IB_seq_scores_p75: {IB_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(IB_seq_scores.values())):.5f}")

In [None]:
sns.distplot(np.array(list(IB_seq_scores.values())))

## netMHCpan assessment

In [10]:
peptides_db = defaultdict(lambda: 0)
for key in IB_peptide_scores.keys():
    start = key.find("_") + 1
    peptides_db[key[start:start+9]] += 1
    
peptides_db = list(peptides_db.keys())

In [11]:
file = open(f"{NMP_FOLDER_1}{os.sep}peptides_db.pep", "w")
for peptide in peptides_db:
      file.writelines([peptide, "\n"])
    
file.close()

In [12]:
for mhc_name in MHC_list:
    # print(mhc_name)
    mhc_name_2 = mhc_name.replace(":", "").replace("HLA-", "")
    
    print(f"./netMHCpan -p {NMP_FOLDER_2}/peptides_db.pep -a {mhc_name} > {NMP_FOLDER_2}/peptides_db_{mhc_name_2}.pep.out")
    print("\n")

./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db.pep -a HLA-A01:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db_A0101.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db.pep -a HLA-A02:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db_A0201.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db.pep -a HLA-A03:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db_A0301.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db.pep -a HLA-A24:02 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db_A2402.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db.pep -a HLA-A26:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/peptides_db_A2601.pep.out


./netMHCpan -p ~/win/Docu

In [13]:
# run the above on Linux

In [None]:
nMp_peptide_scores = eval_peptides_nMp("peptides_db", MHC_list);

In [15]:
nMp_seq_scores = defaultdict(lambda : 0)
nMp_epitopes_db = set()
for seq in tqdm(all_viral_seqs):
    nMp_seq_scores[seq], epitopes = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)
    nMp_epitopes_db = nMp_epitopes_db.union(epitopes)

  0%|          | 0/69027 [00:00<?, ?it/s]

In [16]:
len(nMp_epitopes_db)

14732

In [17]:
h = np.array(list(nMp_seq_scores.values()))
nMp_seq_scores_p25 = np.percentile(h, 25)
nMp_seq_scores_p75 = np.percentile(h, 75)

nMp_seq_immuno_cat = {}
for seq in all_viral_seqs:
    if nMp_seq_scores[seq] < nMp_seq_scores_p25:
        nMp_seq_immuno_cat[seq] = 0
    elif nMp_seq_scores[seq] < nMp_seq_scores_p75:
        nMp_seq_immuno_cat[seq] = 1
    else:
        nMp_seq_immuno_cat[seq] = 2
        
print(f"nMp_seq_scores_p25: {nMp_seq_scores_p25:.5f} nMp_seq_scores_p75: {nMp_seq_scores_p75:.5f}")
print(f"mean: {np.mean(list(nMp_seq_scores.values()))}")

nMp_seq_scores_p25: 49.83333 nMp_seq_scores_p75: 50.33333
mean: 50.07670808804338


## assign antigenicity category

In [18]:
if ANTIGENICITY == 1:
    ds.seq_immuno_cat = IB_seq_immuno_cat
    ds_val.seq_immuno_cat = IB_seq_immuno_cat
    ds_test.seq_immuno_cat = IB_seq_immuno_cat
elif ANTIGENICITY == 2:
    ds.seq_immuno_cat = nMp_seq_immuno_cat
    ds_val.seq_immuno_cat = nMp_seq_immuno_cat
    ds_test.seq_immuno_cat = nMp_seq_immuno_cat

# FC VAE

## Hyperparameter search

In [19]:
HYP_EPOCHS = 25

In [20]:
ev_train = calc_entropy_vector(ds.viral_seqs, ds.tok.aa_to_idx)

  0%|          | 0/65027 [00:00<?, ?it/s]

In [125]:
def generate_seqs(VAE, n_seqs, antigenicity=0):
    hyp_latent_dim = VAE.latent_dim
    pl.seed_everything(42)
    
    # generate random latent variables
    p = torch.distributions.Normal(
        torch.zeros(hyp_latent_dim, device=VAE.device), 
        torch.ones(hyp_latent_dim, device=VAE.device))

    mus, log_vars, latents, cats = VAE.get_latent_from_seq(VAE.ds[PHASE_VALID].viral_seqs)
    
    Zs = p.sample(sample_shape=torch.Size([n_seqs])).to(VAE.device)
    latents = torch.vstack(latents)
    Zs = Zs @ torch.cov(latents.t()).to(VAE.device).float()

    generated_seqs = VAE.get_seq_from_latent(Zs, antigenicity) # generate low antigenic sequences
    return generated_seqs
    

def define_model(trial):
    hyp_blocks = trial.suggest_int("blocks", 2, 7)
    hyp_hidden = trial.suggest_categorical("hidden_dim", [1024*2, 1024, 512])
    hyp_latent_dim = trial.suggest_int("latent_dim", 2, 50)
    hyp_dropout = trial.suggest_float("dropout", 0.05, 0.5)
    hyp_kl_target = trial.suggest_float("kl_target", 0.01, 1.0)
    
    hyp_VAE = FcVAE(
            aa_dim = AA_ENC_DIM,
            sequence_len = MAX_SEQ_LEN,
            blocks = hyp_blocks,
            hidden_dim = hyp_hidden,
            hidden_dim_scaling_factor=(0.5, 2.0),
            latent_dim = hyp_latent_dim,
            conditional = 3,
            dropout = hyp_dropout,
            kl_target = hyp_kl_target,
            lr = FC_LR,
            batch_size = FC_BATCH_SIZE,
            weight_decay = FC_WEIGHT_DECAY
    )
    
    hyp_VAE.ds[PHASE_TRAIN] = ds
    hyp_VAE.ds[PHASE_VALID] = ds_val
    hyp_VAE.ds[PHASE_TEST] = ds_test
    
    return hyp_VAE, hyp_blocks, hyp_hidden, hyp_latent_dim, hyp_kl_target, hyp_dropout

def objective(trial):
    hyp_VAE, hyp_blocks, hyp_hidden, hyp_latent_dim, hyp_kl_target, hyp_dropout = define_model(trial)
    
    experiment_name = f"OPTUNA-LATENT_DIM-{hyp_latent_dim}-BLOCKS-{hyp_blocks}-HIDDEN-{hyp_hidden}-KL_TARGET-{hyp_kl_target:.3f}-DROPOUT-{hyp_dropout:.3f}"
    
    logger = TensorBoardLogger("tb_logs", name=experiment_name)    
    trainer = Trainer(max_epochs=HYP_EPOCHS, gpus=1, logger=logger)
    
    trainer.fit(hyp_VAE)
    # trainer.save_checkpoint(f"..{os.sep}models{os.sep}{experiment_name}.ckpt")
    
    generated_seqs = generate_seqs(hyp_VAE, 100)
    gen_seqs = []
    for seq, cnt in generated_seqs.items():
        gen_seqs += [seq] * cnt
        
    ev = calc_entropy_vector(gen_seqs, hyp_VAE.ds[PHASE_TRAIN].tok.aa_to_idx)
    
    evaluation_score = torch.norm(torch.tensor(ev - ev_train))
    
    return evaluation_score

In [None]:
study = optuna.create_study(
    study_name="hyp", storage='sqlite:///hyp.db', load_if_exists=True
)

In [81]:
study.optimize(objective, n_trials=5)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 1.8 K 
2 | fc_var  | Linear    | 1.8 K 
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.463   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 42


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[32m[I 2022-03-15 21:50:18,068][0m Trial 20 finished with value: 2.0298171813248977 and parameters: {'blocks': 4, 'hidden_dim': 512, 'latent_dim': 27, 'dropout': 0.1754304076205893, 'kl_target': 0.4843286875028315}. Best is trial 14 with value: 1.660719525956095.[0m
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 272   
2 | fc_var  | Linear    | 272   
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.492   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 42


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[32m[I 2022-03-15 23:06:02,595][0m Trial 21 finished with value: 2.4618704656285577 and parameters: {'blocks': 6, 'hidden_dim': 512, 'latent_dim': 16, 'dropout': 0.4985363203463385, 'kl_target': 0.12888934522406867}. Best is trial 14 with value: 1.660719525956095.[0m
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 1.6 K 
2 | fc_var  | Linear    | 1.6 K 
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.460   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 42


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[32m[I 2022-03-16 00:21:56,047][0m Trial 22 finished with value: 1.6823502619133153 and parameters: {'blocks': 4, 'hidden_dim': 512, 'latent_dim': 24, 'dropout': 0.44724701227469477, 'kl_target': 0.260669212299434}. Best is trial 14 with value: 1.660719525956095.[0m
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 1.9 K 
2 | fc_var  | Linear    | 1.9 K 
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.465   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 42


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[32m[I 2022-03-16 01:33:30,585][0m Trial 23 finished with value: 1.5620869453825963 and parameters: {'blocks': 4, 'hidden_dim': 512, 'latent_dim': 29, 'dropout': 0.41661403667807795, 'kl_target': 0.23499362217791947}. Best is trial 23 with value: 1.5620869453825963.[0m
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 990   
2 | fc_var  | Linear    | 990   
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.493   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 42


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[32m[I 2022-03-16 02:36:40,947][0m Trial 24 finished with value: 1.4863442227828325 and parameters: {'blocks': 5, 'hidden_dim': 512, 'latent_dim': 30, 'dropout': 0.3941869036221446, 'kl_target': 0.2315779299781796}. Best is trial 24 with value: 1.4863442227828325.[0m


In [82]:
len(study.trials)

25

In [83]:
trial = study.best_trial

In [84]:
for key, value in trial.params.items():
    print(f" {key:<20s}: {value}")

 blocks              : 5
 dropout             : 0.3941869036221446
 hidden_dim          : 512
 kl_target           : 0.2315779299781796
 latent_dim          : 30


In [87]:
trial.value

1.4863442227828325

In [85]:
optuna.visualization.plot_param_importances(study)

In [86]:
optuna.visualization.plot_slice(study)

In [55]:
optuna.visualization.plot_contour(study)

## training and loading 

In [91]:
LOG_PATH = f".{os.sep}tb_logs"
CKPT = "epoch=24-step=25424.ckpt"

In [105]:
def url_from_trial(trial):
    for key, value in trial.params.items():
        print(f" {key:<20s}: {value}")
    
    hyp_latent_dim = trial.params['latent_dim']
    hyp_blocks = trial.params['blocks']
    hyp_hidden = trial.params['hidden_dim']
    hyp_kl_target = trial.params['kl_target']
    hyp_dropout = trial.params['dropout']
    experiment_name = f"OPTUNA-LATENT_DIM-{hyp_latent_dim}-BLOCKS-{hyp_blocks}-HIDDEN-{hyp_hidden}-KL_TARGET-{hyp_kl_target:.3f}-DROPOUT-{hyp_dropout:.3f}"

    url = f"{LOG_PATH}{os.sep}{experiment_name}{os.sep}version_0{os.sep}checkpoints{os.sep}{CKPT}"
    
    return url, experiment_name

In [None]:
if FC_SAVE:
    VAE = FcVAE(
            aa_dim = AA_ENC_DIM,
            sequence_len = MAX_SEQ_LEN,
            blocks = FC_BLOCKS,
            hidden_dim = FC_HIDDEN_DIM,
            hidden_dim_scaling_factor=(0.5, 2.0),
            latent_dim = FC_LATENT_DIM,
            conditional = 3,
            dropout = FC_DROPOUT,
            kl_target = FC_KL_TARGET,
            lr = FC_LR,
            batch_size = FC_BATCH_SIZE,
            weight_decay = FC_WEIGHT_DECAY
    )
    
    VAE.ds = ds
    trainer = Trainer(max_epochs=FC_EPOCHS, gpus=1)
    trainer.fit(VAE)
    trainer.save_checkpoint(f"..{os.sep}models{os.sep}{MODEL_NAME}.ckpt")

In [114]:
if FC_LOAD:
    logger = TensorBoardLogger("tb_logs", name=experiment_name) 
    url, experiment_name = url_from_trial(trial)
    VAE = FcVAE.load_from_checkpoint(checkpoint_path=url) #f"..{os.sep}models{os.sep}{MODEL_NAME}.ckpt")
    VAE = VAE.cuda()
    VAE.ds = [ds, ds_val, ds_test]
    trainer = pl.Trainer(gpus=1,  logger=logger, max_epochs=FC_EPOCHS)
    trainer.fit(VAE, ckpt_path=url)

 blocks              : 5
 dropout             : 0.3941869036221446
 hidden_dim          : 512
 kl_target           : 0.2315779299781796
 latent_dim          : 30


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Restoring states from the checkpoint path at .\tb_logs\OPTUNA-LATENT_DIM-30-BLOCKS-5-HIDDEN-512-KL_TARGET-0.232-DROPOUT-0.394\version_0\checkpoints\epoch=24-step=25424.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

You're resuming from a checkpoint that ended mid-epoch. Training will start from the beginning of the next epoch. This can cause unreliable results if further training is done, consider using an end of epoch checkpoint.

Restored all states from the checkpoint file at .\tb_logs\OPTUNA-LATENT_DIM-30-BLOCKS-5-HIDDEN-512-KL_TARGET-0.232-DROPOUT-0.394\version_0\checkpoints\epoch=24-step=25424.ckpt

  | Name    | Type      | Params
--------------------------------------
0 | encoder | FcNetwork | 14.1 M
1 | fc_mu   | Linear    | 990   
2 | fc_var  | Linear    | 990   
3 | decoder | FcNetwork | 14.2 M
--------------------------------------
28.4 M    Trainable params
0      

Validation sanity check: 0it [00:00, ?it/s]


The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

Global seed set to 42

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [115]:
VAE.eval()

FcVAE(
  (encoder): FcNetwork(
    (network): Sequential(
      (block_0): Sequential(
        (0): Linear(in_features=27282, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.1)
        (3): Dropout(p=0.3941869036221446, inplace=False)
      )
      (block_1): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.1)
        (3): Dropout(p=0.3941869036221446, inplace=False)
      )
      (block_2): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.1)
        (3): Dropout(p=0.3941869036221446, inplace=False)
      )
      (block_3): Sequential(
        (0): Line

In [116]:
j = -1

In [117]:
j += 1
h = ds.tok.decode(
        VAE.forward(
            ds.tok.tokenize(ds.viral_seqs[j][:MAX_SEQ_LEN]).unsqueeze(dim=0).to(VAE.device),
            torch.tensor(ds.seq_immuno_cat_tokens[ds.seq_immuno_cat[ds.viral_seqs[j]]]).unsqueeze(dim=0).to(VAE.device),
            sample=False).reshape(1, MAX_SEQ_LEN, -1)
)
alignments = pairwise2.align.globalxx(ds.viral_seqs[j][:MAX_SEQ_LEN], h[0])
print(format_alignment(*alignments[0]))

MFV-F-LVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHVI--SGTNGTKRFDNPVLPFNDGVYFASIEKSNIIRGWIFGT--TLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLD-----HKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINL--V---R------D---LPQGFSALEPLVDLPIGINITRFQTLLALHRS---YLTPGD-SSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFG-EVFNATRFASVYAWNRKRISNCVADYSVLYNS-AS-FS-TFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGK-IADYNYKLPDDFTGCVIAWNSNN-LDSKVG-GNYNYLYRLFRKSNLKPFERDISTEIYQAGST--PCNGVE-GFNCYFPLQ-SYG-FQ-PTN-GVGY-QPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLKGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQGVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEYVNNSYECDIPIGAGICASYQTQTK----SHRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLKRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKYFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFKGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNHNAQ

## latent space

In [147]:
import statsmodels

In [148]:
statsmodels.stats.diagnostic.normal_ad

<module 'statsmodels.stats.diagnostic' from 'C:\\Users\\s2118339\\anaconda3\\envs\\GroupProject\\lib\\site-packages\\statsmodels\\stats\\diagnostic.py'>

In [149]:
mus, log_vars, latents, cats = VAE.get_latent_from_seq(VAE.ds[PHASE_TRAIN].viral_seqs)

### remove outliers

In [None]:
np.cov(Latent_Xs, Latent_Ys)

## generate new sequences

In [136]:
generated_seqs_new = [None, None, None]
antigenicity_names = ["low", "medium", "high"]

### lowly antigenic

In [126]:
j = 0

In [137]:
generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

Generated: 948 New: 897


In [139]:
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta");
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta");

### intermediate antigenic

In [140]:
j = 1

In [141]:
generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

Global seed set to 42


  0%|          | 0/50000 [00:00<?, ?it/s]

Generated: 1331 New: 1154


In [142]:
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta");
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta");

### highly antigenic

In [143]:
j = 2

In [144]:
generated_seqs[j] = generate_seqs(VAE, FC_SAMPLES, antigenicity=j)

generated_seqs_new[j] = {}
for seq, cnt in generated_seqs[j].items():
    if seq not in VAE.ds[PHASE_TRAIN].viral_seqs:
        generated_seqs_new[j].update({seq: cnt})
        
print(f"Generated: {len(generated_seqs[j])} New: {len(generated_seqs_new[j])}")

Global seed set to 42


  0%|          | 0/50000 [00:00<?, ?it/s]

Generated: 1323 New: 1138


In [152]:
write_seqs_to_fasta(generated_seqs[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}_all.fasta");
write_seqs_to_fasta(generated_seqs_new[j], f"..{os.sep}data{os.sep}spike_protein_sequences{os.sep}{MODEL_NAME}_gen_{antigenicity_names[j]}.fasta");

### evaluate antigenicity

#### with ImmunoBERT

In [None]:
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):             
        IB_seq_presentation[seq], IB_seq_scores[seq] = score_seq_IB(model, seq,  MHC_list, IB_peptide_scores)

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([IB_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])

#### with netMHCpan

In [151]:
missing = []
for j in range(3):
    for seq in tqdm(list(generated_seqs_new[j].keys())):
        seq = seq.replace("-", "")
        for position in range(len(seq)-9):
            if seq[position:(position+9)] not in nMp_peptide_scores:
                missing.append(seq[position:(position+9)])
                
file = open(f"{NMP_FOLDER_1}{os.sep}missing.pep", "w")
for peptide in missing:
      file.writelines([peptide, "\n"])
file.close()

for mhc_name in MHC_list:
    # print(mhc_name)
    mhc_name_2 = mhc_name.replace(":", "").replace("HLA-", "")
    
    print(f"./netMHCpan -p {NMP_FOLDER_2}/missing.pep -a {mhc_name} > {NMP_FOLDER_2}/missing_{mhc_name_2}.pep.out")
    print("\n")

  0%|          | 0/897 [00:00<?, ?it/s]

  0%|          | 0/1154 [00:00<?, ?it/s]

  0%|          | 0/1138 [00:00<?, ?it/s]

./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing.pep -a HLA-A01:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing_A0101.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing.pep -a HLA-A02:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing_A0201.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing.pep -a HLA-A03:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing_A0301.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing.pep -a HLA-A24:02 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing_A2402.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing.pep -a HLA-A26:01 > ~/win/Documents/2022H1/Group_project/CovidProject/netMHCpan/missing_A2601.pep.out


./netMHCpan -p ~/win/Documents/2022H1/Group_project/CovidProject/

In [None]:
nMp_peptide_scores.update(eval_peptides_nMp("missing", MHC_list))

In [None]:
nMp_epitopes_gen = [set(), set(), set()]
for j in range(3):
    for seq in tqdm(generated_seqs_new[j].keys()):
        nMp_seq_scores[seq], epitopes = score_seq_nMp(seq, MHC_list, nMp_peptide_scores)
        nMp_epitopes_gen[j] = nMp_epitopes_gen[j].union(epitopes)

In [None]:
for j, name in enumerate(antigenicity_names):
    print(f"{name}: antigenic epitopes - {len(nMp_epitopes_gen[j])} \t", end="")
    print(f"new antigenic epitopes - {len(nMp_epitopes_gen[j].difference(nMp_epitopes_db))}")

In [None]:
matplotlib.rcParams.update({'font.size': 15})

for j in range(3):
    sns.distplot([nMp_seq_scores[seq] for seq in generated_seqs_new[j].keys()])

plt.legend(labels=["low","medium", "high"])

In [None]:
for j in range(3):
    print(nMp_epitopes_gen[j].difference(nMp_epitopes_db))

# Conv VAE

In [None]:
SEQ_LEN = 1304
ds.pad_to = SEQ_LEN
ds.max_seq_len = SEQ_LEN
ds.conv = True

In [None]:
CONV_BLOCKS = 3
CONV_INPUT_DIM = [AA_ENC_DIM, SEQ_LEN]

In [None]:
VAE = ConvVAE(
        conv_blocks = CONV_BLOCKS,
        conv_input_dim = CONV_INPUT_DIM,
        conv_image_scaling_factor = 0.5,
        fc_blocks = FC_BLOCKS,
        fc_hidden_dim = FC_HIDDEN_DIM,
        fc_hidden_dim_scaling_factor=(0.5, 2.0),
        latent_dim = FC_LATENT_DIM,
        conditional = 3,
        dropout = FC_DROPOUT,
        kl_target = FC_KL_TARGET,
        lr = FC_LR,
        batch_size = FC_BATCH_SIZE,
        weight_decay = FC_WEIGHT_DECAY
)

print(VAE.conv_encoder.output_dim)
print(VAE.conv_decoder.output_dim)

In [None]:
VAE.ds = ds

In [None]:
VAE

## training and loading

In [None]:
CONV_EPOCHS = FC_EPOCHS
CONV_SAVE = True
CONV_LOAD = True
CONV_MODEL_NAME = "conv001"

In [None]:
conv_trainer = Trainer(max_epochs=CONV_EPOCHS, gpus=1)

In [None]:
if CONV_SAVE:
    conv_trainer.fit(VAE)
    conv_trainer.save_checkpoint(f"..{os.sep}models{os.sep}{CONV_MODEL_NAME}.ckpt")

In [None]:
if CONV_LOAD:
    VAE = ConvVAE.load_from_checkpoint(checkpoint_path=f"..{os.sep}models{os.sep}{CONV_MODEL_NAME}.ckpt")
    VAE = VAE.cuda()
    VAE.ds = ds

In [None]:
VAE.eval()

In [None]:
j = -1

In [None]:
j += 1
h = ds.tok.decode(
        VAE.forward(
            ds.tok.tokenize(ds.viral_seqs[j][:SEQ_LEN]).unsqueeze(dim=0).to(VAE.device),
            torch.tensor(ds.seq_immuno_cat_tokens[ds.seq_immuno_cat[ds.viral_seqs[j]]]).unsqueeze(dim=0).to(VAE.device),
            sample=False).reshape(1, SEQ_LEN, -1)
)
alignments = pairwise2.align.globalxx(ds.viral_seqs[j][:MAX_SEQ_LEN], h[0])
print(format_alignment(*alignments[0]))

In [None]:
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F

import pytorch_lightning as pl

from SpikeOracle.models.VAE.fc import FcNetwork

class ConvNetwork(torch.nn.Module):
    def __init__(self,
                 blocks: int,
                 input_dim: (int, int),
                 image_scaling_factor: float,
                 dropout: float
                 ):
        super().__init__()
        self.input_dim = input_dim
        self.input_seq_len = input_dim[1]
        self.dropout = dropout

        channel_scaling_factor = 1. / image_scaling_factor

        dim_ = input_dim
        self.network = nn.Sequential()
        for block in range(blocks - 1):
            block_, dim_ = self.get_conv_block(dim_, image_scaling_factor, dropout)
            self.network.add_module(f"block_{block}", block_)

        block_, dim_ = self.get_conv_block(dim_, image_scaling_factor, 0)
        self.network.add_module(f"block_{block + 1}", block_)

        self.output_dim = dim_


    def forward(self, x):
        return self.network(x)

    def get_conv_block(self, dim, image_scaling_factor, dropout):
        channel_scaling_factor = 1./image_scaling_factor
        out_dim = [int(dim[0]/image_scaling_factor), dim[1]]

        block = nn.Sequential()
        if image_scaling_factor > 1.:
            block.add_module(f"Upsample", nn.Upsample(scale_factor=image_scaling_factor, mode='linear'))
            if out_dim[1]:
                out_dim[1] *= image_scaling_factor

        block.add_module(f"Conv", nn.Conv1d(dim[0], out_dim[0], kernel_size=3, padding=1))
        block.add_module(f"BN", nn.BatchNorm1d(out_dim[0]))
        block.add_module(f"Activation", nn.LeakyReLU(0.1))

        if dropout > 0:
            block.add_module("Dropout", nn.Dropout(dropout))

        if channel_scaling_factor > 1.:
            block.add_module(f"Pooling", nn.MaxPool1d(int(channel_scaling_factor), stride=int(channel_scaling_factor)))
            if out_dim[1]:
                out_dim[1] = int(out_dim[1] * image_scaling_factor)

        return block, out_dim


class ConvVAE(pl.LightningModule):
    """Standard VAE with Gaussian Prior and approx posterior.
    """

    def __init__(
            self,
            conv_blocks: int,
            conv_input_dim: (int, int),
            conv_image_scaling_factor: float,
            fc_blocks: int,
            fc_hidden_dim: int,
            fc_hidden_dim_scaling_factor: (float, float),
            latent_dim: int,
            conditional: int,
            dropout: float,
            kl_target: float,
            lr: float,
            batch_size: int,
            weight_decay: float
    ):
        super().__init__()

        self.save_hyperparameters()

        self.conv_blocks = conv_blocks
        self.conv_input_dim = conv_input_dim
        self.conv_image_scaling_factor = conv_image_scaling_factor
        self.fc_blocks = fc_blocks
        self.fc_hidden_dim = fc_hidden_dim
        self.fc_hidden_dim_scaling_factor = fc_hidden_dim_scaling_factor
        self.latent_dim = latent_dim
        self.conditional = conditional
        self.dropout = dropout
        self.kl_target = kl_target
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay


        self.beta = 0
        self.P = 0
        self.I = 0

        self.ds = None
        self.dl = None


        self.conv_encoder = ConvNetwork(
            blocks=conv_blocks,
            input_dim=conv_input_dim,
            image_scaling_factor=conv_image_scaling_factor,
            dropout=dropout
        )

        enc_fc_input_dim = self.conv_encoder.output_dim[0]*self.conv_encoder.output_dim[1] + conditional
        enc_output_dim = int(fc_hidden_dim * (fc_hidden_dim_scaling_factor[0]**(fc_blocks - 1)))
        self.fc_encoder = FcNetwork(
            blocks=fc_blocks,
            input_dim=enc_fc_input_dim,
            hidden_dim=fc_hidden_dim,
            hidden_dim_scaling_factor=fc_hidden_dim_scaling_factor[0],
            output_dim=enc_output_dim,
            dropout=dropout
        )

        self.fc_mu = nn.Linear(enc_output_dim, latent_dim)
        self.fc_var = nn.Linear(enc_output_dim, latent_dim)

        dec_fc_input_dim = self.latent_dim + conditional
        dec_fc_output_dim = self.conv_encoder.output_dim[0] * self.conv_encoder.output_dim[1]
        self.fc_decoder = FcNetwork(
            blocks=fc_blocks,
            input_dim=dec_fc_input_dim,
            hidden_dim=int(enc_output_dim * fc_hidden_dim_scaling_factor[1]),
            hidden_dim_scaling_factor=fc_hidden_dim_scaling_factor[1],
            output_dim=dec_fc_output_dim,
            dropout=dropout
        )

        self.conv_decoder_input_dim = self.conv_encoder.output_dim
        self.conv_decoder = ConvNetwork(
            blocks=conv_blocks,
            input_dim=self.conv_decoder_input_dim,
            image_scaling_factor=int(1./conv_image_scaling_factor),
            dropout=dropout
        )

        self.z = None

    def forward(self, x, y, sample=True):
        self.z, x_hat_logit, _, _ = self._run_step(x, y, sample)
        print(x_hat_logit.shape)
        #x_hat_logit = x_hat_logit.reshape(-1, self.conv_input_dim[0], self.conv_input_dim[1])
        x_hat_logit = x_hat_logit.permute(0, 2, 1)
        x_hat = torch.nn.functional.softmax(x_hat_logit, -1)
        return x_hat

    def _run_step(self, x, y, sample=True):
        x = self.conv_encoder(x)

        x = torch.flatten(x, start_dim=-2)
        x = x if self.conditional == 0 else torch.cat([x, y], dim=1)
        x = self.fc_encoder(x)
        x = F.relu(x)

        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        if sample:
            p, q, z = self.sample(mu, log_var)
        else:
            p, q, z = None, None, mu

        d = z if self.conditional == 0 else torch.cat([z, y], dim=1)

        d = self.fc_decoder(d)
        d = d.reshape(-1, self.conv_decoder_input_dim[0], self.conv_decoder_input_dim[1])
        d = self.conv_decoder(d)

        return z, d, p, q

    def sample(self, mu, log_var):
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z

    def step(self, batch, batch_idx):
        x, y = batch

        z, x_hat_logit, p, q = self._run_step(x, y)

        # recon_loss = F.mse_loss(x_hat, x, reduction="mean")
        recon_loss = F.cross_entropy(
            x_hat_logit.view(-1, self.conv_input_dim[-1]),
            torch.max(x, dim=-1).indices.contiguous().view(-1))

        kl = torch.distributions.kl_divergence(q, p)
        kl = kl.mean()

        kl_coeff = self.calc_beta(float(kl.detach()), self.kl_target, 1e-3, 5e-4, 1e-4, 1)

        loss = kl * kl_coeff + recon_loss

        logs = {
            "recon_loss": recon_loss,
            "kl": kl,
            "kl_coeff": kl_coeff,
            "loss": loss,
        }
        return loss, logs

    def calc_beta(self, actual_kl, target_kl, Kp, Ki, beta_min, beta_max):
        error = target_kl - actual_kl
        self.P = Kp / (1 + np.exp(error))

        if beta_min < self.beta and self.beta < beta_max:
            self.I = self.I - Ki * error

        self.beta = min(max(self.P + self.I + beta_min, beta_min), beta_max)
        return self.beta

    def training_step(self, batch, batch_idx):
        loss, logs = self.step(batch, batch_idx)
        self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logs = self.step(batch, batch_idx)
        self.log_dict({f"val_{k}": v for k, v in logs.items()})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def train_dataloader(self):
        self.dl = DataLoader(self.ds,
                             shuffle=True,
                             batch_size=self.batch_size)
        return self.dl