# Imports

In [None]:
import os
import pickle
import shutil
import re

import pandas as pd
import numpy as np
import copy
from tqdm.auto import tqdm
import glasbey

import Levenshtein

import statsmodels.api as sm
from scipy import stats

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import PercentFormatter
import glasbey
import seaborn as sns
import string

from Bio import Align
from Bio.Align import substitution_matrices

import kit
import kit.globals as G
from kit.loch.oo import Loch
from kit.loch.path import cp_fasta_to_dir
from kit.path import join, get_entries
from kit.data import DD, Split, file_to_str, str_to_file
from kit.data.utils import set_df_cols_to, cast_df_cols_to
from kit.data.trees import PrefixTree
from kit.bioinf import get_kmers
from kit.bioinf.alignment.structure.tm_align import align_structures
from kit.bioinf.immuno.mhc_1 import Mhc1Predictor, MHC_1_PEPTIDE_LENGTHS
from kit.bioinf.immuno.utils import get_mhc_1_setup_hash
from kit.bioinf.utils import get_seq_hash
from kit.bioinf.utils.filter import filter_pdbs
from kit.bioinf.proteins.similar import get_similar_proteins
from kit.bioinf.proteins import Protein, ProteinType
from kit.bioinf.pdb import download_structure, includes_dna, check_xray, get_protein_name_organism
from kit.bioinf.fasta import fasta_to_df, read_fasta
from kit.plot import A4_height, A4_width, plot_text, plot_legend_patches

from CAPE.MPNN import run_mpnn
from CAPE.MPNN.model import CapeMPNN
from CAPE.MPNN.data.aux import S_to_seqs
from CAPE.MPNN.beam import set_config as set_config_beam, get_beam_search_hash, get_features_from_pdb

# Startup

In [None]:
args, args_unknown = kit.init('CAPE', 'CAPE-Beam', create_job=False, arg_string=None)

In [None]:
G.DOMAIN = 'v2'

# Configs

## Design

In [None]:
kit.DEVICE = 'cpu'

BASE_MODEL_NAME = 'v_48_020'
PROTEOME_FILE_NAME = "2022-05-29-Homo_sapiens_GRCh38_biomart_v94.fasta"
ALPHABET = 'ACDEFGHIKLMNPQRSTVWYX'

MHC_1_PREDICTOR_DECODING = 'pwm_dynamic'

## Hyper-parameters

In [None]:
NON_SELF_PROB_FACTORS = [0., 0.01, 0.1, 0.25, 0.5, 0.9, 0.95, 0.99, 0.999]
NON_SELF_PROB_FACTORS = [0., 0.1, 0.5, 0.9, 0.99]
MIN_PROTEOME_KMER_LENS = [5, 6, 7]

## Genotypes

In [None]:
person_mhc_1_genotype = "HLA-A*02:01+HLA-A*24:02+HLA-B*07:02+HLA-B*39:01+HLA-C*07:01+HLA-C*16:01"
person_mhc_1_genotype_list = person_mhc_1_genotype.split("+")

alternative_mhc_1_genotypes = ["HLA-A*29:02+HLA-A*30:07+HLA-B*15:13+HLA-B*57:01+HLA-C*14:02+HLA-C*04:04"]

all_mhc_1_genotypes = [person_mhc_1_genotype] + alternative_mhc_1_genotypes

immuno_setup = {'mhc_1': person_mhc_1_genotype}

DISTANCE_RANK_THRESHOLD = 0.02  # used to identify alternative genotypes

ANALYSE_PWM_PREDICTOR = False

## Protein Data

In [None]:
SEED = 42
MIN_SPECIFIC_CHAIN_LENGTH = 51
MAX_PROTEIN_LENGTH = 2000

In [None]:
SPECIFIC_MANUAL = {
    # Split.VAL: sorted(['1B9K', '1OA4', '1QWK', '1TJE', '1XGD', '4RQG']),
    Split.VAL: sorted(['1B9K', '1OA4', '1QWK', '1TEJ', '1TJE', '1UIZ', '1XGD', '2C2X', '2QT4', '4RQG', '1QTS', '6Q3V', '4BVK', '2R90', '5XBH']),
    Split.TEST: sorted(['1A3H', '1P3C', '1PGS', '1QKD', '1S5T', '1X0M', '2BK8', '3O6A', '3TIP', '3WOY',
                        '4BOK', '4QTZ', '1A2O', '5OA9', '6PNW', '3SFT', '5V5F', '6TPT', '5TPJ', '3RFW']),
    Split.PREDICT: [],
}

N_SPECIFIC = {
    Split.VAL: 15,
    Split.TEST: 20,
    Split.PREDICT: 0
}

IDENTIFY_SPECIFIC = True if any(len(SPECIFIC_MANUAL[split]) < N_SPECIFIC[split] for split in [Split.VAL, Split.TEST, Split.PREDICT]) else False
LOAD_RESULTS = None

CANDIDATE_MAX_UNMODELLED = 5
CANDIDATE_EXCLUDE_SMALL_MOLECULES = True

In [None]:
SIMILAR_MIN_PROTEINS_CNT = 3
SIMILAR_MAX_PROTEINS_CNT = 10

SIMILAR_MAX_SCORE = 0.99
SIMILAR_MIN_SCORE = 0.1
SIMILAR_MIN_TM_SCORE = 0.9
SIMILAR_MAX_LENGTH_DIFF = 0.1

In [None]:
KEEP_CHAINS = {
    '1UIZ': ['A', 'B', 'C'],
    '4BVK': ['A'],
}

AF_ERRORS_SEQ_HASHES = [
    '409f67caa69e3646d259a003db9e1ae79174e7ab317cf584eab3d0e31c66c4f1', 'e9a929bd99de028d6a03c415b668f0688cca0f98d62d7169e3231056f5383b85',
    '86403f1cba1c1f920de5861e20cd71d91b56a226df783a75f24dffaf73d1e8c1', 'fdc5626917d775784984dd415d639b41e3053bb3967ddb5dbeb0ce57b4a11fc5',
    'c0f7d4c3dd8b0d83bee7047c164c8dc2fe8f306b0c574614ebc3dc3261ce6ac8'
]

## Benchmarks

In [None]:
SAMPLING_TEMPS = [1e-8, 0.1]

benchmark_sources = {
    'template': [{'eval_mhc_1_genotype': mhc_alleles} for mhc_alleles in all_mhc_1_genotypes], 
    'standard': [{'sampling_temp': sampling_temp, 'eval_mhc_1_genotype': mhc_genotype} for sampling_temp in SAMPLING_TEMPS for mhc_genotype in all_mhc_1_genotypes], 
    'CAPE-MPNN': [{'sampling_temp': sampling_temp, 'tune_mhc_1_genotype': person_mhc_1_genotype, 'eval_mhc_1_genotype': person_mhc_1_genotype} for sampling_temp in SAMPLING_TEMPS]
}
benchmark_source_ids = list(benchmark_sources.keys())

In [None]:
CAPE_MPNN_WEIGHTS_PATH = os.path.join(G.ENV.PROJECT, 'artefacts', 'CAPE-MPNN', 'models')
CAPE_MPNN_CKPT_ID = '458340e4:epoch_20'

## Figures

In [None]:
fig_X_configs = {split: [
        {'eval_mhc_1_genotype': (person_mhc_1_genotype,), 'sampling_temp': (None, _sampling_temp)} for _sampling_temp in SAMPLING_TEMPS
    ] for split in Split
}
fig_A_configs = {split: [
        {'eval_mhc_1_genotype': (person_mhc_1_genotype,), 'sampling_temp': (None, _sampling_temp)} for _sampling_temp in SAMPLING_TEMPS
    ] for split in Split
}
fig_B_configs = {split: [
        {'eval_mhc_1_genotype': (person_mhc_1_genotype,), 'sampling_temp': (None, _sampling_temp)} for _sampling_temp in SAMPLING_TEMPS
    ] for split in Split
}
fig_C_configs = {split: [
        {'eval_mhc_1_genotype': (person_mhc_1_genotype,), 'sampling_temp': (None, _sampling_temp)} for _sampling_temp in SAMPLING_TEMPS
    ] for split in Split
}

## System

In [None]:
SAVE_FIGURES = True
OVERWRITE = False

In [None]:
PDB_SERVER = r'https://files.rcsb.org'
FASTA_OUTPUT_PATH = join(G.ENV.ARTEFACTS, 'designs')
PROTEIN_MPNN_REPO_PATH = os.path.join(G.ENV.PROJECT, 'external', 'repos', 'ProteinMPNN')

In [None]:
LOCH_PATH = join(G.ENV.ARTEFACTS, 'loch')
PDB_DIR_PATH = join(G.ENV.PROJECT, 'data', 'input', 'PDBs')
COLABFOLD_PATH = join(G.PROJECT_ENV.ARTEFACTS, 'colabfold')
SIMILAR_PROTEINS_PATH = join(G.ENV.ARTEFACTS, 'eval', 'similar_PDBs')

GENERAL_DATA_DIR_PATH = os.path.join(G.ENV.INPUT, "CAPE-MPNN", "pdb_2021aug02")

NETMHCPAN_DIR_PATH = join(G.ENV.ARTEFACTS, "eval", "immuno", "mhc_1", "netmhcpan")
if not os.path.exists(os.path.join(NETMHCPAN_DIR_PATH, 'definitions')):
    os.symlink(
        '../../../../../../data/input/immuno/mhc_1/Mhc1PredictorPwm/definitions',
        os.path.join(NETMHCPAN_DIR_PATH, 'definitions')
    )

PEPTIDE_RANKS_DIR_PATH = os.path.join(os.environ['PF'], 'data', 'input', "immuno", "mhc_1", "Mhc1PredictorPwm", "ranks")
PWM_PREDICTOR_TEST_RANKS_DIR_PATH = os.path.join(os.environ['PF'], 'data', 'input', "immuno", "mhc_1", "Mhc1PredictorPwm", "ranks")

In [None]:
DESTRESS_PROG_DIR_PATH = None

In [None]:
set_config_beam(PROTEIN_MPNN_REPO_PATH)
CapeMPNN.base_model_pt_dir_path = os.path.join(G.ENV.INPUT, "CAPE-MPNN", "vanilla_model_weights")
CapeMPNN.base_model_yaml_dir_path = os.path.join(G.ENV.INPUT, 'CAPE-MPNN', 'base_hparams')

## Eval

In [None]:
MHC_1_PREDICTOR_EVAL = 'netmhcpan'

In [None]:
PLOT_MIN_SELF_KMER_LENGTH = [5, 6, 7]
PLOT_NON_SELF_PROB_FACTORS = [0., 0.1, 0.5, 0.9, 0.99] #, 0.999]
PLOT_MIN_TM_SCORE = 0.9

# Function definitions

In [None]:
def get_immuno_chains(source_id, protein_id, seq):
    protein = proteins[protein_id]

    chains = [c.replace("-", "X") for c in seq.split('/')]
    if protein.get_protein_type() != ProteinType.HETEROOLIGOMER:
        if len(set(chains)) > 1:
            print(f"WARNING: {source_id} {protein_id}... {len(set(chains))} different chain seq")
        immuno_chains = chains[:1]
    else:
        immuno_chains = chains
    return immuno_chains

def get_possible_peptides(immuno_chains, lengths):
    n = 0
    for chain in immuno_chains:
        for length in lengths:
            n += max(len(chain) - length + 1, 0)
    return n

In [None]:
def get_peers(pdb_id):
    similar_proteins = get_similar_proteins(
        pdb_id, PDB_DIR_PATH,
        similar_max_score = SIMILAR_MAX_SCORE,
        similar_min_score = SIMILAR_MIN_SCORE,
        similar_min_tm_score = SIMILAR_MIN_TM_SCORE,
        similar_max_count = SIMILAR_MAX_PROTEINS_CNT,
        max_length_diff = SIMILAR_MAX_LENGTH_DIFF,
        similar_proteins_dir_path = SIMILAR_PROTEINS_PATH,
        server=PDB_SERVER
    )
    if len(similar_proteins) >= SIMILAR_MIN_PROTEINS_CNT:
        return similar_proteins
    else:
        return {}


def add_to_specific_proteins(liste, pdb_id, peer_groups, KEEP_CHAINS):
    protein = None
    if pdb_id not in liste:
        exclude = False
        for _pdb_id in liste:
            if Levenshtein.distance(pdb_id, _pdb_id) <= 1:
                exclude = True

            if _pdb_id in peer_groups:
                if pdb_id in peer_groups[_pdb_id]:
                    exclude = True

            if not exclude:
                pdb_file_path = os.path.join(PDB_DIR_PATH, f"{pdb_id}.pdb")
                protein = Protein.from_pdb(pdb_file_path, keep_chains=KEEP_CHAINS.get(pdb_id, None))
                if MIN_SPECIFIC_CHAIN_LENGTH > len(list(protein.chains.values())[0]):
                    exclude = True
                elif np.sum([len(x) for x in protein.chains.values()]) > MAX_PROTEIN_LENGTH:
                    exclude = True
                elif includes_dna(os.path.join(PDB_DIR_PATH, f"{pdb_id}.pdb")):
                    exclude = True
                elif not check_xray(os.path.join(PDB_DIR_PATH, f"{pdb_id}.pdb")):
                    exclude = True

        if not exclude:
            liste.append(pdb_id)
        else:
            protein = None
            
    return protein

