# Parameters

In [None]:
# specify which data to use
domain = "HIV_nef"
dataset_name = "HIV_nef_full"
mhc_1_alleles = ['HLA-A*02:01', 'HLA-A*24:02', 'HLA-B*07:02', 'HLA-B*39:01', 'HLA-C*07:01', 'HLA-C*16:01']  

# specify the CAPE-XVAE model
XVAE_job = "mlp_1606474" 

#
# specify the predictors
#

# Structure
predictor_structure_name = "AF"  # ESM or AF

# MHC Class 1
predictor_mhc_1_class = 'Mhc1PredictorNetMhcPan'  # which presentation predictor to use to assessment of visibility
predictor_mhc_1_limit = 2
predictor_mhc_1_name = f"{predictor_mhc_1_class}_{predictor_mhc_1_limit*100}"
predictor_mhc_1_save = False

# Function
predictor_function_name = 'TransFun'

# specify what constitutes extreme cases
vis_down_visibility_percentile = 0.005
vis_up_visibility_percentile = 0.995


#
# Workflow management
#

create_DB = True
update_DB = True
save_figures = True

run_MSA = False
construct_tree = False

random_seed = 19

# Imports

In [None]:
import importlib
from dotenv import load_dotenv
import os
import sys

load_dotenv()
sys.path = [os.environ['PYTHONPATH']]+ sys.path

In [None]:
import os
import string
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import pickle
from collections import defaultdict

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import glasbey

from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.manifold import TSNE, MDS
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.decomposition import PCA

import mdtraj as md

In [None]:
import kit
from kit.loch import seq_hashes_to_file
from kit.loch.utils import get_seq_hash
from kit.log import log_info
import kit.globals as G
from kit.path import join
from kit.bioinf import get_kmers
from kit.bioinf.sf import SequenceFrame
from kit.bioinf.immuno.mhc_1 import Mhc1Predictor
from kit.bioinf.fasta import read_fasta, seqs_to_fasta
from kit.bioinf.alignment.sequence import PairwiseSequenceAligner, MultipleSequenceAligner
from kit.data import DD, Split, str_to_file
from kit.data.utils import set_df_cell_to_np, get_np_from_df_cell
from kit.plot import plot_text, plot_legend_patches, plot_legend_scatter, rm_axes_elements
from kit.maths import ttest

In [None]:
from CAPE.Eval import CapeDB
from CAPE.Eval.utils import pack_to_source_profile_step
from CAPE.Eval.kmers import get_visible_mhc_1_peptides, get_metrics, add_metrics, add_precision_to_seq_kmers, calc_recall_metrics
from CAPE.Eval.plot import set_palettes, set_markers, set_dashes, \
    plot_tsne_kde, plot_avg_dissimilarity_boxplots, plot_boxplots, plot_kmer_similarity, plot_kmer_similarity_box, plot_vs_visibility, \
    plot_seq_epitopes, plot_seq_precision_pie, plot_epitope_recall_by_seq, plot_seqs_precision_bar, \
    get_label_from_pack, plot_natural_vs_pack
from CAPE.profiles import Profile
from CAPE.datasets import FastaDS

In [None]:
np.random.seed(random_seed)

vis_base = Profile.BASE
vis_down = Profile.VIS_DOWN
vis_up = Profile.VIS_UP
vis_up_nat = Profile.VIS_UP_NAT
vis_profiles = [vis_base, vis_down, vis_up, vis_up_nat]

In [None]:
# Enable LaTeX text rendering
plt.rcParams['text.usetex'] = True

In [None]:
kit.init("CAPE", "CAPE-Eval", create_job=False)
G.DOMAIN = domain

# Constants

In [None]:
model_checkpoint = f"{XVAE_job}:last"
DHPARAMS_path = os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE-XVAE', 'jobs', XVAE_job, 'DHPARAMS.yaml')
DHPARAMS = DD.from_yaml(DHPARAMS_path)

In [None]:
# TSNE
TSNE_perplexity = 30
TSNE_iter = 500
tsne_rep_name = f"tsne_{TSNE_perplexity}_{TSNE_iter}"

# PCA
pca_rep_name = f"pca"

In [None]:
source_xvae = 'CAPE-XVAE'
source_packer = 'CAPE-Packer'

In [None]:
kmer_similarity_lengths = list(range(3,16))

In [None]:
open_gap_score = -20
extend_gap_score = -10

pairwise_sequence_aligner = PairwiseSequenceAligner(
    substitution_matrix="BLOSUM62", 
    open_gap_score=open_gap_score, 
    extend_gap_score=extend_gap_score, 
    wildcard="-"
)

In [None]:
A4_width = 8.27
A4_height = 11.69

# DB

In [None]:
db = CapeDB(join(G.ENV.ARTEFACTS, "DBs", "CAPE-Eval.db"), 
            G.DOMAIN, os.environ['LOCH'], 
            None, 
            predictor_structure_name,
            predictor_function_name,
            pairwise_sequence_aligner,
            create_database=create_DB
           )

In [None]:
## Create packs

if update_DB:
    db.cursor.execute('''DELETE FROM packs''')
    db.conn.commit()
    
    # Data
    db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE', f'{G.DOMAIN}.data.TRAIN.seq_hash'))
    db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE', f'{G.DOMAIN}.data.VAL.seq_hash'))
    db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE', f'{G.DOMAIN}.data.TEST.seq_hash'))
    
    # Supports
    db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE', f'{G.DOMAIN}.support.seq_hash'))
  
    # CAPE-XVAE
    for profile in vis_profiles:
        db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE-XVAE', f'{G.DOMAIN}.CAPE-XVAE.{profile}.final.seq_hash'))
    
    # CAPE-Packer
    for profile in vis_profiles:
        db.add_seq_hashes_as_pack(os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE-Packer', f'{G.DOMAIN}.CAPE-Packer.{profile}.final.seq_hash'))
    
    db.conn.commit()

# Evaluation

In [None]:
sql = f'''
    SELECT p.pack, p.seq_hash, s.seq
    FROM packs p LEFT JOIN sequences s ON p.seq_hash == s.seq_hash
    WHERE p.domain == '{G.DOMAIN}' AND s.complete == 1
'''
df_eval = db.sql_to_df(sql).sort_values('seq_hash')
print(f"{len(df_eval)} evaluation entries")

## add source, profile, step

In [None]:
df_eval['source'], df_eval['profile'], df_eval['step'] = None, None, None
for idx, row in df_eval.iterrows():
    source, profile, step = pack_to_source_profile_step(row.pack)

    df_eval.at[idx, 'source'] = source
    df_eval.at[idx, 'profile'] = profile
    df_eval.at[idx, 'step'] = step

N_Supports = len(df_eval.query("source == 'support'"))
print(N_Supports)

## add length

In [None]:
df_eval['seq_length'] = df_eval.apply(lambda row: len(row.seq), axis=1)
df_eval[['pack', 'seq_length']].groupby(['pack']).aggregate(['min', 'mean', 'max'])

## add count to natural sequences

In [None]:
DHPARAMS = FastaDS.get_dhparams(dataset_name, None, True)
datasets = {split: FastaDS(split,  os.path.join(os.environ['PF'], 'data', 'input'), DHPARAMS) for split in Split}

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)
df_eval.set_index('seq_hash', inplace=True)

