In [1]:
!git clone https://github.com/meyer-lab-cshl/BATMAN.git
%cd BATMAN/run_batman

Cloning into 'BATMAN'...
remote: Enumerating objects: 211, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 211 (delta 67), reused 93 (delta 45), pack-reused 88[K
Receiving objects: 100% (211/211), 1.27 MiB | 13.01 MiB/s, done.
Resolving deltas: 100% (83/83), done.
/content/BATMAN/run_batman


In [8]:
from active_learning_functions import train,active_learning_cycle,return_peptides_to_sample
from active_learning_functions import generate_mutant_features,peptide2index
import numpy as np
import pandas as pd

In [22]:
# Read full dataset for all TCRs
full_peptide_data = pd.read_csv('test_input.csv')

In [23]:
# We split full data into train and test, leaving one TCR out
left_out_TCR = 'TCR1'

# Split and save train and test data
train_data = full_peptide_data[full_peptide_data.tcr!=left_out_TCR].copy()
# The trick to name all TCRs the same so that BATMAN gives a single weight profile
train_data.tcr = 'pan-TCR'
train_data.to_csv('train.csv')

test_data = full_peptide_data[full_peptide_data.tcr==left_out_TCR]
test_data.to_csv('test.csv')

In [24]:
# (Note that for some reason, pyBATMAN is not able to load AA matrix data, so we load it explicitly)
aa_matrix = pd.read_csv('blosum100.csv', index_col=0)

In [25]:
################# Pan TCR #################################################
# Train BATMAN to infer pan-TCR weights and AA distance matrix
inferred_weights_pan,_,inferred_matrix_pan = train('train.csv',
                                                   'full',
                                                   aa_matrix,
                                                   steps = 80000,
                                                   seed = 100)



Shape validation failed: input_shape: (1, 50000), minimum_shape: (chains=2, draws=4)


In [26]:
# Enter index peptide for the new TCR
index_peptide_candidate_TCR = 'NLVPMVATV'

# Round 1: Return peptides to sample
al_peptides = return_peptides_to_sample(index_peptide_candidate_TCR, # index peptide
                              inferred_matrix_pan, # AA matrix to find distance
                              inferred_weights_pan.to_numpy() # positional weights
                              )

In [27]:
# sample peptides
al_train_set = test_data.loc[test_data['peptide'].isin(al_peptides)].copy()
al_train_index = test_data.loc[test_data['peptide']==index_peptide_candidate_TCR].copy()
al_train_set = pd.concat([al_train_set,al_train_index])

al_test_set = test_data.loc[test_data['peptide'].isin(al_train_set['peptide'])==False].copy()

In [28]:
# run AL: round 1
w_al,auc_mean = active_learning_cycle(
    al_train_set,al_test_set,
    inferred_matrix_pan,
    inferred_weights_pan.to_numpy(),
    steps=40000,
    seed=111)



Shape validation failed: input_shape: (1, 50000), minimum_shape: (chains=2, draws=4)


In [29]:
# AL results: round 1
print(w_al)
print(auc_mean)

[[0.00734312 0.31842457 0.50267023 0.21028037 0.32109479 0.40720961
  0.22162884 1.         0.        ]]
0.5939345504753758


In [30]:
# Round 2: Return peptides to sample
al_peptides = return_peptides_to_sample(index_peptide_candidate_TCR, # index peptide
                              inferred_matrix_pan, # AA matrix to find distance
                              w_al, # positional weights
                              al_train_set['peptide'].tolist())
# Get labels for new peptides
al_train_set_addition = test_data.loc[test_data['peptide'].isin(al_peptides)].copy()

# New train and test set
al_train_set = pd.concat([al_train_set,al_train_set_addition])
al_test_set = test_data.loc[test_data['peptide'].isin(al_train_set['peptide'])==False].copy()

In [31]:
# run AL: round 2
w_al,auc_mean = active_learning_cycle(
    al_train_set,al_test_set,
    inferred_matrix_pan,
    w_al,
    steps=40000,
    seed=111)



Shape validation failed: input_shape: (1, 50000), minimum_shape: (chains=2, draws=4)


In [32]:
# AL results: round 2
print(w_al)
print(auc_mean)

[[0.         0.43308824 0.66397059 0.06323529 0.97205882 0.32573529
  0.80367647 1.         0.18308824]]
0.7583155270655272