In [None]:
def get_df_decodings(benchmark_sources, pdb_ids, min_self_kmer_lens, non_self_prob_factors, cape_beam_mhc_1_genotypes):
    dfs = []

    # add Benchmark rows
    for benchmark_source_id, versions in benchmark_sources.items():
        for version in versions:
            df = pd.DataFrame(data=pdb_ids, columns=['protein_id'])
            df['source_id'] = benchmark_source_id
            for _col, _val in version.items():
                df[_col] = _val
            dfs.append(df)

    # add CAPE-Beam rows
    columns = ['checked_kmer_length', 'min_self_kmer_length', 'width', 'branching_factor', 'depth', 'tune_mhc_1_genotype', 'tune_mhc_1_predictor', 'eval_mhc_1_genotype', 'non_self_prob_factor']
    for kmer_length in min_self_kmer_lens:
        for non_self_prob_factor in non_self_prob_factors:
            for mhc_1_genotype in cape_beam_mhc_1_genotypes:
                df = pd.DataFrame(data=pdb_ids, columns=['protein_id'])
    
                if non_self_prob_factor is None:
                    config = [kmer_length, kmer_length, 10, 1, kmer_length * 2, None, None, None, None]
                else:
                    config = [10, kmer_length, 10, 1, kmer_length * 2, mhc_1_genotype, MHC_1_PREDICTOR_DECODING, mhc_1_genotype, non_self_prob_factor]
                    
                set_df_cols_to(df, columns, config)
    
                df['proteome_file_name'] = PROTEOME_FILE_NAME
                df['prune_min_acc_log_prob'] = -2000.
                dfs.append(df)

    df_result = pd.concat(dfs)
    df_result.reset_index(inplace=True)

    # add design properties columns
    set_df_cols_to(
        df_result,
        ['seq_hash', 'seq', 'tm_data', 'kmers_not_in_proteome', 'kmers_presented', 'kmers_presented_pwm_dynamic', 'kmers_presented_netmhcpan', 'proportion', 'sampling_state'],
        None
    )
    df_result['protein_type'] = df_result.apply(lambda row: protein_infos[row.protein_id][1], axis=1)
    
    # cast to integer
    cast_df_cols_to(df_result, ['checked_kmer_length', 'min_self_kmer_length', 'width', 'branching_factor', 'depth'], int, -1)

    return df_result

## Sampling

In [None]:
def sample_std(protein_id, pdb_input_file_path, protein_type, sampling_temp):
    output_dir_path = join(FASTA_OUTPUT_PATH, protein_id, 'standard', str(sampling_temp))

    fasta_output_file_path = join(output_dir_path, 'result.txt')
    seed = 37
    run_mpnn('v_48_020', 
             pdb_input_file_path, 
             fasta_output_file_path, 
             seed, 
             protein_type=protein_type,   
             designed_positions=None, 
             sampling_temp=sampling_temp
    )
    seq = fasta_to_df(fasta_output_file_path).iloc[1].seq
    str_to_file(seq, os.path.join(output_dir_path, 'beams.txt'))

In [None]:
def get_cape_mpnn_design_path(fasta_output_path, cape_mpnn_ckpt_id, protein_id, sampling_temp):
    return join(fasta_output_path, protein_id, "CAPE-MPNN", str(sampling_temp), cape_mpnn_ckpt_id)
    
def get_cape_mpnn_fasta_file_path(fasta_output_path, cape_mpnn_ckpt_id, protein_id, seed, sampling_temp):
    return join(get_cape_mpnn_design_path(fasta_output_path, cape_mpnn_ckpt_id, protein_id, sampling_temp), f"{protein_id}_{seed}.fasta")

def generate_CAPE_MPNN(protein_id, pdb_input_file_path, fasta_output_path, sampling_temp, overwrite=False):
    template_protein = proteins[protein_id]
    protein_type = template_protein.get_protein_type()
    template_seq = template_protein.get_seq()
    
    designed_positions = None

    model_id, ckpt = CAPE_MPNN_CKPT_ID.split(':')

    success = True
    for trial in [1, 2, 3]:
        seed = 36 + trial
        fasta_output_file_path = get_cape_mpnn_fasta_file_path(fasta_output_path, CAPE_MPNN_CKPT_ID, protein_id, seed, sampling_temp)
        if overwrite or not os.path.exists(fasta_output_file_path):
            run_mpnn(
                (os.path.join(CAPE_MPNN_WEIGHTS_PATH, model_id, 'ckpts'), ckpt),
                pdb_input_file_path,
                fasta_output_file_path,
                seed,
                protein_type=protein_type,
                designed_positions=designed_positions,
                sampling_temp=sampling_temp
            )

        if not os.path.exists(fasta_output_file_path):
            print(f"{ckpt_id:30s} {pdb_id} {trial} does not exist!")
            success = False
        else:
            seq = fasta_to_df(fasta_output_file_path).iloc[1].seq

            assert len(seq) == len(template_seq)

    if success:
        source_design_dir_path = os.path.join(FASTA_OUTPUT_PATH, protein_id, "CAPE-MPNN")
        design_file_path = os.path.join(source_design_dir_path, sampling_temp, CAPE_MPNN_CKPT_ID, 'beams.txt')
        str_to_file("", design_file_path)

In [None]:
def get_row_source_id(row):
    if row.source_id not in benchmark_source_ids:
        return get_source_id(
            row.checked_kmer_length,
            PROTEOME_FILE_NAME,
            row.min_self_kmer_length,
            row.tune_mhc_1_genotype,
            immuno_predictor_setups[row.tune_mhc_1_predictor]['mhc_1'],
            row.non_self_prob_factor,
            row.width,
            row.branching_factor,
            row.depth,
            row.prune_min_acc_log_prob
        )        
    return row.source_id


def design_sequences(df_decodings, loch):
    todo_cnt = 0
    for i, row in tqdm(df_decodings.iterrows()):
        source_id = get_row_source_id(row)
        df_decodings.at[i, 'source_id'] = source_id
        
        source_design_dir_path = os.path.join(FASTA_OUTPUT_PATH, row.protein_id, source_id)
        if source_id == 'standard':
            design_file_path = os.path.join(source_design_dir_path, str(row.sampling_temp), 'beams.txt')
        elif source_id == 'CAPE-MPNN':
            design_file_path = os.path.join(source_design_dir_path, str(row.sampling_temp), CAPE_MPNN_CKPT_ID, 'beams.txt')
        else:
            design_file_path = os.path.join(source_design_dir_path, 'beams.txt')
        exception_file_path = os.path.join(source_design_dir_path, 'exception.txt')

        if not os.path.exists(design_file_path) and not os.path.exists(exception_file_path):
            # pdb_input_file_path = os.path.join(PDB_DIR_PATH, f'{row.protein_id}.pdb')
            pdb_input_file_path = loch.get_pdb_file_path(proteins[row.protein_id].seq_hash, predictor_structure_name='exp')

            if source_id == 'standard':
                sample_std(row.protein_id, pdb_input_file_path, row.protein_type, str(row.sampling_temp))
            elif source_id == 'template':
                str_to_file(proteins[row.protein_id].get_seq(), join(FASTA_OUTPUT_PATH, row.protein_id, 'template', 'beams.txt'))
            elif source_id.startswith('CAPE-MPNN'):
                generate_CAPE_MPNN(row.protein_id, pdb_input_file_path, FASTA_OUTPUT_PATH, str(row.sampling_temp))
            elif source_id not in benchmark_source_ids:
                command = ["cape-beam.py", 
                           '--pdb_input_file_path', pdb_input_file_path, '--protein_id', row.protein_id, '--protein_type', row.protein_type, 
                           '--output_dir_path', os.path.join(G.ENV.ARTEFACTS, 'designs'),
                           '--proteome_file_name', row.proteome_file_name, '--min_self_kmer_length', row.min_self_kmer_length,
                           '--checked_kmer_length', row.checked_kmer_length,
                           '--width', row.width, '--depth', row.depth, '--branching_factor', row.branching_factor,
                           '--prune_min_acc_log_prob', row.prune_min_acc_log_prob
                ]
                if row.tune_mhc_1_genotype is not None:
                    command += [
                        '--mhc_1_alleles', row.tune_mhc_1_genotype,
                        '--mhc_1_predictor', row.tune_mhc_1_predictor,
                        '--non_self_prob_factor', row.non_self_prob_factor                        
                    ]
                command = [str(c) for c in command]
                print(f'# {source_id}')
                print(' '.join(command))
                todo_cnt += 1

    print(f"missing searches: {todo_cnt}")




## Evaluate 

In [None]:
def to_colabfold(df, loch, ignore_seq_hashes=None):
    for i, row in tqdm(df.iterrows()):
        source_id = get_row_source_id(row)

        seqs = []
        if source_id == 'template':
            seqs.append(proteins[row.protein_id].get_seq())  
        elif source_id.startswith('CAPE-MPNN'):
            if row.seq_hash is not None and not isinstance(row.seq_hash, list):
                seqs = None
            else:
                for trial in [1, 2, 3]:
                    fasta_output_file_path = get_cape_mpnn_fasta_file_path(FASTA_OUTPUT_PATH, CAPE_MPNN_CKPT_ID, row.protein_id, 36 + trial, row.sampling_temp)
                    seq = fasta_to_df(fasta_output_file_path).iloc[1].seq
                    seqs.append(seq)
        else:
            if source_id == "standard":
                source_design_dir_path = os.path.join(FASTA_OUTPUT_PATH, row.protein_id, source_id, str(row.sampling_temp))
            else:
                source_design_dir_path = os.path.join(FASTA_OUTPUT_PATH, row.protein_id, source_id)             
                
            design_file_path = os.path.join(source_design_dir_path, 'beams.txt')
            exception_file_path = os.path.join(source_design_dir_path, 'exception.txt')

            if os.path.exists(exception_file_path):
                df.at[i, 'seq_hash'] = None
                df.at[i, 'seq'] = None
                continue
            elif os.path.exists(design_file_path):
                seqs.append(file_to_str(design_file_path).split('\n')[0])


        if seqs is not None:
            seq_hashes = []
            for seq in seqs:
                # add to loch
                seq_hashes.append(seq_hash := loch.add_seq(seq))

                if seq_hash not in ignore_seq_hashes:
                    # copy to the colabfold path
                    cp_fasta_to_dir(seq_hash, os.path.join(COLABFOLD_PATH, 'input'), translate=("/-", ":X", ""))
                else:
                    print(f"Ignore: {row.source_id} {row.protein_id}")
    
            df.at[i, 'seq'] = seqs
            df.at[i, 'seq_hash'] = seq_hashes


def from_colabfold(df_decodings, loch):
    COLABFOLD_OUTPUT_PATH = os.path.join(COLABFOLD_PATH, 'output')
    for i, row in tqdm(df_decodings.iterrows()):
        seq_hashes = row.seq_hash
        if not isinstance(row.seq_hash, list):
            seq_hashes = [row.seq_hash]
            
        for seq_hash in seq_hashes:
            target_file_path = loch.get_pdb_file_path(seq_hash, predictor_structure_name='AF')
            if not os.path.exists(target_file_path):
                files = get_entries(os.path.join(COLABFOLD_OUTPUT_PATH), f"{seq_hash}_relaxed_rank_001_.+\.pdb")
                if len(files) == 1:
                    for file_name, file_paths in files.items():                        
                        shutil.copy(os.path.join(COLABFOLD_OUTPUT_PATH, file_name), target_file_path)
                if len(files) > 1:
                    raise Error()

        if isinstance(row.seq_hash, list) and len(row.seq_hash) == 1:
            df_decodings.at[i, 'seq'] = row.seq[0]
            df_decodings.at[i, 'seq_hash'] = row.seq_hash[0]
           

def add_tm_data(df_decodings, loch, overwrite=False):
    for i, row in tqdm(df_decodings.iterrows()):
        if row.tm_data is None or overwrite:
            source_id = get_row_source_id(row)
            source_design_dir_path = os.path.join(FASTA_OUTPUT_PATH, row.protein_id, source_id)
            design_file_path = os.path.join(source_design_dir_path, 'beams.txt')
            exception_file_path = os.path.join(source_design_dir_path, 'exception.txt')
    
            if row.seq_hash is not None:
                predictor_structure_name = 'exp' if row.source_id == 'template' else 'AF'
                source_pdb_file_path = loch.get_pdb_file_path(row.seq_hash, predictor_structure_name=predictor_structure_name)

                if os.path.exists(source_pdb_file_path):
                # source_pdb_file_path = list(get_entries(os.path.join(COLABFOLD_PATH, 'output'), fr"{row.seq_hash}_relaxed_rank_001_.+\.pdb", subdirs=False).values())
                # if len(source_pdb_file_path) > 0:
                #     source_pdb_file_path = source_pdb_file_path[0][0]
                    template_pdb_file_path = os.path.join(PDB_DIR_PATH, f'{row.protein_id}.pdb')    
                    df_decodings.at[i, 'tm_data'] =  align_structures(template_pdb_file_path, source_pdb_file_path)