df_eval['cnt'] = 0
for split, ds in datasets.items():
    if ds.df is not None:
        for idx, row in ds.df.iterrows():
            if row.complete:
                seq_hash = get_seq_hash(idx)
                df_eval.loc[seq_hash, 'cnt'] += row.cnt

df_eval.reset_index(inplace=True)

In [None]:
df_eval['weight'] = df_eval['cnt']/df_eval.cnt.sum()

## Immune Visibility

### Load Predictors

In [None]:
predictor_mhc_1 = Mhc1Predictor.get_predictor(predictor_mhc_1_class)(
    join(G.ENV.PROJECT, 'artefacts', 'immunology', 'netMHCpan', 'percentile'), 
    limit=predictor_mhc_1_limit/100,
    mhc_1_alleles_to_load=mhc_1_alleles
)

### Predict Visibility

In [None]:
db.predictors_mhc_1 = {predictor_mhc_1_name: predictor_mhc_1}

In [None]:
for mhc_1_allele in tqdm(mhc_1_alleles, desc="mhc_1_alleles"):
    db.add_mhc_1_presentations(predictor_mhc_1_name, mhc_1_allele)
    db.conn.commit()

In [None]:
if predictor_mhc_1_save:
    predictor_mhc_1.save()

In [None]:
df_eval['visibility'] = None
for idx, row in tqdm(df_eval.iterrows()):
    df_eval.at[idx, 'visibility'] = db.get_visibility_mhc_1(predictor_mhc_1_name, mhc_1_alleles, row.seq_hash)
df_eval['visibility'] = df_eval['visibility'].astype('float')

In [None]:
df_natural = df_eval.query('source == "natural"')
vis_down_visibility_threshold = np.percentile(df_natural.visibility, vis_down_visibility_percentile*100)
vis_up_visibility_threshold = np.percentile(df_natural.visibility, vis_up_visibility_percentile*100)
print(f"{vis_down}_visibility_threshold: {vis_down_visibility_threshold}")
print(f"{vis_up}_visibility_threshold: {vis_up_visibility_threshold}")

In [None]:
df_eval[['pack', 'visibility']].groupby('pack').agg(['min', 'mean', 'max'])

In [None]:
df_tmp = df_eval[['seq_hash', 'pack']].groupby('seq_hash').agg(list)
df_tmp['len'] = df_tmp.apply(lambda row: len(row.pack), axis=1)
df_tmp = df_tmp.query('len > 1')
for seq_hash, row in df_tmp.iterrows():
    for p in row.pack:
        if p not in ['support', 'data.TRAIN', 'data.VAL', 'data.TEST']:
            print(row)

## Similarity

### Structure

In [None]:
df_eval['max_TMscore_support_seq_hash'], df_eval['max_TMscore'], df_eval['max_TMscore_aligned_length'], df_eval['max_TMscore_rmsd'] = \
    "", None, None, None

In [None]:
packs_TMscore = [
    ('support', False),
       
    (f'CAPE-XVAE.{vis_base}.final', True),
    (f'CAPE-XVAE.{vis_down}.final', True),
    (f'CAPE-XVAE.{vis_up}.final', True),
    (f'CAPE-XVAE.{vis_up_nat}.final', True),
    
    (f'CAPE-Packer.{vis_base}.final', True),
    (f'CAPE-Packer.{vis_down}.final', True),
    (f'CAPE-Packer.{vis_up}.final', True)
]

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)
df_eval.set_index('seq_hash', inplace=True)
for pack, self_match in packs_TMscore:
    df = db.get_pack(pack)
    cnt = 0
    for _, row in tqdm(df.iterrows()):
        seq_hash = row['seq_hash']
        if os.path.exists(db.get_pdb_file_path(seq_hash)):
            sequence = db.get_sequence(seq_hash=seq_hash)
            max_TMscore_support_seq_hash, max_TMscore, max_TMscore_aligned_length, max_TMscore_rmsd = db.get_closest(sequence['seq'], self_match=self_match)
    
            df_eval.at[seq_hash, 'max_TMscore_support_seq_hash'] = max_TMscore_support_seq_hash
            df_eval.at[seq_hash, 'max_TMscore'] = max_TMscore
            df_eval.at[seq_hash, 'max_TMscore_aligned_length'] = max_TMscore_aligned_length
            df_eval.at[seq_hash, 'max_TMscore_rmsd'] = max_TMscore_rmsd
            cnt += 1
    print(f"{pack}: {cnt} aligned")
df_eval.reset_index(inplace=True)
df_eval.max_TMscore = df_eval.max_TMscore.astype('float')

print(f'len(df_eval): {len(df_eval)}')

In [None]:
print(f"{'pack':40s} | {'mean':5} | {'std':5} | {'min':5} | {'max':5}")

for pack, _ in packs_TMscore:
    mean = df_eval.query(f'pack == "{pack}"')['max_TMscore'].mean()
    std = df_eval.query(f'pack == "{pack}"')['max_TMscore'].std()
    min = df_eval.query(f'pack == "{pack}"')['max_TMscore'].min()
    max = df_eval.query(f'pack == "{pack}"')['max_TMscore'].max()
    print(f"{pack:40s} | {mean:5.3f} | {std:5.3f} | {min:5.3f} | {max:5.3f}")

In [None]:
print(f"{'pack':40s} | {'mean':5} | {'std':5} | {'min':5} | {'max':5}")

for pack, _ in packs_TMscore:
    mean = df_eval.query(f'pack == "{pack}"')['max_TMscore_aligned_length'].mean()
    std = df_eval.query(f'pack == "{pack}"')['max_TMscore_aligned_length'].std()
    min = df_eval.query(f'pack == "{pack}"')['max_TMscore_aligned_length'].min()
    max = df_eval.query(f'pack == "{pack}"')['max_TMscore_aligned_length'].max()
    print(f"{pack:40s} | {mean:5.1f} | {std:5.1f} | {min:5.1f} | {max:5.1f}")

In [None]:
print(f"{'pack':40s} | {'mean':5} | {'std':5} | {'min':5} | {'max':5}")

for pack, _ in packs_TMscore:
    mean = df_eval.query(f'pack == "{pack}"')['max_TMscore_rmsd'].mean()
    std = df_eval.query(f'pack == "{pack}"')['max_TMscore_rmsd'].std()
    min = df_eval.query(f'pack == "{pack}"')['max_TMscore_rmsd'].min()
    max = df_eval.query(f'pack == "{pack}"')['max_TMscore_rmsd'].max()
    print(f"{pack:40s} | {mean:5.1f} | {std:5.1f} | {min:5.1f} | {max:5.1f}")

### Sequence

#### K-mers

In [None]:
natural_kmers_data_file_path = os.path.join(G.ENV.ARTEFACTS, f"{G.DOMAIN}.kmers.natural.pickle")
if os.path.exists(natural_kmers_data_file_path):
    with open(natural_kmers_data_file_path, "rb") as file:
        natural_kmers_data = pickle.load(file)
else:
    natural_kmers_dict = {length: {} for length in kmer_similarity_lengths}
    df_natural = df_eval.query('source == "natural"')
    for idx, row in tqdm(df_natural.iterrows()):
        for length in kmer_similarity_lengths:
            kmers = get_kmers(row.seq, lengths=[length])
            for kmer in kmers:
                natural_kmers_dict[length][kmer] = natural_kmers_dict[length].get(kmer, 0) + 1

    natural_kmers_data = {
        'natural_kmers_dict': natural_kmers_dict,
        'n_natural': len(df_natural)
    }
    with open(natural_kmers_data_file_path, "wb") as file:
        pickle.dump(natural_kmers_data, file)

