In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib

In [34]:
from visualise_results.utils import *
import numpy as np
from typing import *
import glob

import sys, os, shutil

ROOT_PROJECT =  os.path.join(os.path.normpath(os.path.join(os.getcwd()))) 
sys.path[0] = ROOT_PROJECT

In [None]:
from bo.custom_init import InitialBODataset, get_initial_dataset_path
from bo.utils import save_w_pickle
from tqdm import tqdm

In [14]:
result_by_method = []

for meth_file in glob.glob("./results_data/*_optim_res.csv"):
    result_by_method.append(pd.read_csv(meth_file, index_col=0))
results = pd.concat(result_by_method)
results = results.loc[np.sort(results.index.values)]

In [15]:
with open("./dataloader/core_antigens.txt") as f:
    core_antigens = list(map(lambda antigen: antigen.rstrip(), f.readlines()))


with open('./utils_data/all_antigens.txt') as file:
    antigens = file.readlines()
    all_antigens = [antigen.rstrip() for antigen in antigens]

remaining_antigens = list(set(all_antigens)-set(core_antigens))

In [24]:
thresholdPd = pd.read_csv("./utils_data/ListThresholds.txt", sep=" ")
# Mascotte and above 1% 

In [26]:
# For each of the 12 core_antigens
# 
# 1 inital dataset Loosers (20)  # samples to get to Mascotte,Heroes,SuperHeroes
# 1 inital dataset Loosers+Mascotte (20) # samples to get to Heroes,SuperHeroes
# 1 inital dataset Loosers+Mascotte+Heroes (20) # samples to get to SuperHeroes

AA = 'ACDEFGHIKLMNPQRSTVWY'
AA_to_idx = {aa: i for i, aa in enumerate(AA)}

categories = ['NonBinders', 'Loosers', 'Mascotte', 'Heroes', 'SuperHeroes']

def get_energy_interval(antigen, category):
    if category not in categories:
        raise(ValueError(category))
    ind = categories.index(category)
    upper_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==category]['maxEnergy'].values[0]
    if ind == (len(categories) - 1):
        lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[ind]]['minEnergy'].values[0]
    else:
        lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[min(len(categories) - 1, ind + 1)]]['maxEnergy'].values[0] + 0.01
    return lower_bound, upper_bound

In [27]:
antigen_name_ = "2DD8_S"
get_energy_interval(antigen=antigen_name_, category='SuperHeroes')

  upper_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==category]['maxEnergy'].values[0]
  lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[ind]]['minEnergy'].values[0]


(-125.07, -119.16)

In [28]:
def get_proteins_by_category(antigen, category, top_cut_ratio=0):
    """
    antigen: the name of the antigen
    category: the name of the category (Loosers,...)
    top_cut_ratio: remove the top `top_cut_ratio` * 100 % of the protein in terms of binding energy (the lower the better)
    """
    aux_results = results[results.Antigen == antigen]
    lower_energy, upper_energy = get_energy_interval(antigen=antigen, category=category)
    aux_results = aux_results[aux_results['Last Binding Energy'] < upper_energy][aux_results['Last Binding Energy'] > lower_energy]
    aux_results = aux_results[['Last Protein', 'Last Binding Energy']].drop_duplicates(subset=['Last Protein']).values
    
    # Keep the highest energies
    keeps = np.ceil(top_cut_ratio * len(aux_results)).astype(int)
    inds = aux_results[:, 1].argsort()[keeps:]
    return aux_results[inds]

In [29]:
proteins_scores_ = get_proteins_by_category('2DD8_S', 'SuperHeroes', 0)
proteins_scores_.shape

  upper_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==category]['maxEnergy'].values[0]
  lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[ind]]['minEnergy'].values[0]
  aux_results = aux_results[aux_results['Last Binding Energy'] < upper_energy][aux_results['Last Binding Energy'] > lower_energy]


(1734, 2)