def kmers_in_proteome(seq, proteome_tree, length):
    in_proteome = []
    not_in_proteome = []
    kmers = get_kmers(seq, length, check_aa=False)

    for kmer in kmers:
        if proteome_tree.get_kmer(kmer) is None:
            not_in_proteome.append(kmer)
        else:
            in_proteome.append(kmer)
    return in_proteome, not_in_proteome


def analyse_kmers_non_human(seq, proteome_tree, max_checked_kmer_length=10):
    kmers = get_kmers(seq, list(range(1, max_checked_kmer_length+1)))
    non_human = []
    for kmer in kmers:
        if not proteome_tree.has_kmer(kmer):
            non_human.append(kmer)

    return non_human
    

def analyse_kmers_presented(seq, immuno_setup=None, predictor_setup=None, max_checked_kmer_length=10):
    kmers = get_kmers(seq, list(range(1, max_checked_kmer_length+1)))
    kmers = [kmer for kmer in kmers if len(kmer) in MHC_1_PEPTIDE_LENGTHS]

    presented = []
    if not immuno_setup is None:
        predictor_setup['mhc_1'].predict_peptides(kmers, immuno_setup['mhc_1'])
        for kmer in kmers:
            if predictor_setup['mhc_1'].peptide_presented(kmer, immuno_setup['mhc_1']):
                presented.append(kmer)

    return presented
    

def analyse_kmer(df_decodings, source_id=None, overwrite=False):
    if 'immuno_chains' not in df_decodings:
        df_decodings['immuno_chains'] = None
        
    for i, row in tqdm(df_decodings.iterrows()):
        immuno_chains = None
        if row.seq is not None and isinstance(row.seq, str):
            immuno_chains = get_immuno_chains(row.source_id, row.protein_id, row.seq)
            
            for col in ['kmers_not_in_proteome', 'kmers_presented'] + [f'kmers_presented_{predictor_postfix}' for predictor_postfix in immuno_predictor_setups.keys()]:
                if col not in df_decodings.columns:
                    df_decodings[col] = None
            
            for predictor_postfix, predictor_setup in immuno_predictor_setups.items():            
                if df_decodings.at[i, f'kmers_presented_{predictor_postfix}'] is None or overwrite:
                    df_decodings.at[i, f'kmers_presented_{predictor_postfix}'] = []
                    for chain in immuno_chains:
                        df_decodings.at[i, f'kmers_presented_{predictor_postfix}'] += analyse_kmers_presented(
                            chain, 
                            immuno_setup={'mhc_1': row.eval_mhc_1_genotype}, 
                            predictor_setup=predictor_setup, 
                            max_checked_kmer_length=10
                        )
            
            if df_decodings.at[i, 'kmers_not_in_proteome'] is None or overwrite:
                df_decodings.at[i, 'kmers_not_in_proteome'] = []
                for chain in immuno_chains:
                    df_decodings.at[i, 'kmers_not_in_proteome'] += analyse_kmers_non_human(chain, proteome_tree, max_checked_kmer_length=10)    
                
            for prop_len in [5, 6, 7, 8, 9, 10]:
                in_proteome = []
                not_in_proteome = []
                for chain in immuno_chains:
                    _in_proteome, _not_in_proteome = kmers_in_proteome(chain, proteome_tree, prop_len)
                    in_proteome += _in_proteome
                    not_in_proteome += _not_in_proteome
                df_decodings.at[i, f'proportion_{prop_len}'] = len(in_proteome) / max(len(in_proteome) + len(not_in_proteome), 1.)
                
        df_decodings.at[i, 'kmers_presented'] = None
        df_decodings.at[i, 'immuno_chains'] = immuno_chains

In [None]:
def calc_TM_data(seq_hash_data, seq_hash_generated, loch):
    data_pdb_file_path = loch.get_pdb_file_path(seq_hash_data, predictor_structure_name="exp")
    generated_pdb_file_path = loch.get_pdb_file_path(seq_hash_generated)

    if os.path.exists(data_pdb_file_path) and os.path.exists(generated_pdb_file_path):
        tm_score, tm_alignment_length, tm_rmsd, tm_identical, len_chain_1, len_chain_2 = align_structures(
            data_pdb_file_path, generated_pdb_file_path, return_chain_lengths=True)
        result = (tm_score, tm_alignment_length, tm_rmsd, tm_identical, len_chain_1, len_chain_2)
    else:
        result = None

    return result


def add_cape_mpnn_designs(df, cape_mpnn_ckpt_id, loch):
    for idx, row in df.iterrows():
        old_row_seq_hash = row.seq_hash
            
        if row.source_id.startswith('CAPE-MPNN'):
            if isinstance(row.seq_hash, list):
                cape_mpnn_design_file_path = os.path.join(get_cape_mpnn_design_path(FASTA_OUTPUT_PATH, cape_mpnn_ckpt_id, row.protein_id, row.sampling_temp), "beams.txt")
                best_seq = file_to_str(cape_mpnn_design_file_path)

                protein_id = row.protein_id
                seq_hash_data = get_seq_hash(proteins[protein_id].get_seq())
                if len(best_seq) == 0:  # and row.seq_hash is None:
                
                    max_tm_seq_hash = None
                    max_tm_score, max_tm_data = -1., None
                    max_tm_kmers_not_in_proteome, max_tm_kmers_presented_netmhcpan = None, None
                    min_vis_seq_hash = None
                    min_vis_tm_score, min_vis_tm_data = None, None
                    min_vis_kmers_not_in_proteome, min_vis_kmers_presented_netmhcpan = None, None
                    min_vis_mhc_1 = None
                    
                    seqs = {}
                    for trial in [1, 2, 3]:
                        cape_mpnn_fasta_file_path = get_cape_mpnn_fasta_file_path(FASTA_OUTPUT_PATH, CAPE_MPNN_CKPT_ID, protein_id, 36 + trial, row.sampling_temp)
        
                        seq = fasta_to_df(cape_mpnn_fasta_file_path).iloc[1].seq
                        seq_hash = get_seq_hash(seq)
                        seqs[seq_hash] = seq
        
                        _df = pd.DataFrame({'protein_id': [protein_id], 'seq': [seq], 'source_id': ['CAPE-MPNN'], 'eval_mhc_1_genotype': [row.eval_mhc_1_genotype]})
                        analyse_kmer(_df, cape_mpnn_ckpt_id)
                        kmers_presented_netmhcpan = _df.iloc[0].kmers_presented_netmhcpan
                        kmers_not_in_proteome = _df.iloc[0].kmers_not_in_proteome
                            
                        vis_mhc_1 = len(kmers_presented_netmhcpan)
            
                        pdb_file_path = loch.get_pdb_file_path(seq_hash, predictor_structure_name='AF')
                        print(pdb_file_path)
                        if os.path.exists(pdb_file_path):
                            tm_data = calc_TM_data(seq_hash_data, seq_hash, loch)
                            tm_score = tm_data[0]
            
                            if tm_score is not None and tm_score > max_tm_score:
                                max_tm_seq_hash = seq_hash
                                max_tm_score, max_tm_data = tm_score, tm_data
                                max_tm_kmers_not_in_proteome, max_tm_kmers_presented_netmhcpan = kmers_not_in_proteome, kmers_presented_netmhcpan
            
                                if tm_score >= 0.9:
                                    if min_vis_mhc_1 is None or min_vis_mhc_1 > vis_mhc_1:
                                        min_vis_seq_hash = seq_hash
                                        min_vis_tm_score, min_vis_tm_data = tm_score, tm_data
                                        min_vis_kmers_not_in_proteome, min_vis_kmers_presented_netmhcpan = kmers_not_in_proteome, kmers_presented_netmhcpan
                                        
                                        min_vis_mhc_1 = vis_mhc_1
                            print(f'{protein_id} {trial} exists tm-score: {tm_score} vis_mhc_1: {vis_mhc_1}')
            
                    if min_vis_seq_hash is None:
                        best_seq_hash = max_tm_seq_hash
                        best_tm_score, best_tm_data = max_tm_score, max_tm_data
                        best_kmers_not_in_proteome, best_kmers_presented_netmhcpan = max_tm_kmers_not_in_proteome, max_tm_kmers_presented_netmhcpan
                    else:
                        best_seq_hash = min_vis_seq_hash
                        best_tm_score, best_tm_data = min_vis_tm_score, min_vis_tm_data
                        best_kmers_not_in_proteome, best_kmers_presented_netmhcpan = min_vis_kmers_not_in_proteome, min_vis_kmers_presented_netmhcpan
    
                    best_seq = seqs[best_seq_hash]
                    str_to_file(best_seq, cape_mpnn_design_file_path)
                else:
                    best_seq_hash = get_seq_hash(best_seq)
                    best_tm_data = calc_TM_data(seq_hash_data, best_seq_hash, loch)
                    _df = pd.DataFrame({'protein_id': [protein_id], 'seq': [best_seq], 'source_id': ['CAPE-MPNN'], 'eval_mhc_1_genotype': [row.eval_mhc_1_genotype]})
                    analyse_kmer(_df, cape_mpnn_ckpt_id)
                    best_kmers_presented_netmhcpan = _df.iloc[0].kmers_presented_netmhcpan
                    best_kmers_not_in_proteome = _df.iloc[0].kmers_not_in_proteome


                df.at[idx, 'seq_hash'] = best_seq_hash
                df.at[idx, 'seq'] = best_seq
                df.at[idx, 'tm_data'] = best_tm_data
                df.at[idx, 'kmers_not_in_proteome'] = best_kmers_not_in_proteome
                df.at[idx, 'kmers_presented_netmhcpan'] = best_kmers_presented_netmhcpan
            

In [None]:
def check_deimmunized(df, show=True):
    problems = []
    not_checked = []
    for idx, row in df.query('not seq_hash.isnull()').iterrows():
        if row.source_id not in benchmark_source_ids:
            if row.checked_kmer_length >= 10:
                problematic = len([x for x in row.kmers_presented_pwm_dynamic if x in row.kmers_not_in_proteome])
                if show:
                    print(f"{row.source_id} {row.protein_id} {row.checked_kmer_length} {len(row.kmers_presented_pwm_dynamic)} vs. {len(row.kmers_presented_netmhcpan)} - {problematic}")
                if problematic > 0:
                    problems.append(row)
            else:
                not_checked.append(row.source_id)
    
    assert len(problems) == 0
    
    print(f"The following CAPE-Beam IDs were not checked: {not_checked}")

In [None]:
def is_between(v, a, b):
    lower_check = (a is None) or (a <= v)
    upper_check = (b is None) or (v <= b)
    return lower_check and upper_check

## Plots

In [None]:
def get_Beam_source_ids(min_self_kmer_lengths, non_self_prob_factors, tune_mhc_1_genotype):
    result = []
    for _min_self_kmer_length in min_self_kmer_lengths:
        for _non_self_prob_factor in non_self_prob_factors:
            result.append(get_source_id(
                checked_kmer_length=10, 
                proteome_file_name=PROTEOME_FILE_NAME, 
                min_self_kmer_length=_min_self_kmer_length, 
                tune_mhc_1_genotype=tune_mhc_1_genotype, 
                tune_mhc_1_predictor=immuno_predictor_setups[MHC_1_PREDICTOR_DECODING]['mhc_1'], 
                non_self_prob_factor=_non_self_prob_factor, 
                width=10, 
                branching_factor=1, 
                depth=_min_self_kmer_length*2, 
                prune_min_acc_log_prob=-2000.
            ))
    return result

In [None]:
def set_kmers_presented(df, mhc_1_predictor_name):
    df['kmers_presented'] = df.apply(lambda row: row[f'kmers_presented_{mhc_1_predictor_name}'], axis=1)    

def get_df_plot(df):
    df_plot = df.copy()
    df_plot['seq_len'] = df_plot.apply(lambda row: len(row.seq) if row.seq is not None else 0, axis=1)

    set_kmers_presented(df_plot, MHC_1_PREDICTOR_EVAL)

    df_plot['kmers_presented_not_in_proteome'] = df_plot.apply(
        lambda row: set(row.kmers_presented) & set(row.kmers_not_in_proteome) if row.seq_len > 0 else None, 
        axis=1
    )

    for c in ['not_in_proteome', 'presented', 'presented_not_in_proteome']:
        df_plot[c] = df_plot.apply(lambda row: len(row[f'kmers_{c}']) if row.seq_len > 0 else None, axis=1)
        # df_plot[c] = df_plot.apply(lambda row: len(row[f'kmers_{c}']) if row[f'kmers_{c}'] is not None else None, axis=1)
    
    df_plot['tm_aligned'] = df_plot.apply(lambda row: row.tm_data[1]/len(row.seq.split("/")[0]) if row.tm_data is not None and row.seq_len is not None else None, axis=1)
    df_plot['tm_score'] = df_plot.apply(lambda row: row.tm_data[0] if row.tm_data is not None else None, axis=1)

    df_plot['presented_not_in_proteome_pc'] = df_plot.apply(lambda r: r.presented_not_in_proteome/get_possible_peptides(r.immuno_chains, [8, 9, 10]) if r.seq_len > 0 else None, axis=1)

    df_plot['seq_hash'] = df_plot.apply(lambda r: r.seq_hash if isinstance(r.seq_hash, str) else None, axis=1)

    return df_plot.query("not seq_hash.isnull()")

### Plot fig X