natural_kmers_dict = natural_kmers_data['natural_kmers_dict']

natural_kmers = set()
natural_kmers.update(*[set(natural_kmers_dict[length].keys()) for length in kmer_similarity_lengths])
print(f"K-mer count: {len(natural_kmers):,}")

In [None]:
df_eval[f'kmer_similarity'] = None
for k in kmer_similarity_lengths:
    df_eval[f'{k}mer_similarity'] = None


a_len = np.max(kmer_similarity_lengths) + 1
for idx, row in tqdm(df_eval.iterrows()):
    natural_kmers_cnt, artificial_kmers_cnt = [0]*a_len, [0]*a_len

    kmers = get_kmers(row.seq, lengths=kmer_similarity_lengths)
    for kmer in kmers:
        length = len(kmer)
        if kmer in natural_kmers:
            natural_kmers_cnt[length] += 1
        else:
            artificial_kmers_cnt[length] += 1
                
    df_eval.at[idx, 'kmer_similarity'] = np.sum(natural_kmers_cnt) / (np.sum(natural_kmers_cnt) + np.sum(artificial_kmers_cnt))
    for k in kmer_similarity_lengths:
        df_eval.at[idx, f'{k}mer_similarity'] = natural_kmers_cnt[k] / (natural_kmers_cnt[k] + artificial_kmers_cnt[k])
    
df_eval.kmer_similarity = df_eval.kmer_similarity.astype('float')
for k in kmer_similarity_lengths:
    df_eval[f'{k}mer_similarity'] = df_eval[f'{k}mer_similarity'].astype('float')

#### recall/precision of visible peptides

In [None]:
mhc_1_peptide_lengths=[8,9,10]

mhc_1_peptide_lengths_col = "+".join([str(x) for x in mhc_1_peptide_lengths])
seq_hashes_natural = sorted(list(set(df_eval.query("source == 'natural'").seq_hash)))

metrics = ['recall', 'precision']

def f_metrics(d_visible: dict, seq_hash_proxy: str, seq_hash_target: str):
    visible_proxy = d_visible[seq_hash_proxy]
    visible_target = d_visible[seq_hash_target]
    l_intersection = len(visible_proxy.intersection(visible_target))
    recall = l_intersection / len(visible_target)
    precision = l_intersection / len(visible_proxy)        
    return recall, precision 

def P01(values):
    return np.percentile(values, 0.1)

def P05(values):
    return np.percentile(values, 0.5)

def P1(values):
    return np.percentile(values, 1)

def P5(values):
    return np.percentile(values, 5)

def P10(values):
    return np.percentile(values, 10)

functions = ['np.mean', 'np.median', 'np.min', 'P01', 'P05', 'P1', 'P5', 'P10']

d_visible = get_visible_mhc_1_peptides(df_eval, predictor_mhc_1, mhc_1_alleles, mhc_1_peptide_lengths=mhc_1_peptide_lengths)

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)

d_cnt = df_eval.set_index('seq_hash').query("source == 'natural'").cnt.to_dict()

for metric in metrics:
    for function in functions:
        col = f"{function}_{metric}_{mhc_1_peptide_lengths_col}"
        df_eval[col] = None

results = np.zeros((np.sum(list(d_cnt.values())), len(metrics)))

seq_hashes_proxy = list(df_eval.seq_hash)
df_eval.set_index('seq_hash', inplace=True)
for seq_hash_proxy in tqdm(seq_hashes_proxy):
    j = 0
    for seq_hash_target in seq_hashes_natural:
        for _ in range(d_cnt[seq_hash_target]):
            results[j, :] = f_metrics(d_visible, seq_hash_proxy, seq_hash_target)
            j += 1

    for m, metric in enumerate(metrics):
        for function in functions:
            function_name = function.split('.')[-1]
            col = f"{function_name}_{metric}_{mhc_1_peptide_lengths_col}"
            df_eval.loc[seq_hash_proxy, col] = eval(function)(results[:,m])

df_eval.reset_index(inplace=True)

### Function

In [None]:
function_packs = ['support',
 'CAPE-XVAE.baseline.final',
 'CAPE-XVAE.reduced.final',
 'CAPE-XVAE.increased.final',
 'CAPE-XVAE.inc-nat.final',
 'CAPE-Packer.baseline.final',
 'CAPE-Packer.reduced.final',
 'CAPE-Packer.increased.final'
]

function_seq_hashes = df_eval.query(f"pack in {function_packs}")[['seq_hash']].set_index('seq_hash')

In [None]:
df_GO_vectors = {}

for idx, row in tqdm(df_eval.iterrows()):
    seq_hash = row.seq_hash
    go_function_file_path = db.get_function_path(seq_hash)
    if os.path.exists(go_function_file_path) and seq_hash not in df_GO_vectors:
        df = pd.read_csv(go_function_file_path, sep=' ', names= ['structure', 'GO_term', 'score']).sort_values('GO_term').set_index('GO_term')
        df = df[['score']].rename(columns={'score': seq_hash})
        df_GO_vectors[seq_hash] = df

df_GO_vectors = pd.concat(list(df_GO_vectors.values()), axis=1).transpose()
df_GO_vectors.index.name='seq_hash'

In [None]:
df_GO_1 = df_eval.set_index('seq_hash')[['source', 'profile', 'step']].join(df_GO_vectors, how='left')
print(df_GO_1.groupby(['source', 'profile', 'step'])[['GO:0000006']].count())

In [None]:
df_GO = df_GO_1.query("not `GO:0000006`.isna()").drop(columns=['source', 'profile', 'step']).copy().drop_duplicates()
assert df_GO.index.is_unique

In [None]:
df_GO = df_GO.join(function_seq_hashes, how='inner')

In [None]:
pipeline = Pipeline([
    ('scaler', Normalizer()),  # Standardize the data
    ('pca', PCA())  # Perform PCA
])

pipeline.fit(df_GO)
print(pipeline['pca'].explained_variance_ratio_.round(2)[:20])

In [None]:
f"{pipeline['pca'].explained_variance_ratio_[:2].sum():.2f}"

In [None]:
tgt_explained = 0.99
n_pca = [i for i, v in enumerate([pipeline['pca'].explained_variance_ratio_[:n].sum() for n in range(len(pipeline['pca'].explained_variance_ratio_))]) if v > tgt_explained][0]
print(f"First {n_pca} components, explaining {pipeline['pca'].explained_variance_ratio_[:n_pca].sum().round(4)} if variance")

In [None]:
transformed = pipeline.transform(df_GO)
for d in range(n_pca):
    df_GO[f'GO PCA {d+1}'] = transformed[:, d]

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)
df_eval.set_index('seq_hash', inplace=True)

df_eval = df_eval.drop(columns=[c for c in df_eval.columns if c in [f'GO PCA {d+1}' for d in range(n_pca)]])\
    .join(df_GO[[f'GO PCA {d+1}' for d in range(n_pca)]], how='left')

df_eval.reset_index(inplace=True)
assert 'seq_hash' in df_eval.columns