In [30]:
def getInitUpToCat(antigen: str, top_category: str, n: int, top_cat_top_cut_ratio: float=0, seed: int = 0):
    if top_category not in categories:
        raise(ValueError(top_category))
    np.random.seed(seed)
    ind = categories.index(top_category)
    n_per_cats = [n // ind for _ in range(ind)]
    n_per_cats[-1] = n - sum(n_per_cats[:-1])
    top_cut_ratios = [0 for _ in range(ind)]
    top_cut_ratios[-1] = top_cat_top_cut_ratio
    assert sum(n_per_cats) == n
    samples = []
    for i in range(ind):
        n_sample = n_per_cats[i]
        top_cut_ratio = top_cut_ratios[i]
        category = categories[i + 1]
        proteins_scores = get_proteins_by_category(antigen, category=category, top_cut_ratio=top_cut_ratio)
        sampled_inds = np.random.choice(np.arange(len(proteins_scores)), size=n_sample)
        samples.extend([[category, np.array([AA_to_idx[el] for el in  proteins_scores[sample_ind][0]]), proteins_scores[sample_ind][0], proteins_scores[sample_ind][1]] for sample_ind in sampled_inds])
    return pd.DataFrame(samples, columns=['type', 'AA to ind', 'Protein', 'Binding Energy'])

In [31]:
def getInitDataset(antigen: str, n_per_cat: Dict[str, int], top_cut_ratio_per_cat: Dict[str, float], seed: int):
    np.random.seed(seed)
    samples = []
    for category, n_sample in n_per_cat.items():
        top_cut_ratio = top_cut_ratio_per_cat.get(category, 0)
        proteins_scores = get_proteins_by_category(antigen, category=category, top_cut_ratio=top_cut_ratio)
        sampled_inds = np.random.choice(np.arange(len(proteins_scores)), size=n_sample)
        samples.extend([[category, np.array([AA_to_idx[el] for el in  proteins_scores[sample_ind][0]]), proteins_scores[sample_ind][0], proteins_scores[sample_ind][1]] for sample_ind in sampled_inds])
    return pd.DataFrame(samples, columns=['type', 'AA to ind', 'Protein', 'Binding Energy'])

In [32]:
def get_n_per_cat(n_loosers: int, n_mascottes: int, n_heroes):
    return dict(Loosers=n_loosers, Mascotte=n_mascottes, Heroes=n_heroes)

def get_top_cut_ratio_per_cat(top_cut_ratio_loosers: int, top_cut_ratio_mascottes: int, top_cut_ratio_heroes):
    return dict(Loosers=top_cut_ratio_loosers, Mascotte=top_cut_ratio_mascottes, Heroes=top_cut_ratio_heroes)


In [56]:
n_per_cat_s = [get_n_per_cat(20, 0, 0), get_n_per_cat(10, 10, 0), get_n_per_cat(6, 6, 8)] 
top_cut_ratio_per_cat_s = [get_top_cut_ratio_per_cat(0.5, 0, 0), get_top_cut_ratio_per_cat(0, 0.5, 0), get_top_cut_ratio_per_cat(0, 0, 0.5)]

selected_antigens = all_antigens

invalid_antigens_ = []
for antigen_name_ in tqdm(selected_antigens):
    for top_cut_ratio_per_cat_, n_per_cat_ in zip(top_cut_ratio_per_cat_s, n_per_cat_s):
        if len(results[results.Antigen == antigen_name_]) == 0:
            invalid_antigens_.append(antigen_name_)
            continue
        try:
            for seed_ in range(10):
                res = getInitDataset(antigen=antigen_name_, n_per_cat=n_per_cat_, top_cut_ratio_per_cat=top_cut_ratio_per_cat_, seed=seed_)        
                data = InitialBODataset(res)
                save_path = get_initial_dataset_path(
                    antigen_name=antigen_name_, n_per_cat=n_per_cat_, top_cut_ratio_per_cat=top_cut_ratio_per_cat_, seed=seed_
                )
                assert data.get_index_encoded_x().shape == (sum(n_per_cat_.values()), 11)
                save_w_pickle(data, save_path)
                print(save_path)
        except IndexError as e:
            invalid_antigens_.append(antigen_name_)
            print(antigen_name_)

  upper_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==category]['maxEnergy'].values[0]
  lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[min(len(categories) - 1, ind + 1)]]['maxEnergy'].values[0] + 0.01
  aux_results = aux_results[aux_results['Last Binding Energy'] < upper_energy][aux_results['Last Binding Energy'] > lower_energy]


/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/0/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/1/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/2/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/3/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/4/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/5/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/6/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/7/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/8/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/9/Loosers-20_Loosers-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/0/Loosers-10_Mascotte-10_Mascotte-0.5/init_data
/home/rladmin/antigenbinding/b

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.59s/it]

/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/8/Loosers-6_Mascotte-6_Heroes-8_Heroes-0.5/init_data
/home/rladmin/antigenbinding/bo/init_dataset/2DD8_S/9/Loosers-6_Mascotte-6_Heroes-8_Heroes-0.5/init_data





In [48]:
with open('./dataloader/valid_antigens_init_data.txt', 'w') as f:
    f.writelines(map(lambda a: a + '\n', list(set(all_antigens).difference(invalid_antigens_))))

In [39]:
len(set(all_antigens).difference(invalid_antigens_))

139

In [32]:
proteins_scores_ = get_proteins_by_category(antigen_name_, category='Loosers', top_cut_ratio=.5)

  upper_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==category]['maxEnergy'].values[0]
  lower_bound = thresholdPd[thresholdPd.AGname==antigen][thresholdPd.type==categories[min(len(categories) - 1, ind + 1)]]['maxEnergy'].values[0] + 0.01


In [33]:
proteins_scores_

array([], shape=(0, 2), dtype=object)