In [None]:
def get_successful_source_ids(df, min_tm_score):
    source_ids = list(df.sort_values(['min_self_kmer_length', 'non_self_prob_factor'], ascending=[True, False]).source_id.unique())
    result = []
    for source_id in source_ids:
        if len(df.query(f'tm_score >= {min_tm_score} and source_id == "{source_id}"')) > 0:
            result.append(source_id)
    return result
    

def fig_X(df, source_ids, rename_source_ids, palette_source_id, min_tm_score=None, 
          homo_sapiens_proteins=None,
          sources_horizontal=True, kmer_lengths=[5,6,7,8,9,10], 
          hspace=.1, wspace=0.15, groupseps=[], fig_height=0.8, label_fontsize=8):     
    infos = {
        'tm_score': ('TM score', (0., 1.2), 'linear', ".2f", False),
        'proportion': ("self-kmers \n [frac of kmers]", (0.01, 5.), 'log', ".0%", True),
        # 'presented_not_in_proteome_pc': ("presented 8-10mers \n not in proteome \n [frac of 8-10mers]", (0.001, 0.2), 'log', ".1%", True),
        'presented_not_in_proteome_pc': ("presented \n non-self 8-10mers \n [frac of 8-10mers]", (0.001, 0.2), 'log', ".1%", True),
        'rosetta p.a.': ('rosetta score \n [REU per AA]', (None, 2.), 'linear', ".1f", False),
        'delta isoelectric point': ('$\Delta$ pI', (None, None), 'linear', ".1f", False),
        'aggrescan3d max': ('aggrescan3d max', (None, None), 'linear', ".1f", False)
    }
    _palette = None
   
    n_rows, n_cols = (len(infos), 1) if sources_horizontal else (1, len(infos))

    
    fig = plt.figure(figsize=(A4_width, A4_height*fig_height))
    gs = gridspec.GridSpec(n_rows, n_cols, width_ratios=[1]*n_cols, hspace=hspace, wspace=wspace)
  
    for j, (info_col, info_details) in enumerate(infos.items()):
        info_name, info_range, info_scale, format_median, ignore_human_proteins = info_details
        
        ax = fig.add_subplot(gs[j, 0] if sources_horizontal else gs[0, j])
        x_col, y_col = ('source_id', info_col) if sources_horizontal else (info_col, 'source_id')
        axline = ax.axhline if sources_horizontal else ax.axvline
        
        _df = df
        if info_col != 'tm_score' and min_tm_score is not None:
            _df = df.query(f'tm_score >= {min_tm_score}')

        if ignore_human_proteins:
            _df = _df.query(f"protein_id not in {homo_sapiens_proteins}")
        
        if info_col not in ['proportion']:
            sns.boxplot(_df, x=x_col, y=y_col, palette=palette_source_id, order=source_ids, ax=ax, width=0.8)
            if info_col == 'rosetta p.a.':
                axline(y=-1, color='grey', linestyle='--', alpha=0.5) 
        elif info_col == 'proportion':
            # get one line per proportion and kmer_len
            _df_melted = pd.melt(_df, id_vars=['source_id'], value_vars=[f'proportion_{k}' for k in range(5, 11)]).rename(columns={'variable': 'kmer_len', 'value': info_col})
            _df_melted['kmer_len'] = _df_melted.apply(lambda r: r.kmer_len.split('_')[1], axis=1)

            # extend the color palette
            tmp = glasbey.extend_palette(list(palette_source_id.values()), palette_size=len(palette_source_id) + len(kmer_lengths), lightness_bounds=(0, 50))
            _palette = {str(k): tmp[i] for i, k in enumerate(list(palette_source_id.keys()) + kmer_lengths)}

            # plot the proportions of each source
            sns.boxplot(_df_melted, x=x_col, y=y_col, hue='kmer_len', order=source_ids, palette=_palette, ax=ax, flierprops=dict(marker='o', markersize=1.5))

            # Add horizontal lines for the human genome background distribution
            for kmer_len, row in df_pc_human.iterrows():
                axline(y=row.human, color=_palette[str(kmer_len)], linestyle='--', alpha=0.5, label=f'Mean kmer_len={kmer_len}')
        
            plot_legend_patches(
                {k: v for k, v in _palette.items() if k in [str(k) for k in kmer_lengths]}, 
                ax, 
                location='center', 
                ncol=len(kmer_lengths), plain=False, frameon=False,
                legend_kwargs={'bbox_to_anchor': (0.5, 0.92), 'fontsize': 8}
            )

        if format_median is not None:
            sub_info_cols, fontstyle = [info_col], {'fontweight': 'bold', 'color': 'black'}
            if info_col == 'proportion':
                sub_info_cols = [f'proportion_{k}' for k in kmer_lengths]
                fontstyle = {'fontsize': 7}
            
            d = 1./len(sub_info_cols)
            rotation = 0 if len(sub_info_cols) == 1 else 90
            
            for s, source_id in enumerate(source_ids):                
                for c, _info_col in enumerate(sub_info_cols):
                    v = float(_df.query(f"source_id == '{source_id}'")[[_info_col]].median().iloc[0])
                    
                    p1 = s + d * (c - (len(sub_info_cols)-1)/2)*0.8
                    p2 = v if info_range[0] is None else max(v, info_range[0])
                    p2 = p2 if info_range[1] is None else min(p2, info_range[1])

                    if info_col == 'proportion':
                        fontstyle.update({'color': _palette[str(kmer_lengths[c])]})

                    if is_between(v, info_range[0], info_range[1]):                    
                        #x, y, va, ha = (s, v, 'top', 'center') if sources_horizontal else (s, v, 'center', 'right')
                        x, y, va, ha = (p1, p2, 'center', 'center') if sources_horizontal else (p2, p1, 'center', 'center')
                        ax.text(x=x, y=y, s=f'{v:{format_median}}', verticalalignment=va, horizontalalignment=ha, 
                                rotation=rotation, **fontstyle,
                                bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", boxstyle="round,pad=0.05")
                        )
        
        labels = []
        if sources_horizontal:
            ax.set_ylim(info_range)
            get_ticklabels, set_ticklabels = ax.get_xticklabels, ax.set_xticklabels
            ax.set_xlabel('')
            ax.set_ylabel(info_name, fontsize=label_fontsize)
            ax.set_yscale(info_scale)
        else:
            ax.set_xlim(info_range)
            get_ticklabels, set_ticklabels = ax.get_yticklabels, ax.set_yticklabels
            ax.set_xlabel(info_name, fontsize=label_fontsize)
            ax.set_ylabel('')
            ax.set_xscale(info_scale)
        
        if j == len(infos) - 1:
            for ticklabel in get_ticklabels():
                source_id = ticklabel.get_text()
                n_designs = len(df.query(f"source_id == '{source_id}'" + ("" if min_tm_score is None else f" and tm_score >= {min_tm_score}") ))
                labels.append(f"{rename_source_ids.get(source_id, source_id)}\n N={n_designs}")
        
        set_ticklabels(labels, rotation=50, fontsize=label_fontsize)

        # add separating lines between sources
        for s, source_id in enumerate(source_ids):
            if s not in groupseps:
                ax.axvline(x=s+0.5, linestyle='--', color='lightgrey')
            else:
                ax.axvline(x=s+0.5, linestyle='-', color='grey')

        # add figure ids
        x_offset, y_offset = (-0.15, 1.) if sources_horizontal else (1., -0.15)
        ax.text(x_offset, y_offset, f"{string.ascii_lowercase[j]})", transform=ax.transAxes, fontsize=label_fontsize*1.25, fontweight='bold', va='top', ha='right')



### Plot fig A

In [None]:
def plot_fig_A(df, palette_source_id, area=None, rename_source_ids=None, x_lim_pnip=None, min_tm_score=None, show_x_labels=True):
    df = df.copy()

    # df['pnip'] = df.apply(lambda r: r.presented_not_in_proteome/(len(r.seq)*3 - 24), axis=1)
    df['pnip'] = df.apply(lambda r: r.presented_not_in_proteome/get_possible_peptides(r.immuno_chains, [8, 9, 10]), axis=1)
    source_ids = list(df.source_id.unique())

    x_lim_pnip = df.pnip.max() + 0.1 if x_lim_pnip is None else x_lim_pnip
  

    if area is None:
        fig = plt.figure(figsize=(A4_width, A4_height/2))
        gs = gridspec.GridSpec(len(source_ids), 
                               3,
                               width_ratios=[1, 1, 1],
                               hspace=.5,
                               wspace=0.15,
                              )
    else:
        fig = plt.gcf()
        gs = area.subgridspec(len(source_ids), 
                               3,
                               width_ratios=[1, 1, 1],
                               hspace=.5,
                               wspace=0.15,
                              )

    axes = np.full((len(source_ids), 3), None, dtype=object)
    _lens = [5, 6, 7, 8, 9, 10]

    for i, source_id in enumerate(source_ids):
        source_name = rename_source_ids[source_id] if (rename_source_ids is not None and source_id in rename_source_ids) else source_id
        
        _df = df.query(f'source_id == "{source_id}"')
        
        ax_0 = fig.add_subplot(gs[i, 0])  #, sharey=axes[0, 0])
        axes[i, 0] = ax_0

        #
        # TM scores
        #
        
        sns.boxplot(_df, y='source_id', x='tm_score', hue='source_id', palette=palette_source_id, ax=ax_0)
       
        ax_0.tick_params(axis='y', labeltop=True, labelbottom=False) 
        ax_0.get_legend().remove()

        ax_0.set_ylabel('')
        ax_0.set_xlabel('')
        ax_0.set_xlim((0., 1.))


        if min_tm_score is not None:
            _df = _df.query(f"tm_score >= {min_tm_score}")
        N = len(_df)
                
        labels = [f"{source_name}\nN={N}"]
        ax_0.set_yticklabels(labels)
        ax_0.xaxis.set_ticks_position('top')
        ax_0.xaxis.set_label_position('top')
        if i != 0 or not show_x_labels:
            ax_0.set_xticklabels([])
        else:
            ax_0.set_xlabel('TM score')
        
        if N == 0:
            continue
        #
        # self kmer proportions
        #

        
        ax_1 = fig.add_subplot(gs[i, 1])  #, sharey=axes[1, 0])
        axes[i, 1] = ax_1            
                  
        df_h = pd.DataFrame(_df[[f'proportion_{l}' for l in _lens]].mean())
        df_h.columns = ['pc']
        df_h['len'] = df_h.apply(lambda r: int(r.name.split('_')[1]), axis=1)
        sns.scatterplot(df_h, x='len', y='pc', color=palette_source_id[source_id], ax=ax_1)

        j = 1
        for _, row in df_h.sort_values('len').iterrows():
            if j < len(df_h):  # right of dat
                v = 3 if row['pc'] < .5 else -3
                ax_1.annotate(f'{row["pc"]:.2f}', (row['len'], row['pc']), 
                          xytext=(3, v), textcoords='offset points', ha='left', va='center', fontsize=7)
            else:  # above dot
                v = 3 if row['pc'] < .5 else -3
                ax_1.annotate(f'{row["pc"]:.2f}', (row['len'], row['pc']), 
                          xytext=(0, v), textcoords='offset points', ha='center', va='bottom' if v > 0 else 'top', fontsize=7)
            j += 1

        sns.scatterplot(df_pc_human, x='length', y='human', marker='x')
        ax_1.set_ylim((0., 1.1))
        #ax_1.invert_yaxis()

        ax_1.set_ylabel('')
        ax_1.set_xlabel('')
        #ax_1.set_ylim((_lens[-1], _lens[0]))


        #
        # presented non-self kmers
        #
        
        ax_2 = fig.add_subplot(gs[i, 2])  #, sharey=axes[1, 0])
        axes[i, 2] = ax_2          
        sns.boxplot(_df, y='source_id', x='pnip', hue='source_id', palette=palette_source_id, ax=ax_2)
        
        ax_2.text(x=_df.pnip.max()+0.01, y=.25, s=f'avg={_df.pnip.mean():.1%}', verticalalignment='bottom', horizontalalignment='left') # add mean value
        ax_2.get_legend().remove()
        ax_2.set_xlabel('')
        ax_2.set_ylabel('')
        ax_2.set_yticklabels([])
        ax_2.set_xlim((0., x_lim_pnip))
        
        ax_1.xaxis.set_ticks_position('top')
        ax_1.xaxis.set_label_position('top')
        ax_2.xaxis.set_ticks_position('top')
        ax_2.xaxis.set_label_position('top')
        if i != 0 or not show_x_labels:
            ax_1.set_xticklabels([])
            ax_2.set_xticklabels([])
        else:
            ax_1.set_xlabel("from genome \n [fraction of kmers]")
            ax_2.set_xlabel('presented 8-10mers \n not in proteome \n [fraction of 8-10mers]')
            
            ax_1.set_xticks(_lens)
            ax_1.set_xticklabels(_lens)
   

### Plot fig B