In [None]:
def run_GaussianNB(packs, cnt=0):
    #
    # Data
    #
    nb_selection = f'source != "natural" and not `GO PCA 1`.isna()'
    fun_nb_df = df_eval.query(nb_selection).copy()

    if cnt > 0:
        fun_nb_df.loc[(fun_nb_df.cnt <= cnt) & (fun_nb_df.pack == 'support'), 'pack'] = 'support_low_cnt'
        for idx, p in enumerate(packs):
            if p == 'support':
                packs.insert(idx+1, 'support_low_cnt')

    fun_nb_X = fun_nb_df[[f'GO PCA {d+1}' for d in range(n_pca)]]
    fun_nb_y = pd.Categorical(fun_nb_df.pack)
    X = fun_nb_X.copy()
    y = fun_nb_y.copy()
    
    #
    # Classifier
    #
    nb_classifier = GaussianNB()

    # train
    nb_classifier.fit(fun_nb_X, fun_nb_y.codes)
    
    # predict
    y_pred = nb_classifier.predict(X)

    # generate confusion matrix
    cm = confusion_matrix(y.codes, y_pred, labels=range(len(y.categories)))

    # format confusion matrix
    p_cm = pd.DataFrame(cm, index=fun_nb_y.categories, columns=fun_nb_y.categories)
    p_cm['sum'] = p_cm.sum(axis=1)
    p_cm['pc'] = p_cm['support'] / p_cm['sum'] 
    print(p_cm[['pc', 'support', 'sum']])
    
    df_eval['func_nb_pred'] = None
    for j, (idx, row) in enumerate(fun_nb_X.iterrows()):
        df_eval.at[idx, 'func_nb_pred'] = y.categories[y_pred[j]]

    for c in fun_nb_y.categories:
        p_cm[f"{c}_pc"] = p_cm.apply(lambda row: f"{100*row[c]/row['sum']:.0f}\%", axis=1) 

    #
    # Latex output
    #
    print('\n\n')
    text = ''
    
    for p in packs:
        label = get_label_from_pack(p)
        text += f"& {label} "
    text += r"\\"
    text += "\n"
    
    for pack in packs:
        if pack in p_cm:
            row = p_cm.loc[pack]
            label = get_label_from_pack(pack)
            text += f"{label} "
            for p in packs:
                value = row[f'{p}_pc'] if f'{p}_pc' in row else "0\%"
                text += f"& {value} "
            text += r"\\"
            text += "\n"
    
    print(text)

In [None]:
run_GaussianNB(function_packs)

In [None]:
run_GaussianNB([p[0] for p in packs_TMscore], 2)

## Stability

### MD simulations

In [None]:
md_param_hash = '675742'
df_eval['md_rmsd'], df_eval['md_max_rmsd'], df_eval['md_rmsf'] = None, None, None

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)
df_eval.set_index('seq_hash', inplace=True)

for seq_hash, row in df_eval.iterrows():
    md_path = db.get_md_path(seq_hash, md_param_hash)
    dcd_file_path = os.path.join(md_path, 'output_sub.dcd')
    if os.path.exists(dcd_file_path):
        print(seq_hash)

        pdb_file_path = db.get_pdb_file_path(seq_hash)
        structure = md.load(pdb_file_path)
        topology = md.load_topology(pdb_file_path)
        alpha_carbon_atoms = [atom.index for atom in topology.atoms if atom.name == 'CA']

        # reduce the topology to its CA atoms
        topology_sub = md.Topology()
        chain = topology_sub.add_chain()
        for r, _ in enumerate(alpha_carbon_atoms):
            res = topology_sub.add_residue(topology.residue(r), chain)
            topology_sub.add_atom(f"CA{r}", "C", res)
        
        trajectory = md.load_dcd(dcd_file_path, top=topology_sub)

        rmsd = md.rmsd(trajectory, trajectory[0])
        rmsf = md.rmsf(trajectory, trajectory[0])

        df_eval.at[seq_hash, 'md_max_rmsd'] = rmsd.max()
        set_df_cell_to_np(df_eval, seq_hash, 'md_rmsd', rmsd)
        set_df_cell_to_np(df_eval, seq_hash, 'md_rmsf', rmsf)

# np.frombuffer(rmsd.astype(np.dtype('float32')).tobytes(), dtype=np.dtype('float32'))

df_eval.md_max_rmsd = df_eval.md_max_rmsd.astype('float')

df_eval.reset_index(inplace=True)

In [None]:
df_eval.query("not md_rmsd.isnull()")[['seq_hash', 'pack']].groupby('pack').count()

# Visualizations

## represent

### support representation

In [None]:
db.add_support_rep()

### tSNE support representation

In [None]:
tsne = TSNE(n_components=2, verbose=0, perplexity=TSNE_perplexity, n_iter=TSNE_iter)
tsne_rep_base_hash = db.add_reduced_support_rep(df_eval, tsne_rep_name, tsne)
df_TSNE = db.get_rep(tsne_rep_name, tsne_rep_base_hash).rename(columns={'rep_1': 't-SNE 1', 'rep_2': 't-SNE 2'})
df_eval = df_eval[[c for c in df_eval.columns if c not in ['t-SNE 1', 't-SNE 2']]].join(df_TSNE, how='left', on='seq_hash')

### PCA support representation

In [None]:
pca = PCA(n_components=2)
pca_rep_base_hash = db.add_reduced_support_rep(df_eval, pca_rep_name, pca)
df_PCA = db.get_rep(pca_rep_name, pca_rep_base_hash).rename(columns={'rep_1': 'PCA 1', 'rep_2': 'PCA 2'})
df_eval = df_eval[[c for c in df_eval.columns if c not in ['PCA 1', 'PCA 2']]].join(df_PCA, how='left', on='seq_hash')

### Dissimilarities

In [None]:
dissimilarity_src_packs

In [None]:
if df_eval.index.name is not None:
    df_eval.reset_index(inplace=True)
df_eval.set_index('seq_hash', inplace=True)


dissimilarity_src_packs = [
    # (source, [packs])
    ('support', ['support']), 
    (source_xvae, [f"{source_xvae}.{profile}.final" for profile in vis_profiles]),
    (source_packer, [f"{source_packer}.{profile}.final" for profile in vis_profiles]),
]

dissimilarity_tgt_packs = [
    f'support',
    f'{source_packer}.{vis_base}.final'
]

for pack_name in dissimilarity_tgt_packs:
    column = f'avg_dissimilarity_{pack_name}'
    df_eval[column] = None

for src_pack in tqdm(dissimilarity_src_packs):
    for pack_1_name in src_pack[1]:
        df_pack_1 = db.get_pack(pack_1_name)
        for pack_2_name in dissimilarity_tgt_packs:
            column = f'avg_dissimilarity_{pack_2_name}'
            df_pack_2 = db.get_pack(pack_2_name)
            for _, row in df_pack_1.iterrows():
                value = np.mean(pairwise_sequence_aligner.get_seq_to_seqs_dissimilarity(row.seq, df_pack_2.seq))
                df_eval.at[row.seq_hash, column] = value

df_eval.reset_index(inplace=True)

## Plot

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
#color_palette = glasbey.create_palette(11, colorblind_safe=True)
color_palette = glasbey.create_block_palette(
    [1, # Data
     5, # XVAE
     5  # Packer
     ],
    colorblind_safe=True, 
    cvd_severity=60,
)