In [None]:
def plot_fig_B(df, palette_source_id, rename_source_ids, fig_width=None, ylabel=True, min_tm_score=None):
    infos = ['rosetta p.a.', 'aggrescan3d max', 'delta isoelectric point']

    if min_tm_score is not None:
        df = df.query(f"tm_score >= {min_tm_score}")

    fig_width = A4_width if fig_width is None else fig_width
    fig = plt.figure(figsize=(fig_width, A4_height/2))
    gs = gridspec.GridSpec(len(infos), # rows
                           1, # cols
                           width_ratios=[1],
                           height_ratios=[1]*len(infos),
                           hspace=0.1,
                           wspace=0.1,
                          )

    for j, info in enumerate(infos):
        ax = fig.add_subplot(gs[j, 0])
        
        sns.boxplot(
            data=df,
            x='source_id', 
            y=info, 
            palette=palette_source_id,
            ax=ax
        )

        if j < len(infos) - 1:
            ax.set_xticklabels([])
        else:
            labels = []
            for h in ax.get_xticklabels():
                source_id = h.get_text()
                n_designs = len(df.query(f"source_id == '{source_id}'"))
                labels.append(f"{rename_source_ids.get(source_id, source_id)}\n N={n_designs}")
            ax.set_xticklabels(labels, rotation=50)

        if not ylabel:
            ax.set_ylabel("")

        if info == 'rosetta p.a.':
            plt.axhline(y=-1.0, color='red', linestyle='--')
            
        ax.set_xlabel('')

        # plt.axvline(x=len(benchmark_source_ids) - 0.5, color='green')

### Plot fig C

In [None]:
def add_min_diff(df, closest_pkmer=None, save=True):
    closest_pkmer_file_path = os.path.join(G.ENV.ARTEFACTS, "eval", "closest_pkmer.pickle")
    if closest_pkmer is None:
        closest_pkmer = {}
        if os.path.exists(closest_pkmer_file_path):
            with open(closest_pkmer_file_path, "rb") as f:
                closest_pkmer = pickle.load(f)

    
    df['avg_min_diff'] = None
    df['max_min_diff'] = None
    df['min_diff'] = None
    
    for i, row in df.iterrows():
        kmers_not_in_proteome = [kmer for kmer in row.kmers_not_in_proteome if len(kmer) in MHC_1_PEPTIDE_LENGTHS]
    
        min_diff = []
        for kmer in kmers_not_in_proteome:
            l = len(kmer)
            if kmer not in closest_pkmer:       
                max_score = -1000
                max_pkmer = None
                for pkmer in proteome_kmers_per_length[l]:
                    score = aligner.score(kmer, pkmer)
                    if score > max_score:
                        max_pkmer, max_score = pkmer, score
    
                closest_pkmer[kmer] = (max_pkmer, max_score)
            else:
                max_pkmer, max_score = closest_pkmer[kmer]
                
            min_diff.append(aligner.score(kmer, kmer) - max_score)

        df.at[i, 'min_diff'] = min_diff
        chain = row.seq.split('/')[0]
        # print(f"{row.protein_id} {row.source_id} {len(chain)*3 - (9 + 8 +7)} {len(min_diff)}")
        if len(kmers_not_in_proteome) > 0:
            df.at[i, 'avg_min_diff'] = np.mean(min_diff)
            df.at[i, 'max_min_diff'] = np.max(min_diff)

        if save:
            with open(closest_pkmer_file_path, "wb") as f:
                pickle.dump(closest_pkmer, f)

    return closest_pkmer

In [None]:
def plot_dissimilarity_protein(df, protein_id, source_ids, area, ylabel=True, rename_source_ids=None, palette=None, 
                               source_label_fontsize=7, min_tm_score=None, 
                               dissimilarity_column='min_diff', dissimilarity_label="BLOSUM62\ndissimilarity", dissimilarity_scale='linear'):
    fig = plt.gcf()
    #with sns.axes_style("whitegrid"):
    ax = fig.add_subplot(area)

    source_names, min_diffs = {}, {}
    _df = {'source_name': [], dissimilarity_column: []}
    for source_id in source_ids:
        source_names[source_id] = rename_source_ids[source_id] if (rename_source_ids is not None and source_id in rename_source_ids) else source_id
        
        rows = df.query(f"protein_id == '{protein_id}' and source_id == '{source_id}'")
        assert len(rows) == 1
        row = rows.iloc[0]
        if min_tm_score is not None and row.tm_score < min_tm_score:
            continue
        
        _df['source_name'] += [source_id] * len(row.min_diff)
        _df[dissimilarity_column] += row[dissimilarity_column]  # row[similarity_column] is a list
        min_diffs[source_id] = row[dissimilarity_column]

    _df = pd.DataFrame(_df)
    sns.boxplot(data=_df, x='source_name', y=dissimilarity_column, order=source_ids, palette=palette, ax=ax)
    labels = [f"{source_names[h]}\nn={len(min_diffs[h]) if h in min_diffs else 'NA'}" for h in source_ids]
    ax.set_xticklabels(labels, rotation=50, fontsize=source_label_fontsize)
    ax.set_ylabel(dissimilarity_label if ylabel else '')
    ax.set_yscale(dissimilarity_scale)
    ax.set_xlabel('')    
    plt.title(protein_id)

def plot_dissimilarity(df, protein_ids, source_ids, rename_source_ids=None, n_cols=3, fig_width=None, fig_height=None, 
                       palette=None, wspace=0.25, hspace=0.5, source_label_fontsize=7, min_tm_score=None, 
                       dissimilarity_column='min_diff', dissimilarity_label="BLOSUM62\ndissimilarity", dissimilarity_scale='linear'
                      ):
    n_rows = int(np.ceil(len(protein_ids)/n_cols))

    fig_width = A4_width if fig_width is None else fig_width
    fig_height = min(A4_height*0.25*n_rows, A4_height) if fig_height is None else fig_height
    fig = plt.figure(figsize=(fig_width, fig_height))
    gs = mpl.gridspec.GridSpec(
            n_rows,
            n_cols,
            height_ratios=[1] * n_rows,
            width_ratios=[1] * n_cols,
            wspace=wspace, hspace=hspace
    )

    for i, protein_id in enumerate(protein_ids):
        i_col = i % n_cols
        i_row = i // n_cols
    
        plot_dissimilarity_protein(df, protein_id, source_ids, 
                                   gs[i_row, i_col], ylabel=(i_col == 0), rename_source_ids=rename_source_ids, 
                                   palette=palette, source_label_fontsize=source_label_fontsize, 
                                   min_tm_score=min_tm_score, 
                                   dissimilarity_column=dissimilarity_column, 
                                   dissimilarity_label=dissimilarity_label,
                                   dissimilarity_scale=dissimilarity_scale
        )
        

## Destress

In [None]:
def run_destress(df):
    seq_hashes_exp = []
    seq_hashes_AF = []
    seq_hashes_CAPE_MPNN = []
    
    for idx, row in df.iterrows():
        if row.seq_hash is not None and isinstance(row.seq_hash, str):
            if row.source_id in ['template']:
                seq_hashes_exp.append(row.seq_hash)
            elif row.tm_score is not None:
                seq_hashes_AF.append(row.seq_hash)
    
    if DESTRESS_PROG_DIR_PATH is None:
        print("DESTRESS_PROG_DIR_PATH not set")
        str_to_file('\n'.join(seq_hashes_exp), join(G.ENV.ARTEFACTS, 'eval', 'de-stress', 'for_destress_exp.txt'))
        str_to_file('\n'.join(seq_hashes_AF), join(G.ENV.ARTEFACTS, 'eval', 'de-stress', 'for_destress_AF.txt'))
    else:
        loch.run_destress(
            seq_hashes_exp, 
            DESTRESS_PROG_DIR_PATH, 
            predictor_structure_name='exp'
        )
        loch.run_destress(
            seq_hashes_AF, 
            DESTRESS_PROG_DIR_PATH, 
            predictor_structure_name='AF'
        )

    return seq_hashes_exp, seq_hashes_AF

# Load systems

## Loch

In [None]:
loch = Loch(loch_path=LOCH_PATH)
kit.loch.path.set_loch_path(LOCH_PATH)
Protein.loch = loch

## Immuno-Predictors

### NetMHCpan

In [None]:
predictor_MHC_I_netmhcpan_dd = DD.from_yaml(
    os.path.join(
        G.PROJECT_ENV.CONFIG, 
        'immuno', 
        'mhc_1_predictor', 
        f'netmhcpan.yaml'
    )
)

predictor_MHC_I_netmhcpan_args = {
    'data_dir_path': NETMHCPAN_DIR_PATH,
    'limit': predictor_MHC_I_netmhcpan_dd.PREDICTOR_MHC_I.LIMIT,
}
if "LIMIT_CALIBRATION" in predictor_MHC_I_netmhcpan_dd.PREDICTOR_MHC_I:
    predictor_MHC_I_netmhcpan_args.update({'limit_calibration': predictor_MHC_I_netmhcpan_dd.PREDICTOR_MHC_I.LIMIT_CALIBRATION})
predictor_MHC_I_netmhcpan = Mhc1Predictor.get_predictor(predictor_MHC_I_netmhcpan_dd.PREDICTOR_MHC_I.NAME)(**predictor_MHC_I_netmhcpan_args)

### PWM

In [None]:
predictor_MHC_I_pwm_dd = DD.from_yaml(
    os.path.join(
        G.PROJECT_ENV.CONFIG, 
        'immuno', 
        'mhc_1_predictor', 
        f'pwm_dynamic.yaml'
    )
)

predictor_MHC_I_dd_args = {
    'data_dir_path': predictor_MHC_I_pwm_dd.PREDICTOR_MHC_I.FOLDER, 
    'limit': predictor_MHC_I_pwm_dd.PREDICTOR_MHC_I.LIMIT,
}
if "LIMIT_CALIBRATION" in predictor_MHC_I_pwm_dd.PREDICTOR_MHC_I:
    predictor_MHC_I_dd_args['limit_calibration'] = predictor_MHC_I_pwm_dd.PREDICTOR_MHC_I.LIMIT_CALIBRATION
predictor_MHC_I_pwm = Mhc1Predictor.get_predictor(predictor_MHC_I_pwm_dd.PREDICTOR_MHC_I.NAME)(**predictor_MHC_I_dd_args)

In [None]:
immuno_predictor_setups = {
    'netmhcpan': {'mhc_1': predictor_MHC_I_netmhcpan}, 
    'pwm_dynamic': {'mhc_1': predictor_MHC_I_pwm}
}

## Proteome

### Tree

In [None]:
proteome_pickle_file_path = os.path.join(G.ENV.INPUT, "proteomes", f"{PROTEOME_FILE_NAME}.pickle")
if not os.path.exists(proteome_pickle_file_path):
    print("regenerate tree")
    proteome_hash, proteome_tree = load_proteome_tree(PROTEOME_FILE_NAME, alphabet=ALPHABET)
    with open(proteome_pickle_file_path, "wb") as f:
        pickle.dump(proteome_tree, f)
else:
    print("load from disk")
    with open(proteome_pickle_file_path, "rb") as f:
        proteome_tree = pickle.load(f)
    PrefixTree.set_alphabet(ALPHABET)

print(f"Nr of nodes: {proteome_tree.cnt_nodes}")

### Aligner

In [None]:
aligner = Align.PairwiseAligner()
aligner.substitution_matrix = substitution_matrices.load('BLOSUM62')
aligner.mode = 'global'
aligner.open_gap_score = -1e6
aligner.extend_gap_score = -1e6

In [None]:
proteome_file_path = os.path.join(G.ENV.INPUT, "proteomes", PROTEOME_FILE_NAME)
proteome = read_fasta(proteome_file_path, stop_token=False, evaluate=False, return_df=True)
proteome_kmers = get_kmers(list(proteome.index), lengths=MHC_1_PEPTIDE_LENGTHS)

In [None]:
print(len(proteome_kmers))

In [None]:
proteome_kmers_per_length = {l: set() for l in MHC_1_PEPTIDE_LENGTHS}

for kmer in proteome_kmers:
    proteome_kmers_per_length[len(kmer)].add(kmer)

## Percent of random kmers in the human genome

In [None]:
df_pc_human_file_path = join(G.ENV.ARTEFACTS, 'eval', 'df_pc_human.csv')

if os.path.exists(df_pc_human_file_path):
    df_pc_human = pd.read_csv(df_pc_human_file_path).set_index('length')
else:
    proteome_kmers = {l: set() for l in range(5, 11)}
    for seq, row in tqdm(proteome.iterrows()):
        for l in proteome_kmers.keys():
            proteome_kmers[l].update(get_kmers(seq, l))


    _lens, _pc = [], []
    for l in proteome_kmers.keys():
        print(f"{l}  {len(proteome_kmers[l])/20**l:.3f}")
        _lens.append(l)
        _pc.append(len(proteome_kmers[l])/20**l)
    
    df_pc_human = pd.DataFrame({'length': _lens, 'human': _pc}).set_index('length')    
    df_pc_human.to_csv(df_pc_human_file_path)

## Load ProteinMPNN checkpoint

In [None]:
protein_mpnn = CapeMPNN.from_file(BASE_MODEL_NAME).eval()

## Plots

In [None]:
def get_source_id(checked_kmer_length, proteome_file_name, min_self_kmer_length, tune_mhc_1_genotype, tune_mhc_1_predictor, non_self_prob_factor, width, branching_factor, depth, prune_min_acc_log_prob):
    _predictor_setup_hash, _immuno_setup = None, None
    if tune_mhc_1_genotype is not None:
        _predictor_setup_hash = tune_mhc_1_predictor.get_predictor_hash()
        _immuno_setup = {'mhc_1': tune_mhc_1_genotype}
        
    return get_beam_search_hash(
        None, 
        checked_kmer_length, 
        proteome_file_name, 
        min_self_kmer_length, 
        0, 
        _immuno_setup, 
        _predictor_setup_hash, 
        20, 
        non_self_prob_factor, 
        width,
        branching_factor,
        depth,
        prune_min_acc_log_prob
    )

In [None]:
palette_source_id = ['template', 'standard', 'CAPE-MPNN'] # template, standard, CAPE-MPNN
tmp = [3]

rename_source_ids_X = {}
rename_source_ids_A = {}
rename_source_ids_B = {}
rename_source_ids_C = {}

for _min_self_kmer_length in PLOT_MIN_SELF_KMER_LENGTH:
    new_source_ids = []
    for _non_self_prob_factor in PLOT_NON_SELF_PROB_FACTORS:
        source_id = get_source_id(
            checked_kmer_length=10, 
            proteome_file_name=PROTEOME_FILE_NAME, 
            min_self_kmer_length=_min_self_kmer_length, 
            tune_mhc_1_genotype=person_mhc_1_genotype, 
            tune_mhc_1_predictor=immuno_predictor_setups[MHC_1_PREDICTOR_DECODING]['mhc_1'], 
            non_self_prob_factor=_non_self_prob_factor, 
            width=10, 
            branching_factor=1, 
            depth=_min_self_kmer_length*2, 
            prune_min_acc_log_prob=-2000.
        )
        rename_source_ids_X[source_id] = f"CB {_min_self_kmer_length}mers ({_non_self_prob_factor})"
        rename_source_ids_A[source_id] = f"CAPE-Beam ({_non_self_prob_factor})"
        rename_source_ids_B[source_id] = f"{_min_self_kmer_length}mers ({_non_self_prob_factor})"
        rename_source_ids_C[source_id] = f"CAPE-Beam\n{_min_self_kmer_length}mers ({_non_self_prob_factor})"
        
        new_source_ids.append(source_id)
        
    palette_source_id += new_source_ids
    tmp.append(len(new_source_ids))

tmp = glasbey.create_block_palette(tmp)
palette_source_id = {k: tmp[i] for i, k in enumerate(palette_source_id)}
markers_source_id = {k: 'o' for i, k in enumerate(palette_source_id)}

In [None]:
sns.palplot(tmp)

In [None]:
df_plot = {}
selected_source_ids = {}

# Identify alternative genotypes

In [None]:
if len(alternative_mhc_1_genotypes) == 0:
    #
    # get the peptide ranks of the alleles to compare to
    #
    allele_ranks = {}
    peptides = None
    for mhc_1_allele in person_mhc_1_genotype_list:
        result = re.findall(r"^HLA-(A|B|C)\*(\d+):(\d+)$", mhc_1_allele)[0]
        file_name = r"HLA-{}_{}:{}.csv".format(*result)
        df = pd.read_csv(os.path.join(PEPTIDE_RANKS_DIR_PATH, file_name)).set_index('peptide')
        if peptides is None:
            peptides = list(df.index)
        else:
            assert len(set(peptides) ^ set(df.index)) == 0
    
        gene = result[0]
        allele_ranks[mhc_1_allele] = df


    #
    # calcuate the 'distance' of all available alles to the already available alleles
    #    
    available_alleles = {}
    for file_name in tqdm([d for d in os.listdir(PEPTIDE_RANKS_DIR_PATH) if d.endswith(".csv")]):
        result = re.findall(r"HLA-(A|B|C)_(\d+):(\d+).csv", file_name)
        if len(result) == 1:
            result = result[0]
            gene = result[0]
            allele_name = f"HLA-{gene}*{result[1]}:{result[2]}"
            if allele_name not in person_mhc_1_genotype_list:
                df = pd.read_csv(os.path.join(PEPTIDE_RANKS_DIR_PATH, file_name)).set_index('peptide')
                assert len(set(peptides) ^ set(df.index)) == 0
                if all([type(x) == float for x in list(df[allele_name])]):
                    dist = []
                    for mhc_1_allele, df2 in allele_ranks.items():
                        df3 = df.join(df2)
                        dist.append(np.sum((df3[mhc_1_allele] <= DISTANCE_RANK_THRESHOLD) != (df3[allele_name] <= DISTANCE_RANK_THRESHOLD))/df3.shape[0])
    
                    available_alleles[gene] = available_alleles.get(gene, list())
                    min_distance = min(dist)
                    
                    available_alleles[gene].append((allele_name, dist, min_distance))

    #
    # Take the most distant alleles, make sure they belong to different groups
    #
    alternative_mhc_1_genotypes = []
    for gene in ['A', 'B', 'C']:
        sorted_available_alleles = sorted(available_alleles[gene], key=lambda x: -x[2])
        
        print(gene)
        found, group = 0, ""
        for sorted_allele in sorted_available_alleles:
            result = re.findall(r"^HLA-(A|B|C)\*(\d+):(\d+)$", sorted_allele[0])[0]
            _group = result[1]
            if _group != group:
                print(f"  {sorted_allele}")
                found += 1
                alternative_mhc_1_genotypes.append(sorted_allele[0])
                if found >= 2:
                    break
                group = _group
                
    alternative_mhc_1_genotypes = '+'.join(alternative_mhc_1_genotypes)
    print(alternative_mhc_1_genotypes)

In [None]:
split_mhc_1_genotypes = {
    Split.VAL: [person_mhc_1_genotype],
    Split.TEST: [person_mhc_1_genotype] + alternative_mhc_1_genotypes,
    Split.PREDICT: [person_mhc_1_genotype],
}

# Analyse PWM predictor

In [None]:
if ANALYSE_PWM_PREDICTOR == True:
    peptide_intersect = None
    pred = immuno_predictor_setups['pwm_dynamic']['mhc_1']
    person_allele_ranks = {}
    peptides = {s: None for s in [Split.TRAIN, Split.TEST]}
    print(rf"{'allele':15s} & {'kmer len':10s} & {'accuracy':10s} & {'precision':10s} & {'recall':10s} & {'accuracy':10s} & {'precision':10s} & {'recall':10s} \\")
    for mhc_1_genotype in all_mhc_1_genotypes:
        for mhc_1_allele in mhc_1_genotype.split('+'):
            for l in [8,9,10]:
                pred.load_allele(mhc_1_allele, l)
            
            result = re.findall(r"^HLA-(A|B|C)\*(\d+):(\d+)$", mhc_1_allele)[0]
            file_name = r"HLA-{}_{}:{}.csv".format(*result)

            dfs = {}
            for split, ranks_dir_path in [(Split.TRAIN, PEPTIDE_RANKS_DIR_PATH), (Split.TEST, PWM_PREDICTOR_TEST_RANKS_DIR_PATH)]:
                df = pd.read_csv(os.path.join(ranks_dir_path, file_name)).set_index('peptide')
                df['PWM'] = df.apply(lambda r: pred.peptide_presented(r.name, mhc_1_allele), axis=1)
                df['netMHCpan'] = df[mhc_1_allele] < pred.limit_calibration
                df['length'] = df.apply(lambda row: len(row.name), axis=1)
                df['TP'] = (df['PWM'] == True) & (df['netMHCpan'] == True)
                df['FP'] = (df['PWM'] == True) & (df['netMHCpan'] == False)
                df['FN'] = (df['PWM'] == False) & (df['netMHCpan'] == True)
                df['TN'] = (df['PWM'] == False) & (df['netMHCpan'] == False)               
                dfs[split] = df

                if peptides[split] is not None:
                    assert len(peptides[split] ^ set(df.index)) == 0
                else:
                    peptides[split] = set(df.index)
            
            print(r"\hline")
            _considered_k = [8, 9, 10] 
            for k in [None]+_considered_k:
                _k = '+'.join([str(x) for x in _considered_k]) if k is None else str(k)
                text = [rf"{mhc_1_allele:15s} & {_k:10s}"]
                for split in [Split.TRAIN, Split.TEST]:
                    _df = dfs[split]
                    if k is not None:
                        _df = _df.query(f"length == {k}")
                        
                    TP = _df['TP'].sum()
                    FP = _df['FP'].sum()
                    FN = _df['FN'].sum()
                    TN = _df['TN'].sum()
    
                    text.append(rf"{(TP + TN) / (TP + TN + FP + FN):10.3f} & {TP/(TP + FP):10.3f} & {TP/(TP + FN):10.3f}")
                text = " & ".join(text) + r' \\'
                print(text)

# Proteins

## Identify specific proteins

In [None]:
split_pdb_ids_file_path = os.path.join(G.ENV.ARTEFACTS, "eval", "split_pdb_ids.pickle")
specific_proteins_file_path = os.path.join(G.ENV.ARTEFACTS, "eval", "specific_proteins.pickle")

In [None]:
specific_proteins = {split: [] for split in [Split.VAL, Split.TEST, Split.PREDICT]}
if os.path.exists(specific_proteins_file_path):
    with open(specific_proteins_file_path, "rb") as f:
        specific_proteins = pickle.load(f)

In [None]:
split_pdb_ids = {split: set() for split in [Split.TRAIN, Split.VAL, Split.TEST]}
if IDENTIFY_SPECIFIC:
    # read all the validation and test clusters
    clusters_val = set([int(x) for x in file_to_str(os.path.join(GENERAL_DATA_DIR_PATH, 'valid_clusters.txt')).split('\n') if x != ''])
    clusters_test = set([int(x) for x in file_to_str(os.path.join(GENERAL_DATA_DIR_PATH, 'test_clusters.txt')).split('\n') if x != ''])
    df_data_list = pd.read_csv(os.path.join(GENERAL_DATA_DIR_PATH, 'list.csv'))
    
    assert df_data_list.shape[0] == len(set(df_data_list.CHAINID))
    
    # construct the sets of train, validation and test chain_ids
    chains_split = {split: set() for split in [Split.TRAIN, Split.VAL, Split.TEST]}
    for _, row in df_data_list.iterrows():
        found = False
        if row.CLUSTER in clusters_test:
            chains_split[Split.TEST].add(row.CHAINID)
            found = True
        if row.CLUSTER in clusters_val:
            chains_split[Split.VAL].add(row.CHAINID)
            found = True
        if not found:
            chains_split[Split.TRAIN].add(row.CHAINID)
    
    assert np.sum([len(c) for c in chains_split.values()]) == df_data_list.shape[0] 
    
    for split, chains in chains_split.items():
        for chain in chains:
            split_pdb_ids[split].add(chain.split('_')[0].upper())

    candidate_specific_pdb_ids = {}
    for split in [Split.VAL, Split.TEST]:
        candidate_specific_pdb_ids[split] = filter_pdbs(
            [x for x in split_pdb_ids[split]], 
            PDB_DIR_PATH, 
            max_unmodelled=CANDIDATE_MAX_UNMODELLED,
            exclude_small_molecules=CANDIDATE_EXCLUDE_SMALL_MOLECULES,
            exclude_complexes=True,
            server=PDB_SERVER
        )
        print(f"Number of {split} candidates: {len(candidate_specific_pdb_ids[split])}")

    with open(split_pdb_ids_file_path, "wb") as f:
        pickle.dump(split_pdb_ids, f)

if os.path.exists(split_pdb_ids_file_path):
    with open(split_pdb_ids_file_path, "rb") as f:
        split_pdb_ids = pickle.load(f)

In [None]:
if IDENTIFY_SPECIFIC:
    # check overlaps are all in complexes
    for split_1 in [Split.TRAIN, Split.VAL, Split.TEST]:
        for split_2 in [Split.TRAIN, Split.VAL, Split.TEST]:
            if split_1 != split_2:
                overlap = set(split_pdb_ids[split_1]) & set(split_pdb_ids[split_2])
                if len(overlap) > 0:
                    print(f"{split_1} {split_2} {overlap}")
                    for pdb_id in overlap:
                        print(f" {pdb_id}")
                        print(f"  {split_1}: ", end='')
                        c_1 = set()
                        for pdb_id_1 in chains_split[split_1]:
                            if pdb_id_1.upper().split('_')[0] == pdb_id:
                                print(f"{pdb_id_1} ", end='')
                                c_1.add(pdb_id_1)
                        print("")
    
                        print(f"  {split_2}: ", end='')
                        for pdb_id_2 in chains_split[split_2]:
                            if pdb_id_2.upper().split('_')[0] == pdb_id:
                                print(f"{pdb_id_2} ", end='')
                                if pdb_id_2 in c_1:
                                    raise Exception("overlap!")
                        print("")

In [None]:
if IDENTIFY_SPECIFIC:
    for split in [Split.VAL, Split.TEST]:
        np.random.seed(SEED)
        candidate_specific_pdb_ids[split] = np.random.permutation(sorted(candidate_specific_pdb_ids[split]))
    
    peer_groups = {split: {} for split in [Split.VAL, Split.TEST]}

    split = Split.TEST
    for pdb_id in sorted(candidate_specific_pdb_ids[split]):
        peers = get_peers(pdb_id)
    
        if len(peers) > 0:              
            found = None
            for cid, cluster in peer_groups[split].items():
                if pdb_id in cluster:
                    found = cid
            if found is None:
                peer_groups[split][pdb_id] = [x[0] for x in peers.values()]


    for pdb_id in list(peer_groups[Split.TEST]):
        add_to_specific_proteins(specific_proteins[Split.TEST], pdb_id, peer_groups[Split.TEST], KEEP_CHAINS)
    

    for split in [Split.VAL, Split.TEST]:
        i = 0
        while len(specific_proteins[split]) < N_SPECIFIC[split]:
            pdb_id = candidate_specific_pdb_ids[split][i]

            add_to_specific_proteins(
                specific_proteins[split], 
                pdb_id, 
                peer_groups[split],
                KEEP_CHAINS
            )

            i += 1