# color_palette = ['blue', 'yellow', 'red', 'orange', 'cyan', 'green', 'magenta', 'darkgoldenrod', 'darkcyan']
# palette_name = 'colorblind' # 'tab20b'
# color_palette = sns.color_palette(palette_name)

sns.palplot(color_palette)

color_palette = [np.array([0.82539682, 0.        , 0.53968261]),
 '#4d001a',
 '#8b0f0c',
 '#c52404',
 '#e26a25',
 '#ffaf46',
 '#002e66',
 '#0c4cb5',
 '#1d6bff',
 '#7a95ff',
 '#d6bfff']

sns.palplot(color_palette)

In [None]:
palettes = {
    'pack': {# Pack-name: Color
        'support': color_palette[0],
        'data.TRAIN': color_palette[0],
        'data.VAL': color_palette[0],
        'data.TEST': color_palette[0],

        f'{source_xvae}.{vis_base}.final': color_palette[2],
        f'{source_xvae}.{vis_down}.final': color_palette[3],     
        f'{source_xvae}.{vis_up}.final': color_palette[4], 
        f'{source_xvae}.{vis_up_nat}.final': color_palette[5], 

        f'{source_packer}.{vis_base}.final': color_palette[7],
        f'{source_packer}.{vis_down}.final': color_palette[8], 
        f'{source_packer}.{vis_up}.final': color_palette[9],
        f'{source_packer}.{vis_up_nat}.final': color_palette[10]
    },
    'source': {
        'support': color_palette[0],
        'natural': color_palette[0],
        source_xvae: color_palette[1],
        source_packer: color_palette[6]
    }    
}
set_palettes(palettes)

In [None]:
styles = {vis_base: '+', 
          vis_down: '1', 
          vis_up: '2',
          vis_up_nat: '3',
         }

markers = {
    'pack': {      
        'support': styles[vis_base],
        'data.TRAIN': styles[vis_base],
        'data.VAL': styles[vis_base],
        'data.TEST': styles[vis_base],

        f'{source_xvae}.{vis_base}.final': styles[vis_base],
        f'{source_xvae}.{vis_down}.final': styles[vis_down],    
        f'{source_xvae}.{vis_up}.final': styles[vis_up],
        f'{source_xvae}.{vis_up_nat}.final': styles[vis_up_nat],

        f'{source_packer}.{vis_base}.final': styles[vis_base],
        f'{source_packer}.{vis_down}.final': styles[vis_down],
        f'{source_packer}.{vis_up}.final': styles[vis_up],        
        f'{source_packer}.{vis_up_nat}.final': styles[vis_up_nat],     
    },
    'source': {
        'support': styles[vis_base],
        'natural': styles[vis_base],
        source_xvae: styles[vis_base],
        source_packer: styles[vis_base],
    }    
}

set_markers(markers)

In [None]:
line_styles = {vis_base: 'solid', 
          vis_down: 'dotted', 
          vis_up: 'dashed',
          vis_up_nat: 'dashdot'
}

dashes = {
    'pack': {      
        'support': line_styles[vis_base],
        'data.TRAIN': line_styles[vis_base],
        'data.VAL': line_styles[vis_base],
        'data.TEST': line_styles[vis_base],

        f'{source_xvae}.{vis_base}.final': line_styles[vis_base],
        f'{source_xvae}.{vis_down}.final': line_styles[vis_down],    
        f'{source_xvae}.{vis_up}.final': line_styles[vis_up],
        f'{source_xvae}.{vis_up_nat}.final': line_styles[vis_up_nat],

        f'{source_packer}.{vis_base}.final': line_styles[vis_base],
        f'{source_packer}.{vis_down}.final': line_styles[vis_down],
        f'{source_packer}.{vis_up}.final': line_styles[vis_up],        
        f'{source_packer}.{vis_up_nat}.final': line_styles[vis_up_nat],        
    },
    'source': {
        'support': line_styles[vis_base],
        'natural': line_styles[vis_base],
        source_xvae: line_styles[vis_base],
        source_packer: line_styles[vis_base],
    }    
}

set_dashes(dashes)

### Figure A - Low-dimensional sequence space representation

In [None]:
font_scale=10/12
fig = plt.figure(figsize=(A4_width, A4_height*4/5))
all_axes = []
gs = mpl.gridspec.GridSpec(3, 6, 
                           width_ratios=[1, 1, 1, 1, 1, 1], 
                           height_ratios=[2, 2, 2], 
                           wspace=0.2, hspace=0.2)

for col, (header, directions) in enumerate([
    (vis_base, [vis_base]),
    (vis_down, [vis_down]),
    (vis_up, [vis_up, vis_up_nat])
]):
    dots_packs_XVAE, dots_packs_Packer = [], []
    for direction in directions:
        step = 'final'
        dots_packs_XVAE.append(f'{source_xvae}.{direction}.{step}')
        dots_packs_Packer.append(f'{source_packer}.{direction}.final')
    all_axes.append(plot_tsne_kde(header, df_eval, dots_packs_XVAE, gs[0, (2*col):(2*col+2)], font_scale=font_scale, xlim=(-60, 60), ylim=(-60, 60), rm_ylabel=(col > 0), rm_xlabel=True))
    all_axes.append(plot_tsne_kde('', df_eval, dots_packs_Packer, gs[1, (2*col):(2*col+2)], font_scale=font_scale, xlim=(-60, 60), ylim=(-60, 60), rm_ylabel=(col > 0)))
   
all_axes.append(plot_avg_dissimilarity_boxplots(df_eval, 
                    dissimilarity_src_packs, dissimilarity_tgt_packs[:1], 
                    1500, gs[2,:3], font_scale=font_scale))
all_axes.append(plot_avg_dissimilarity_boxplots(df_eval, 
                    dissimilarity_src_packs, dissimilarity_tgt_packs[1:], 
                    1500, gs[2,3:], font_scale=font_scale, show_ylabel=False))

labels = [f'{c})' for c in string.ascii_lowercase]
for i, ax in enumerate(all_axes):
    ax.text(-0., 1.08, labels[i], transform=ax.transAxes,
            fontsize=16*font_scale, fontweight='bold', va='top', ha='right') #, color='red')

if save_figures:
    fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_A.pdf"), bbox_inches='tight')

### Figure B - Visibility and Similarity

In [None]:
font_scale=9/12
fig = plt.figure(figsize=(A4_width, A4_height*0.8))
all_axes = []
gs = mpl.gridspec.GridSpec(3, 6, 
                           width_ratios=[1, 1, 1, 1, 1, 1], 
                           height_ratios=[1, 1, 1], 
                           wspace=0.2, hspace=0.2)



all_axes.append(plot_boxplots("Immune-visibility", 
                              df_eval, 
                              "visibility", 
                              ["natural", source_xvae, source_packer], gs[0, :3], font_scale=font_scale, add_nat_mean_line=True))
all_axes.append(plot_kmer_similarity(df_eval, [
    "support", 
    f"CAPE-XVAE.{vis_base}.final", f"CAPE-XVAE.{vis_down}.final", f"CAPE-XVAE.{vis_up}.final", f"CAPE-XVAE.{vis_up_nat}.final",
    f"CAPE-Packer.{vis_base}.final", f"CAPE-Packer.{vis_down}.final", f"CAPE-Packer.{vis_up}.final", f"CAPE-Packer.{vis_up_nat}.final"
    ], kmer_similarity_lengths, gs[0, 3:], font_scale))

all_axes.append(plot_kmer_similarity_box(df_eval, f"{source_xvae}.{vis_down}.final", kmer_similarity_lengths, 0.005, gs[1, :2], font_scale, title_rows=3))
all_axes.append(plot_kmer_similarity_box(df_eval, f"{source_xvae}.{vis_up}.final", kmer_similarity_lengths, 0.005, gs[1, 2:4], font_scale, title_rows=3, rm_y_label=True))
all_axes.append(plot_kmer_similarity_box(df_eval, f"{source_xvae}.{vis_up_nat}.final", kmer_similarity_lengths, 0.005, gs[1, 4:], font_scale, title_rows=3, rm_y_label=True))

all_axes.append(plot_boxplots("TM Score", 
                              df_eval.rename(columns={"max_TMscore": "TMscore"}), 
                              "TMscore", 
                              ["support", source_xvae, source_packer], gs[2, :], font_scale))

labels = [f'{c})' for c in string.ascii_lowercase]
for i, ax in enumerate(all_axes):
    ax.text(-0., 1.08, labels[i], transform=ax.transAxes,
            fontsize=16*font_scale, fontweight='bold', va='top', ha='right') #, color='red')

if save_figures:
    fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_B.pdf"), bbox_inches='tight')

In [None]:
df_eval.groupby(['source', 'profile'])['visibility'].agg(['mean'])

### Figure C - GO

In [None]:
def t(l, a, b):
    return np.array(l).reshape(a,b).transpose().reshape(-1)

In [None]:
models = [source_xvae, source_packer]
profiles = [
    [str(vis_base)], 
    [str(vis_down)], 
    [str(vis_up), str(vis_up_nat)]
]

fig = plt.figure(figsize=(A4_width, A4_height/2))
gs = mpl.gridspec.GridSpec(2+len(models), 1+len(profiles),
                           width_ratios=[1] + [10]*len(profiles), 
                           height_ratios=[1] + [10]*len(models) + [5], 
                           wspace=0.1, hspace=0.1)

sharex = [None] * len(profiles)
packs = []
for i, model in enumerate(models):
    sharey = None
    for j, l_profiles in enumerate(profiles):
        if i == 0:
            text = "/".join(l_profiles)
            ax = fig.add_subplot(gs[i,1+j])
            plot_text(text, ax)

        if j == 0:
            ax = fig.add_subplot(gs[1+i,j])
            plot_text(model, ax, rotation=90, y_pos=0.75)            
        
        ax = fig.add_subplot(gs[1+i,1+j], sharey=sharey, sharex=sharex[j])
        ax.yaxis.set_label_position("right")
        df_support = df_eval.query(f'pack == "support"')
        sns.scatterplot(ax=ax, data=df_support, x='GO PCA 1', y="GO PCA 2", 
                        hue='pack', palette=palettes['pack'], style='pack', markers='o',
                       linewidth=1, s=30)

        for profile in l_profiles:
            df_data = df_eval.query(f'(pack == "{model}.{profile}.final" or pack == "{model}.{profile}.final") and not `GO PCA 1`.isnull()')
            packs += df_data.pack.unique().tolist()
            sns.scatterplot(ax=ax, data=df_data, x='GO PCA 1', y="GO PCA 2", 
                            hue='pack', palette=palettes['pack'], style='pack', markers=markers['pack'],
                           linewidth=1, s=80)

        if i != len(models)-1:
            ax.set_xlabel('')

        if i == 0:
            sharex.append(ax)

        if j != len(profiles)-1:
            ax.set_ylabel('')
        else:
            sharey = ax

        ax.set_xticklabels('')
        ax.set_yticklabels('')
        rm_axes_elements(ax, 'ticks')
        ax.legend().set_visible(False)

ax = fig.add_subplot(gs[-1,:])
_labels = t(['support', ' ', ' ', ' '] + [get_label_from_pack(pack) for pack in packs] + [' '], 3, 4)
_markers = t(['o', 'x', 'x', 'x'] + [markers['pack'][pack] for pack in packs] + ['x'], 3, 4)
_colors = t([mpl.colors.to_hex(palettes['pack']['support']), 'white', 'white', 'white'] + [palettes['pack'][pack] for pack in packs] + ['white'], 3, 4)
plot_legend_scatter(ax, _labels, _markers, _colors, ncol=4, loc='lower center', bbox_to_anchor=(0.5, 0), framealpha=1.0)
rm_axes_elements(ax, 'plain')

if save_figures:
    fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_C.pdf"), bbox_inches='tight')

### Figure D - dependence on visibility

In [None]:
packs_vs_vis = [
    f"CAPE-XVAE.{vis_base}.final", f"CAPE-XVAE.{vis_down}.final", f"CAPE-XVAE.{vis_up}.final", f"CAPE-XVAE.{vis_up_nat}.final",
    f"CAPE-Packer.{vis_base}.final", f"CAPE-Packer.{vis_down}.final", f"CAPE-Packer.{vis_up}.final"
]

fig = plt.figure(figsize=(A4_width, A4_height*0.8))
gs = mpl.gridspec.GridSpec(4, 2 + len(packs_vs_vis), 
                           width_ratios=[10, 1] + [1] * len(packs_vs_vis), height_ratios=[4,4,4,1],
                           wspace=0., hspace=0.2)

ax_sharex = plot_vs_visibility(df_eval, "support", packs_vs_vis, "9mer_similarity", "9-mer similarity", gs[0, :])
plot_vs_visibility(df_eval, "support", packs_vs_vis, "max_TMscore", "TM score", gs[1, :], ax_sharex=ax_sharex)
plot_vs_visibility(df_eval, "support", packs_vs_vis, "md_max_rmsd", "max MD-RMSD", gs[2, :], ax_sharex=ax_sharex, ylim=5., xlabel_boxplots=True)

ax1 = plt.subplot(gs[3,0], sharex=ax_sharex)
# ax1.set_title("Data visibility distribution")
sns.boxplot(ax=ax1, data=df_eval.query("source == 'natural'"), x="visibility", y="source", palette=palettes['source'])
ax1.set_ylabel('')

if save_figures:
    fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_D.pdf"), bbox_inches='tight')

### Figure E - Recall / Precision

In [None]:
fig = plt.figure(figsize=(A4_width, A4_height*0.3))
gs = mpl.gridspec.GridSpec(1, 2, 
                           width_ratios=[1, 1], height_ratios=[1],
                           wspace=0.1, hspace=0.)

all_axes = []
all_axes.append(plot_natural_vs_pack(df_eval, f"{source_xvae}.inc-nat.final", 
                                     "mean", 'precision', "mean", 'recall',mhc_1_peptide_lengths_col, 
                                     gs[0,0], xlim=0.35, ylim=0.35, position='left'))
all_axes.append(plot_natural_vs_pack(df_eval, f"{source_xvae}.inc-nat.final", 
                                     "mean", 'precision', "mean", 'recall', mhc_1_peptide_lengths_col, 
                                     gs[0,1], vis_up_visibility_threshold, xlim=0.35, ylim=0.35))