In [None]:
LOAD_RESULTS = {s: True for s in Split}
if SPECIFIC_MANUAL is not None:
    for split in [Split.VAL, Split.TEST, Split.PREDICT]:        
        if len(set(specific_proteins[split]) ^ set(SPECIFIC_MANUAL[split])) > 0:
            LOAD_RESULTS[split] = False
            print(split)

        specific_proteins[split] = copy.deepcopy(SPECIFIC_MANUAL[split])

## Download protein info and check overlaps of Train/Val/Test set

In [None]:
HUMAN_PROTEINS = []  # ['1XGD']

In [None]:
protein_infos = {}
proteins = {}

for split in [Split.VAL, Split.TEST, Split.PREDICT]:
    for pdb_id in specific_proteins[split]:
        found = []
        for _split in [Split.TRAIN, Split.VAL, Split.TEST]:
            if pdb_id in split_pdb_ids[_split]:
                found.append(_split)

        pdb_file_path = os.path.join(PDB_DIR_PATH, f"{pdb_id}.pdb")
        download_structure(pdb_id, PDB_DIR_PATH, overwrite=False, file_format="pdb", source='pdb', server=PDB_SERVER)
        protein_name, protein_organism = get_protein_name_organism(pdb_file_path)

        protein = Protein.from_pdb(os.path.join(PDB_DIR_PATH, f"{pdb_id}.pdb"), keep_chains=KEEP_CHAINS.get(pdb_id, None))
        protein.to_loch(predictor_structure_name='exp')
        protein_infos[pdb_id] = (None, protein.get_protein_type(), 'pdb')
        proteins[pdb_id] = protein
        
        print(f"{pdb_id} {len(protein.get_seq())} {protein.get_protein_type()}: {found} {protein_organism} - {protein_name}")
        if protein_organism == 'homo sapiens':
            HUMAN_PROTEINS.append(pdb_id)
        
        if split != Split.PREDICT:
            assert all(f == split for f in found)

# Code for splits

In [None]:
def restrict_df_plot(fig_config, df_plot):
    query, suffix = [], []
    for k, v in fig_config.items():
        q, s = [], []
        for _v in v:
            if type(_v) == str:
                q.append(f"{k} == '{_v}'")
            elif _v is None:
                q.append(f"{k}.isnull()")
            else:
                q.append(f"{k} == {_v}")
            
            if k == 'eval_mhc_1_genotype':
                s.append(get_mhc_1_setup_hash(_v))
            else:
                s.append(str(_v))

        query.append('(' + ' or '.join(q) + ')')
        suffix.append('_'.join(s))

    query = " and ".join(query)
    suffix = '__'.join(suffix)

    return df_plot.query(query), query, suffix

In [None]:
shared_code = []

# 0
shared_code.append("""
df_decodings_file_path = os.path.join(G.ENV.ARTEFACTS, "eval", f"df_decodings_{split}.pickle")
if LOAD_RESULTS[split] and os.path.exists(df_decodings_file_path):
    with open(df_decodings_file_path, "rb") as f:
        df_decodings = pickle.load(f)
else:
    df_decodings = get_df_decodings(benchmark_sources, specific_proteins[split], generate_min_proteome_kmer_lens, generate_non_self_prob_factors, split_mhc_1_genotypes[split])
""")

# 1
shared_code.append("""design_sequences(df_decodings, loch)""")

# 2
shared_code.append("""to_colabfold(df_decodings, loch, ignore_seq_hashes=AF_ERRORS_SEQ_HASHES)""")

# 3
shared_code.append("""
from_colabfold(df_decodings, loch)
# from_colabfold(list(df_decodings.query('source_id != "template"').seq_hash), loch)
# template_pdbs_to_loch(df_decodings, loch)
""")

# 4
shared_code.append("""
add_cape_mpnn_designs(df_decodings, CAPE_MPNN_CKPT_ID, loch) 
""")

# 5
shared_code.append("""
add_tm_data(df_decodings, loch, overwrite=OVERWRITE)
""")

# 6
shared_code.append("""analyse_kmer(df_decodings, overwrite=OVERWRITE)""")

# 7
shared_code.append("""
check_deimmunized(df_decodings, show=False)

for _, row in df_decodings.iterrows():
    if row.seq is None or len(row.seq) == 0:
        print(f"Seq is missing: {row.protein_id} {rename_source_ids_X[row.source_id]}")
    elif row.tm_data is None:
        print(f"TM data is missing: {row.protein_id} {rename_source_ids_X[row.source_id]}")
""")

# 8
shared_code.append("""
with open(df_decodings_file_path, "wb") as f:
    pickle.dump(df_decodings, f)  
""")

# 9
shared_code.append("""df_plot[split] = get_df_plot(df_decodings)""")

# 10
shared_code.append("""seq_hashes_exp, seq_hashes_AF = run_destress(df_plot[split])""")

# 11
shared_code.append("""
for seq_hash in seq_hashes_exp:
    Protein.from_loch(seq_hash, predictor_structure_name='exp')
    Protein.proteins[seq_hash].load_destress(predictor_structure_name='exp')

for seq_hash in seq_hashes_AF:
    Protein.from_loch(seq_hash, predictor_structure_name='AF')
    Protein.proteins[seq_hash].load_destress(predictor_structure_name='AF')

for idx, row in df_plot[split].iterrows():
    if isinstance(row.seq_hash, str):
        Protein.proteins[row.seq_hash].ref_seq_hash = df_plot[split].query(f"protein_id == '{row.protein_id}' and source_id == 'template'").iloc[0].seq_hash

df_plot[split]['rosetta p.a.'] = df_plot[split].apply(
    lambda r: 
    Protein.proteins[r.seq_hash].get_info('rosetta_total', delta=False)/len(r.seq) 
    if r.seq_hash in Protein.proteins else
    None
, axis=1)

df_plot[split]['aggrescan3d max'] = df_plot[split].apply(
    lambda r: 
    Protein.proteins[r.seq_hash].get_info('aggrescan3d_max', delta=False)
    if r.seq_hash in Protein.proteins else
    None
, axis=1)

df_plot[split]['delta isoelectric point'] = df_plot[split].apply(
    lambda r: 
    Protein.proteins[r.seq_hash].get_info('isoelectric_point', delta=True)
    if r.seq_hash in Protein.proteins else
    None
, axis=1)

df_plot[split]['isoelectric point'] = df_plot[split].apply(
    lambda r: 
    Protein.proteins[r.seq_hash].get_info('isoelectric_point', delta=False)
    if r.seq_hash in Protein.proteins else
    None
, axis=1)
""")


# 12
shared_code.append("""
for fig_X_config in fig_X_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_X_config, df_plot[split])

    source_ids = get_successful_source_ids(_df, PLOT_MIN_TM_SCORE)
    fig_X(_df, source_ids, rename_source_ids_X, palette_source_id, min_tm_score=PLOT_MIN_TM_SCORE, 
        homo_sapiens_proteins=HUMAN_PROTEINS, fig_height=0.7, hspace=.1, wspace=0.15, groupseps=[2,6,10])
    
    fig = plt.gcf()
    fig.tight_layout()
    if SAVE_FIGURES:
        fig.savefig(join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, 
            f"figure_X_{split}__{suffix}.pdf"), 
            bbox_inches='tight'
        )
""")

# 13
shared_code.append("""
for fig_A_config in fig_A_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_A_config, df_plot[split])

    fig = plt.figure(figsize=(A4_width, A4_height))
    gs = gridspec.GridSpec(4, 2, width_ratios=[1, 10], height_ratios=[2, 4, 4, 4], wspace=.6, hspace=.23)
    
    #ax = fig.add_subplot(gs[0, 0])
    plot_fig_A(_df.query(f"source_id in {benchmark_source_ids}"), palette_source_id, area=gs[0, 1], x_lim_pnip=0.25, min_tm_score=PLOT_MIN_TM_SCORE)
    for i, l in enumerate([5,6,7]):
        ax = fig.add_subplot(gs[i+1, 0])
        plot_text(f"min human {l}-mers", ax, rotation=90, y_pos=0.5)
        plot_fig_A(
            _df.query(f"source_id in {selected_source_ids[split]['A']} and min_self_kmer_length == {l}").\
                sort_values('non_self_prob_factor', ascending=False), 
            palette_source_id,
            area=gs[i+1, 1],
            rename_source_ids=rename_source_ids_A, x_lim_pnip=0.25, min_tm_score=PLOT_MIN_TM_SCORE, show_x_labels=False
        )
        
    fig.tight_layout()
    
    if SAVE_FIGURES:
        fig.savefig(join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, 
            f"figure_A_{split}__{suffix}.pdf"), 
            bbox_inches='tight'
        )
""")

# 14
shared_code.append("""
for fig_B_config in fig_B_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_B_config, df_plot[split])

    plot_fig_B(_df.query(f'source_id in {selected_source_ids[split]["B"]}'), palette_source_id, rename_source_ids_B, fig_width=fig_width_B, ylabel=ylabel_B, min_tm_score=PLOT_MIN_TM_SCORE)
    
    fig = plt.gcf()
    fig.tight_layout()
    if SAVE_FIGURES:
        fig.savefig(join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, 
            f"figure_B_{split}__{suffix}.pdf"), 
            bbox_inches='tight'
        )
""")

# 15
shared_code.append("""
dfs_fig_C = {}
for fig_C_config in fig_C_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_C_config, df_plot[split])
    df_fig_C = _df.query(f'source_id in {selected_source_ids[split]["C"]}').copy()
    closest_pkmer = add_min_diff(df_fig_C, save=False)
    dfs_fig_C[suffix] = df_fig_C
""")

# 16
shared_code.append("""
protein_ids_separate = ['1XGD', '1B9K']
n_cols=2
hspace=0.75
source_label_fontsize=7
""")

# 17
shared_code.append(r"""
for fig_C_config in fig_C_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_C_config, df_plot[split])

    df_fig_C = dfs_fig_C[suffix]
    
    protein_ids = [x for x in  specific_proteins[split] if x not in protein_ids_separate] + protein_ids_separate

    dissimilarity_column='min_diff'
    plot_dissimilarity(df_fig_C, protein_ids, selected_source_ids[split]['C'], 
        rename_source_ids=rename_source_ids_X, 
        palette=palette_source_id, 
        n_cols=n_cols, 
        hspace=1., 
        min_tm_score=PLOT_MIN_TM_SCORE,
        dissimilarity_column=dissimilarity_column,
        dissimilarity_label='BLOSUM62\ndissimilarity',
        dissimilarity_scale='linear'
    )
    
    fig = plt.gcf()
    fig.tight_layout()
    
    if SAVE_FIGURES:
        save_fig = join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, f"figure_C_{split}__{dissimilarity_column}__{suffix}.pdf")
        fig.savefig(save_fig, bbox_inches='tight')
""")

In [None]:
specific_code = {s: [None] * len(shared_code) for s in Split}

In [None]:
specific_code[Split.PREDICT][13] = specific_code[Split.TEST][13] = """
for fig_A_config in fig_A_configs[split]:
    _df, query, suffix = restrict_df_plot(fig_A_config, df_plot[split])

    fig = plt.figure(figsize=(A4_width, A4_height/3))
    gs = gridspec.GridSpec(2, 1, width_ratios=[10], height_ratios=[3, 2], wspace=.6, hspace=.23)

    plot_fig_A(_df.query(f"source_id in {benchmark_source_ids}"), palette_source_id, area=gs[0, 0], x_lim_pnip=0.25, min_tm_score=PLOT_MIN_TM_SCORE)

    plot_fig_A(
        _df.query(f"source_id in {selected_source_ids[split]['A']} and source_id not in {benchmark_source_ids}").\
            sort_values('non_self_prob_factor', ascending=False), 
        palette_source_id, area=gs[1, 0],
        rename_source_ids=rename_source_ids_C, x_lim_pnip=0.25, min_tm_score=PLOT_MIN_TM_SCORE, show_x_labels=False
    )       
    fig.tight_layout()
    
    if SAVE_FIGURES:
        fig.savefig(join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, 
            f"figure_A_{split}__{suffix}.pdf"), 
            bbox_inches='tight'
        )

"""


tmp = """
n_cols=4
hspace=0.8
source_label_fontsize=6   
"""
specific_code[Split.TEST][16] = """protein_ids_separate = ['2BK8', '3WOY', '5OA9', '6TPT']""" + tmp
specific_code[Split.PREDICT][16] =  """protein_ids_separate = ['1UBQ', '1HHK']""" + tmp


# Val

In [None]:
split = Split.VAL

generate_min_proteome_kmer_lens = MIN_PROTEOME_KMER_LENS
generate_non_self_prob_factors = NON_SELF_PROB_FACTORS