labels = [f'{c})' for c in string.ascii_lowercase]
for i, ax in enumerate(all_axes):
    ax.text(0, 1.2, labels[i], transform=ax.transAxes,
            fontsize=16*font_scale, fontweight='bold', va='top', ha='right') 

fig.tight_layout()
fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_E.pdf"), bbox_inches='tight')

In [None]:
designed_pack = f"{source_xvae}.{vis_up_nat}.final"
designed_mean = df_eval.query(f'pack == "{designed_pack}"')[f'mean_recall_{mhc_1_peptide_lengths_col}'].mean()
natural_mean = df_eval.query(f'source == "natural"')[f'mean_recall_{mhc_1_peptide_lengths_col}'].mean()

p_value = ttest(
    df_eval.query(f'source == "natural"')[f'mean_recall_{mhc_1_peptide_lengths_col}'].astype(float), 
    df_eval.query(f'pack == "{designed_pack}"')[f'mean_recall_{mhc_1_peptide_lengths_col}'].astype(float))

text = f"The mean {designed_pack} recall value of {designed_mean:.3f} is "
text += f"different from the natural mean of {natural_mean:.3f} "
text += f"with a p value of {p_value} "
log_info(text)

### Figure F - Highest recall inc-nat sequence

In [None]:
def plot_vis_up_seqs(df_eval, d_df_selected, peptide_lengths_col, font_scale=1., bins=[0, 0.001, 0.01, 0.1, 0.5, 1.]):
    fig = plt.figure(figsize=(A4_width, A4_height/2))
    gs = mpl.gridspec.GridSpec(
        1 + len(d_df_selected), 
        len(d_df_selected) + 1, 
        height_ratios=[2] + [1]*len(d_df_selected), 
        width_ratios=[1]*len(d_df_selected) + [1],
        wspace=0.1, hspace=0.3)

    all_axes = []
    for idx, (source, df_selected) in enumerate(d_df_selected.items()):
        ax = plot_seq_epitopes(gs[0, idx], df_selected, legend=(idx < 1), title=f"{source}\n")
        offset = -0.4 if idx == 0 else 0
        all_axes.append((ax, offset))

    ax = plot_seqs_precision_bar(gs[0, -1], d_df_selected, font_scale=0.6, kind='bar', bins=bins)
    all_axes.append((ax, -0.05))

    for idx, (source, df_selected) in enumerate(d_df_selected.items()):
        ax = plot_epitope_recall_by_seq(
            gs[1 + idx, :],  
            df_eval.query(f'source == "natural" and seq_hash != "{df_selected.seq_hash}"'), 
            df_selected.seq_hash,
            peptide_lengths_col
        )
        ax.set_xlim(0.0, 1.0)

        if idx != 0:
            ax.set_title('')
        
        if idx != (len(d_df_selected) - 1):
            ax.set_xlabel('')
            ax.set_xticklabels([])
        ax.set_ylabel(f"{source} recall in \n{ax.get_ylabel()}")
        all_axes.append((ax, -0.13))

    labels = [f'{c})' for c in string.ascii_lowercase]
    for i, (ax, x_offset) in enumerate(all_axes):
        ax.text(x_offset, 1.2, labels[i], transform=ax.transAxes,
                fontsize=16*font_scale, fontweight='bold', va='top', ha='right') #, color='red')

    return fig

In [None]:
XVAE_row = df_eval.loc[
    df_eval.query(f"pack == '{source_xvae}.{vis_up_nat}.final' and max_TMscore > 0.95")[f'mean_recall_{mhc_1_peptide_lengths_col}'].idxmax()
]
df_selected_XVAE = predictor_mhc_1.get_seq_kmers(XVAE_row['seq'], mhc_1_alleles, 9)

selected_row = XVAE_row
df_selected = df_selected_XVAE
df_selected_name = 'XVAE'

In [None]:
add_precision_to_seq_kmers(df_selected, natural_kmers_data, consider_one=True)
log_info([selected_row.seq_hash,
          selected_row.seq,
          f"visibility: {selected_row.visibility}",
          f"Mean recall {selected_row[f'mean_recall_{mhc_1_peptide_lengths_col}']:.3f}",
          f"P01: {selected_row[f'P01_recall_{mhc_1_peptide_lengths_col}']:.3f}",
          f"P1: {selected_row[f'P1_recall_{mhc_1_peptide_lengths_col}']:.3f}",
          f"P5: {selected_row[f'P5_recall_{mhc_1_peptide_lengths_col}']:.3f}",
          selected_row.max_TMscore,
          selected_row.func_nb_pred
         ])

In [None]:
add_metrics(df_selected.seq_hash, df_eval, d_visible, mhc_1_peptide_lengths_col)
fig = plot_vis_up_seqs(df_eval, {df_selected_name: df_selected}, mhc_1_peptide_lengths_col, bins=[0, 0.0000000000001, 0.001, 0.01, 0.1, 0.5, 1.])
fig.savefig(kit.path.join(G.ENV.ARTEFACTS, "figures", G.DOMAIN, f"Figure_F.pdf"), bbox_inches='tight')

### Figure G - phylogenetic Tree

In [None]:
    seq_hash_generated_maxTMscore = df_eval.loc[df_eval.query('source in ["CAPE-XVAE"]')['max_TMscore'].idxmax()].seq_hash
    log_info(f"Most similar generated XVAE sequence: {seq_hash_generated_maxTMscore}")
    seq_hash_incnat_maxTMscore = df_eval.loc[df_eval.query(f'pack in ["{source_xvae}.{vis_up_nat}.final"]')['max_TMscore'].idxmax()].seq_hash
    log_info(f"Most similar generated XVAE.inc-nat sequence: {seq_hash_incnat_maxTMscore}")
    seq_hash_reduce_maxTMscore = df_eval.loc[df_eval.query(f'pack in ["{source_xvae}.{vis_down}.final"]')['max_TMscore'].idxmax()].seq_hash
    log_info(f"Most similar generated XVAE.reduce sequence: {seq_hash_reduce_maxTMscore}")

In [None]:
def get_best_proxies(df_eval, sources=None, packs=None):
    proxies = {}
    if packs is not None:
        for pack in packs:
            row = df_eval.loc[df_eval.query(f'pack in ["{pack}"]')['max_TMscore'].idxmax()]
            seq_hash = row.seq_hash
            seq = row.seq
            log_info(f"Most similar {pack} sequence: {seq_hash}")
            proxies[pack] = {'seq_hash': seq_hash, 'seq': seq}

    if sources is not None:
        for source in sources:
            row = df_eval.loc[df_eval.query(f'source in ["{source}"]')['max_TMscore'].idxmax()]
            seq_hash = row.seq_hash
            seq = row.seq
            log_info(f"Most similar {source} sequence: {seq_hash}")
            proxies[source] = {'seq_hash': seq_hash, 'seq': seq}

    return proxies