selected_source_ids[split] = {
    'A': benchmark_source_ids + get_Beam_source_ids([5,6,7], PLOT_NON_SELF_PROB_FACTORS, person_mhc_1_genotype),
    'B': benchmark_source_ids + get_Beam_source_ids([5, 6], [0.5, 0.9, 0.99], person_mhc_1_genotype),
    'C': benchmark_source_ids + get_Beam_source_ids([5, 6], [0.5, 0.9, 0.99], person_mhc_1_genotype),
}

fig_width_B = A4_width
ylabel_B = True

In [None]:
code = []
for shared_c, specific_c in zip(shared_code, specific_code[split]):
    if specific_c is not None:
        code.append(specific_c)
    else:
        code.append(shared_c)

## Construct df_decodings

In [None]:
i = 0
print(code[i])
exec(code[i])

## Generate beam search commands

In [None]:
i = 1
print(code[i])
exec(code[i])

run the ``cape-beam.py`` commands generated above in a terminal of the container

## Predict 3D structures

In [None]:
i = 2
print(code[i])
exec(code[i])

run the following in a shell on the host system
```
colabfold_batch "${CAPE}/artefacts/CAPE/colabfold/input/" "${CAPE}/artefacts/CAPE/colabfold/output/" --amber --use-gpu-relax

```

In [None]:
i = 3
print(code[i])
exec(code[i])

In [None]:
i = 4
print(code[i])
exec(code[i])

In [None]:
i = 5
print(code[i])
exec(code[i])

In [None]:
i = 6
print(code[i])
exec(code[i])

In [None]:
for _, row in df_decodings.iterrows():
    if row.seq is None or len(row.seq) == 0:
        print(f"Seq is missing: {row.protein_id} {row.source_id} {rename_source_ids_X[row.source_id]} {row.min_self_kmer_length} {row.non_self_prob_factor} {row.tune_mhc_1_genotype}")
    elif row.tm_data is None:
        print(f"TM data is missing: {row.protein_id} {row.source_id} {rename_source_ids_X[row.source_id]} {row.min_self_kmer_length} {row.non_self_prob_factor} {row.tune_mhc_1_genotype}")

In [None]:
i = 7
print(code[i])
exec(code[i])

In [None]:
i = 8
print(code[i])
exec(code[i])

## Add additional columns

In [None]:
i = 9
print(code[i])
exec(code[i])

## Run DESTRESS

In [None]:
i = 10
print(code[i])
exec(code[i])

```
Install DE-STRESS (https://github.com/wells-wood-research/de-stress) command line tool and run destress evaluations on host system

PF=${CAPE}
DESTRESS_PATH=<path where you installed de-stress>
python ${PF}/tools/run_destress.py --destress_prog_dir_path $DESTRESS_PATH --project CAPE-Beam
```


In [None]:
i = 11
print(code[i])
exec(code[i])

## Plots

### Figure X

In [None]:
i = 12
print(code[i])
exec(code[i])

### Figure A

In [None]:
i = 13
print(code[i])
exec(code[i])

### Figure B

In [None]:
i = 14
print(code[i])
exec(code[i])

### Figure C

In [None]:
# i = 15
# print(code[i])
# exec(code[i])

In [None]:
# i = 16
# print(code[i])
# exec(code[i])

In [None]:
# i = 17
# print(code[i])
# exec(code[i])

# Test

In [None]:
MIN_SELF_KMER_LENS_TEST = [5, 6]
NON_SELF_PROB_FACTORS_TEST = [0.9]

split = Split.TEST

generate_min_proteome_kmer_lens = MIN_SELF_KMER_LENS_TEST
generate_non_self_prob_factors = NON_SELF_PROB_FACTORS_TEST

selected_source_ids[split] = {
    'A': benchmark_source_ids + get_Beam_source_ids(MIN_SELF_KMER_LENS_TEST, NON_SELF_PROB_FACTORS_TEST, person_mhc_1_genotype),
    'B': benchmark_source_ids + get_Beam_source_ids(MIN_SELF_KMER_LENS_TEST, NON_SELF_PROB_FACTORS_TEST, person_mhc_1_genotype),
    'C': benchmark_source_ids + get_Beam_source_ids(MIN_SELF_KMER_LENS_TEST, NON_SELF_PROB_FACTORS_TEST, person_mhc_1_genotype),
}

fig_width_B = A4_width * 5/9
ylabel_B = False

In [None]:
code = []
for shared_c, specific_c in zip(shared_code, specific_code[split]):
    if specific_c is not None:
        code.append(specific_c)
    else:
        code.append(shared_c)

## Construct df_decodings

In [None]:
i = 0
print(code[i])
exec(code[i])

## Generate beam search commands

In [None]:
i = 1
print(code[i])
exec(code[i])

run the ``cape-beam.py`` commands generated above in a terminal of the container

## Predict 3D structures

In [None]:
i = 2
print(code[i])
exec(code[i])

run the following in a shell on the host system
```
colabfold_batch "${CAPE}/artefacts/CAPE/colabfold/input/" "${CAPE}/artefacts/CAPE/colabfold/output/" --amber --use-gpu-relax

```

In [None]:
i = 3
print(code[i])
exec(code[i])

In [None]:
i = 4
print(code[i])
exec(code[i])

In [None]:
i = 5
print(code[i])
exec(code[i])

In [None]:
i = 6
print(code[i])
exec(code[i])

In [None]:
i = 7
print(code[i])
exec(code[i])

In [None]:
i = 8
print(code[i])
exec(code[i])

## Add additional columns

In [None]:
i = 9
print(code[i])
exec(code[i])

## Run DESTRESS

In [None]:
i = 10
print(code[i])
exec(code[i])

```
Install DE-STRESS (https://github.com/wells-wood-research/de-stress) command line tool and run destress evaluations on host system

PF=${CAPE}
DESTRESS_PATH=<path where you installed de-stress>
python ${PF}/tools/run_destress.py --destress_prog_dir_path $DESTRESS_PATH --project CAPE-Beam
```


In [None]:
i = 11
print(code[i])
exec(code[i])

## Plots

### Figure X

In [None]:
i = 12
print(code[i])
exec(code[i])

### Figure A

In [None]:
i = 13
print(code[i])
exec(code[i])

### Figure B

In [None]:
i = 14
print(code[i])
exec(code[i])

### Figure C

In [None]:
i = 15
print(code[i])
exec(code[i])

In [None]:
i = 16
print(code[i])
exec(code[i])

In [None]:
i = 17
print(code[i])
exec(code[i])

## Other Genotypes

In [None]:
all_mhc_1_genotypes

In [None]:
set_kmers_presented(df_decodings, MHC_1_PREDICTOR_EVAL)

In [None]:
tm_scores_per_mhc_1_genotype = {}

kmers_presented_per_mhc_1_genotype_template = {}
kmers_presented_per_mhc_1_genotype_beam = {}
kmers_presented_per_mhc_1_genotype_beam_vs_template = {}
kmers_presented_per_mhc_1_genotype_beam_frac = {}

min_self_kmer_length = 5
for mhc_1_genotype in all_mhc_1_genotypes:
    _df = df_decodings.query(f"eval_mhc_1_genotype == '{mhc_1_genotype}'")
    
    tm_scores_per_mhc_1_genotype[mhc_1_genotype] = [_df.query(f"source_id not in {benchmark_source_ids} and min_self_kmer_length == {min_self_kmer_length} and protein_id == '{_protein_id}'").iloc[0].tm_data[0] for _protein_id in specific_proteins[split]]

    kmers_presented_per_mhc_1_genotype_beam[mhc_1_genotype] = [len(_df.query(f"source_id not in {benchmark_source_ids} and protein_id == '{_protein_id}'").iloc[0].kmers_presented) for _protein_id in specific_proteins[split]]
    kmers_presented_per_mhc_1_genotype_template[mhc_1_genotype] = [len(_df.query(f"source_id == 'template' and protein_id == '{_protein_id}'").iloc[0].kmers_presented) for _protein_id in specific_proteins[split]]
    kmers_presented_per_mhc_1_genotype_beam_vs_template[mhc_1_genotype] = np.array(kmers_presented_per_mhc_1_genotype_beam[mhc_1_genotype])/np.array(kmers_presented_per_mhc_1_genotype_template[mhc_1_genotype])
    kmers_presented_per_mhc_1_genotype_beam_frac[mhc_1_genotype] = np.array(kmers_presented_per_mhc_1_genotype_beam[mhc_1_genotype])/[np.sum(get_possible_peptides(_df.query(f"source_id not in {benchmark_source_ids} and protein_id == '{_protein_id}'").iloc[0].immuno_chains, [8, 9, 10])) for _protein_id in specific_proteins[split]]

In [None]:
x = np.array(tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[0]])  # your x values here
y = np.array(tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[1]])  # your y values here

X = sm.add_constant(x)  # add intercept data

model = sm.OLS(y, X).fit()

intercept, slope = model.params
# Get standard errors
se_intercept, se_slope = model.bse

# Test H0: Intercept = 0
t_intercept = (intercept - 0) / se_intercept
p_intercept = 2 * (1 - stats.t.cdf(np.abs(t_intercept), df=model.df_resid))
print(f"p-value intercept == 0: {p_intercept}")

# Test H0: Slope = 1
t_slope = (slope - 1) / se_slope
p_slope = 2 * (1 - stats.t.cdf(np.abs(t_slope), df=model.df_resid))
print(f"p-value slope == 1:{p_slope}")

affine_estimate = f"y = {intercept:.3f} + {slope:.3f}x"
p_values = f"p-values:\n   intercept == 0: {p_intercept:.3f}\n   slope == 1: {p_slope:.3f}"
print(affine_estimate)

In [None]:
# Sense check (estimate (y - x) = intercept + slope * x)... so the test vs 0 makes sense
x = np.array(tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[0]])
y = np.array(tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[1]]) - np.array(tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[0]])
X = sm.add_constant(x)

model = sm.OLS(y, X).fit()

intercept_check, slope_check = model.params
slope_check += 1
affine_estimate_check = f"y = {intercept_check:.3f} + {slope_check:.3f}x"
print(affine_estimate_check)
print(model.pvalues)

assert all([np.isclose(p1, p2) for (p1, p2) in zip([p_intercept, p_slope], model.pvalues)])

In [None]:
fig = plt.figure(figsize=(A4_width, A4_height/3.4))
gs = gridspec.GridSpec(1, # rows
                       2, # cols
                       width_ratios=[1, 1],
                       height_ratios=[1],
                       hspace=0.1,
                       wspace=0.3,
                      )

# TM-scores vs TM-scores
ax = fig.add_subplot(gs[0, 0])
ax.text(-.1, 1.2, f"{string.ascii_lowercase[0]})", transform=ax.transAxes, fontsize=15, fontweight='bold', va='top', ha='right')
ax.set_title('TM-scores')
ax.set_xlabel('primary genotype')
ax.set_ylabel('alternative genotype')
sns.scatterplot(
    x=tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[0]], 
    y=tm_scores_per_mhc_1_genotype[all_mhc_1_genotypes[1]],
    ax=ax
)
ax.plot([0, 1.], [intercept, intercept + slope], label=f'regression', color='lightblue', linestyle='--')
ax.set_xlim((0., 1.))
ax.set_ylim((0., 1.))

ax.text(x=0.05, y=0.75, s=f"{affine_estimate}\n{p_values}")



# presented 8-10mers not in proteome as % of template
ax = fig.add_subplot(gs[0, 1])
ax.text(-.1, 1.2, f"{string.ascii_lowercase[1]})", transform=ax.transAxes, fontsize=15, fontweight='bold', va='top', ha='right')
ax.set_title('relative number of \n potentially immunogenic peptides')
ax.set_xlabel('primary genotype [frac of 8-10mers]')
ax.set_ylabel('alternative genotype [frac of 8-10mers]')
sns.scatterplot(
    # x=kmers_presented_per_mhc_1_genotype_beam_vs_template[all_mhc_1_genotypes[0]], 
    # y=kmers_presented_per_mhc_1_genotype_beam_vs_template[all_mhc_1_genotypes[1]],
    x=kmers_presented_per_mhc_1_genotype_beam_frac[all_mhc_1_genotypes[0]], 
    y=kmers_presented_per_mhc_1_genotype_beam_frac[all_mhc_1_genotypes[1]],
    ax=ax
)
ax.axvline(np.mean(kmers_presented_per_mhc_1_genotype_beam_frac[all_mhc_1_genotypes[0]]), color='lightblue', linestyle='--')
ax.axhline(np.mean(kmers_presented_per_mhc_1_genotype_beam_frac[all_mhc_1_genotypes[1]]), color='lightblue', linestyle='--')

ax.xaxis.set_major_formatter(PercentFormatter(1.0))
ax.yaxis.set_major_formatter(PercentFormatter(1.0))

for mhc_1_genotype in all_mhc_1_genotypes:
    print(np.mean(kmers_presented_per_mhc_1_genotype_beam_frac[mhc_1_genotype]))     


if SAVE_FIGURES:
    fig.savefig(join(G.ENV.ARTEFACTS, "eval", "figures", G.DOMAIN, 
        f"figure_G_{split}.pdf"), 
        bbox_inches='tight'
    )



# Shutdown

In [None]:
with open(specific_proteins_file_path, "wb") as f:
    pickle.dump(specific_proteins, f)    