In [None]:
# https://itol.embl.de/
if construct_tree:
    min_occupancy = 0.8
    
    msa_phylo_packs = ['support']
    
    sf = SequenceFrame()
    sf.from_seqs(list(df_eval.query(f"""pack in ['{"', '".join(msa_phylo_packs)}']""").seq))
    
    msa_phylo = MultipleSequenceAligner(sf, phylogenetic_tree_model_name = 'raxml')
    # msa_phylo.msa_annotations = {}
    # for seq_hash in sf.seq_hash:
    #     recall_natural = df_eval.query(f"seq_hash == '{seq_hash}'").iloc[0][f'recall_{df_selected_natural.seq_hash}']
    #     recall_XVAE = df_eval.query(f"seq_hash == '{seq_hash}'").iloc[0][f'recall_{df_selected_XVAE.seq_hash}']
    #     precision_natural = df_eval.query(f"seq_hash == '{seq_hash}'").iloc[0][f'precision_{df_selected_natural.seq_hash}']
    #     precision_XVAE = df_eval.query(f"seq_hash == '{seq_hash}'").iloc[0][f'precision_{df_selected_XVAE.seq_hash}']
    #     msa_phylo.msa_annotations[seq_hash] = \
    #         f'_RECALL_natural_{recall_natural:.4f}_XVAE_{recall_XVAE:.4f}' + \
    #         f'_PRECISION_natural_{precision_natural:.4f}_XVAE_{precision_XVAE:.4f}'
    
    msa_phylo.align(msa_file_path=join(G.ENV.ARTEFACTS, 'phylo', G.DOMAIN, 'support.phy'))
    
    #msa_phylo.calculate_occupancy()
    #msa_phylo.set_occupancy_threshold(min_occupancy)
    msa_phylo.construct_phylogenetic_tree()


    #
    # generate annotation CSV
    #

    proxies = get_best_proxies(df_eval, sources=[source_xvae], packs=[f'{source_xvae}.{vis_up_nat}.final', f'{source_xvae}.{vis_down}.final'])
    
    seq_hashes_support = list(df_eval.query('source == "support"').seq_hash)

    df_pr = pd.DataFrame(
        index=seq_hashes_support, 
        columns=[f'{metric}_{seq_hash}' for seq_hash in seq_hashes_proxies for metric in ['recall', 'precision'] ]
    )

    for seq_hash_target in seq_hashes_support:
        for proxy in proxies.values():
            seq_hash_proxy = proxy['seq_hash']
            recall, precision = f_metrics(d_visible, seq_hash_proxy, seq_hash_target)
            df_pr.at[seq_hash_target, f"recall_{seq_hash_proxy}"] = recall
            df_pr.at[seq_hash_target, f"precision_{seq_hash_proxy}"] = precision
    
    df_pr.to_csv(join(G.ENV.ARTEFACTS, 'phylo', G.DOMAIN, 'support_pr.tsv'), sep='\t')
    for seq_hash_proxy in seq_hashes_proxies:
        df_pr[f'recall_{seq_hash_proxy}'].to_csv(join(G.ENV.ARTEFACTS, 'phylo', G.DOMAIN, f'support_{seq_hash_proxy}.tsv'), sep='\t')


    # construct a phylo tree from supports + a different designed sequence
    for proxy_name, proxy in tqdm(proxies.items()):
        seq_hash_proxy = proxy['seq_hash']
        seq_proxy = proxy['seq']
      
        sf = SequenceFrame()
        sf.from_seqs(list(df_eval.query(f"""pack in ['support']""").seq) + [seq_proxy])
        msa_phylo = MultipleSequenceAligner(sf, phylogenetic_tree_model_name = 'raxml')
        msa_phylo.align(msa_file_path=join(G.ENV.ARTEFACTS, 'phylo', G.DOMAIN, f'support_plus_{seq_hash_proxy}', f'support_plus_{seq_hash_proxy}.phy'))
        msa_phylo.construct_phylogenetic_tree()

# save df_for_eval

In [None]:
df_eval.to_csv(os.path.join(G.ENV.ARTEFACTS, "df_eval.csv"))

# Close DB

In [None]:
db.conn.close()

# pymol

In [None]:
import py3Dmol
from CAPE.RL.reward import rewards_seqs

In [None]:
seq_hash_support = df_eval.query(f"source == 'support'").seq_hash.iloc[0]
print(f"Suppoprt: {seq_hash_support}")
seq_hash_di = df_eval.sort_values(by=['visibility'], ascending=True).iloc[0].seq_hash
print(f"Suppoprt: {seq_hash_di}")

In [None]:
def view_structure(seq_hash):
    fasta_file_path = db.get_fasta_file_path(seq_hash)
    pdb_file_path = db.get_pdb_file_path(seq_hash)
    png_file_path = join(G.ENV.ARTEFACTS, 'rendered_structures', f"{seq_hash}_{db.predictor_structure_name}.png")
   
    with open(pdb_file_path) as ifile:
        system = "".join([x for x in ifile])
   
    sequence = kit.bioinf.fasta.read_fasta(fasta_file_path, return_df=True).index[0]
    rewards = rewards_seqs(db.predictors_mhc_1[predictor_mhc_1_name], [sequence], mhc_1_alleles, vis_up)[0].astype('int')

    view = py3Dmol.view(width=800, height=800)
    view.addModelsAsFrames(system)

    colors = [
        '#22FFFF',
        '#44BBBB',
        '#669999',
        '#887777',
        '#AA5555',
        '#CC3333',
        '#EE1111',
    ]

    i = 0
    for line in system.split("\n"):
        split = line.split()

        if len(split) == 0 or split[0] != "ATOM":
            continue

        resid = int(split[5])
        if rewards[resid-1] != 0:
            color = colors[rewards[resid-1]]
            idx = int(split[1])

            view.setStyle({'model': -1, 'serial': i+1}, {"sphere": {'color': color}})
        else:
            view.setStyle({'model': -1, 'serial': i+1}, {"cartoon": {}})
        i += 1

    view.zoomTo()
    view.show()
    view.render(filename=png_file_path)
    print(png_file_path)

In [None]:
view_structure(seq_hash_support)

In [None]:
view_structure(seq_hash_di)

In [None]:
df_eval.info()

# Test

In [None]:
import random

def random_substring(input_string, length):
    if length > len(input_string):
        raise ValueError("Length of substring cannot be greater than the length of the input string.")
    start_index = random.randint(0, len(input_string) - length)
    return input_string[start_index:start_index + length]

In [None]:
seq_xvae = df_eval.query('source == "CAPE-XVAE"').sample().iloc[0].seq

In [None]:
res_9mers = []
res_recall = []
res_precision = []
for _ in range(1000):
    seq_xvae = df_eval.query('source == "natural"').sample().iloc[0].seq  # .query('pack == "CAPE-XVAE.inc-nat.final"').sample().iloc[0].seq
    seq_natural = df_eval.query('source == "natural"').sample().iloc[0].seq

    r = 0
    for _ in range(30):
        peptide = random_substring(seq_xvae, 9)
        if peptide in seq_natural:
            r += 1
    res_9mers.append(r/30)

    vis_xvae = set([x[0] for x in predictor_mhc_1.seq_presented(seq_xvae, mhc_1_alleles)])
    vis_natural = set([x[0] for x in predictor_mhc_1.seq_presented(seq_natural, mhc_1_alleles)])

    res_recall.append(len(vis_xvae.intersection(vis_natural))/len(vis_natural))
    res_precision.append(len(vis_xvae.intersection(vis_natural))/len(vis_xvae))

In [None]:
np.mean(res)

In [None]:
np.mean(res_9mers)

In [None]:
np.mean(res_9mers)

In [None]:
np.mean(res_recall)

In [None]:
np.mean(res_precision)

In [None]:
list('abc')

In [None]:
res_9mers