In [16]:
import os
import sys
import uuid

import random

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Dataset preprocessing

In [3]:
vdjdb_raw = pd.read_csv('data/raw/vdjdb_paired_010923.tsv', sep='\t')
iedb_raw = pd.read_csv('data/raw/iedb_010923.csv')
mcpas_raw = pd.read_csv('data/raw/mcpas-tcr_010923.csv', encoding='latin1')

  iedb_raw = pd.read_csv('data/raw/iedb_010923.csv')
  mcpas_raw = pd.read_csv('data/raw/mcpas-tcr_010923.csv', encoding='latin1')


## VDJdb

In [4]:
vdj_cols = ['complex.id', 'CDR3', 'V', 'J', 'Species', 'MHC A','MHC class', 'Epitope']
print('VDJdb raw:', vdjdb_raw.shape)

vdjdb = vdjdb_raw.dropna()
vdjdb_a = vdjdb[vdjdb['Gene']=='TRA'][vdj_cols].copy()
vdjdb_b = vdjdb[vdjdb['Gene']=='TRB'][vdj_cols].copy()

vdjdb = pd.merge(vdjdb_a, vdjdb_b, on='complex.id')
print('VDJdb paired:', vdjdb.shape)

assert (vdjdb['Species_x'] == vdjdb['Species_y']).all()
assert (vdjdb['Epitope_x'] == vdjdb['Epitope_y']).all()
assert (vdjdb['MHC A_x'] == vdjdb['MHC A_y']).all()

vdj_cols = ['complex.id', 'CDR3_x', 'V_x', 'J_x', 'CDR3_y', 'V_y', 'J_y', 'Species_y', 'MHC A_y', 'MHC class_y', 'Epitope_y']
vdjdb = vdjdb[vdj_cols].copy()
vdj_cols = ['complex.id', 'cdr3a', 'va', 'ja', 'cdr3b', 'vb', 'jb', 'species', 'mhc', 'mhc_class', 'epitope']
vdjdb.columns = vdj_cols

# final data formating
try:
    vdjdb = vdjdb.drop('complex.id', axis=1)
except:
    pass
vdjdb = vdjdb.replace('MHCI', 1) # name formatting
vdjdb = vdjdb.replace('HomoSapiens', 'human') # name formatting
vdjdb = vdjdb[vdjdb['species']=='human'].copy() # keep only humans
vdjdb = vdjdb.drop_duplicates()
print("preprocessed vdjdb:", vdjdb.shape)
vdjdb.head()


VDJdb raw: (3050, 17)
VDJdb paired: (1397, 15)
preprocessed vdjdb: (1102, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope
0,CIVRAPGRADMRF,TRAV26-1*01,TRAJ43*01,CASSYLPGQGDHYSNQPQHF,TRBV13*01,TRBJ1-5*01,human,HLA-B*08,1,FLKEKGGL
1,CAVPSGAGSYQLTF,TRAV20*01,TRAJ28*01,CASSFEPGQGFYSNQPQHF,TRBV13*01,TRBJ1-5*01,human,HLA-B*08,1,FLKEKGGL
2,CAYRPPGTYKYIF,TRAV38-2/DV8*01,TRAJ40*01,CASSALASLNEQFF,TRBV14*01,TRBJ2-1*01,human,HLA-B*08,1,FLKEKGGL
3,CIVRAPGRADMRF,TRAV26-1*01,TRAJ43*01,CASSYLPGQGDHYSNQPQHF,TRBV13*01,TRBJ1-5*01,human,HLA-B*08,1,FLKEQGGL
4,CAVPSGAGSYQLTF,TRAV20*01,TRAJ28*01,CASSFEPGQGFYSNQPQHF,TRBV13*01,TRBJ1-5*01,human,HLA-B*08,1,FLKEQGGL


## IEDB

IEDB queried for MHC 1 and human only

In [5]:
subset = ['Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated']
iedb = iedb_raw.dropna(subset=subset)
iedb = iedb.astype('str')
print("Full IEDB:", iedb.shape)

iedb_base_cols = ['Group Receptor ID', 'Receptor ID', 'Description', 'Organism', 'MHC Allele Names']

iedb_cols = [col for col in iedb.columns if "Curated" in col and ("CDR3" in col or "V" in col or "J" in col)]
iedb_cols = ['Description', 'Organism', 'MHC Allele Names', 'Curated Chain 1 V Gene', 'Curated Chain 1 J Gene', 'Chain 1 CDR3 Curated',  
            'Curated Chain 2 V Gene','Curated Chain 2 J Gene', 'Chain 2 CDR3 Curated']
iedb = iedb[iedb_cols]
print("IEDB Selected Columns:", iedb.shape)

iedb = iedb.replace('nan', np.nan)
iedb = iedb.dropna()
iedb.columns = ['epitope', 'species', 'mhc', 'va', 'ja', 'cdr3a', 'vb', 'jb', 'cdr3b']
iedb['mhc_class'] = 1
iedb['species'] = 'human' # IEDB filtered for only human

iedb = iedb[vdjdb.columns]
print('IEDB remove nan:', iedb.shape)
iedb = iedb.drop_duplicates()
print('IEDB remove duplicates:', iedb.shape)

# we need the CDR3 ends for 3D generation, one way to get this is to check if the sequence starts with C
iedb = iedb[iedb['cdr3a'].str.startswith('C')]
iedb = iedb[iedb['cdr3b'].str.startswith('C')]
print('IEDB with cdr3 ends only:', iedb.shape)

# remove rows with sequences containing #
print('IEDB cdr3a sequences with #:', iedb['cdr3a'].str.contains('#').sum())
iedb = iedb[iedb['cdr3a'].str.contains('#')==False]
print('IEDB cdr3b sequences with #:', iedb['cdr3b'].str.contains('#').sum())
iedb = iedb[iedb['cdr3b'].str.contains('#')==False]



# VJ name formating for cases where V, J gene starts with 'TCR' instead of 'TR'
iedb = iedb.replace('TCR', 'TR', regex=True)

print('IEDB final:', iedb.shape)

iedb

Full IEDB: (24032, 71)
IEDB Selected Columns: (24032, 9)
IEDB remove nan: (6093, 10)
IEDB remove duplicates: (6002, 10)
IEDB with cdr3 ends only: (5483, 10)
IEDB cdr3a sequences with #: 22
IEDB cdr3b sequences with #: 0
IEDB final: (5461, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope
81,CAVRPTSGGSYIPTF,TRAV21*01,TRAJ6*01,CASSYVGNTGELFF,TRBV6-5*01,TRBJ2-2*01,human,HLA-A*02:01,1,SLLMWITQC
173,CAGGTGNQFYF,TRAV35*02,TRAJ49*01,CAISEVGVGQPQHF,TRBV10-3,TRBJ1-5*01,human,HLA-A*02:01,1,AAGIGILTV
617,CALSEAGTGGSYIPTF,TRAV19,TRAJ6,CASSMFVGQPQHF,TRBV19,TRBJ1-5,human,HLA-A*02:01,1,GILGFVFTL
618,CAVSVEETSGSRLTF,TRAV41,TRAJ58,CASSFFHNNEQFF,TRBV19,TRBJ2-1,human,HLA-A*02:01,1,GILGFVFTL
619,CAYRSARDSSYKLIF,TRAV38-2/DV8*01,TRAJ12*01,CASSDHSVTGISSPLHF,TRBV7-9*03,TRBJ1-6*02,human,HLA-B7,1,TPRVTGGGAM
...,...,...,...,...,...,...,...,...,...,...
24299,CGAGETSGSRLTF,TRAV21,TRAJ58,CSVNLGGPTDTQYF,TRBV29-1,TRBJ2-3,human,HLA-B*07:02,1,KPVETSNSF
24300,CALEGSQGNLIF,TRAV9-2,TRAJ42,CSVPDGAEPYGYTF,TRBV20-1,TRBJ1-2,human,HLA-A*01:01,1,TTDPSFLGRY
24301,CLVGNTGGFKTIF,TRAV4,TRAJ9,CSVPDRGNTEAFF,TRBV29-1,TRBJ1-1,human,HLA-B*07:02,1,SPRWYFYYL
24303,CAPSRHAGNNRKLIW,TRAV9-2,TRAJ38,CSVQGGTNEKLFF,TRBV29-1,TRBJ1-4,human,HLA-A*01:01,1,VSDGGPNLY


## McPAS-TCR

In [6]:
mcpas_cols = ['CDR3.alpha.aa', 'CDR3.beta.aa', 'Species', 'Epitope.peptide', 'MHC', 'TRAV', 'TRAJ', 'TRBV', 'TRBJ']
mcpas = mcpas_raw[mcpas_cols].copy()
print('McPAS-TCR raw:', mcpas.shape)
mcpas = mcpas.dropna()
print('McPAS-TCR drop na:', mcpas.shape)
mcpas = mcpas.drop_duplicates()
print('McPAS-TCR drop duplicates:', mcpas.shape)

# formatting
mcpas['Species'] = mcpas['Species'].str.lower()
mcpas.columns = ['cdr3a', 'cdr3b', 'species', 'epitope', 'mhc', 'va', 'ja', 'vb', 'jb']
mcpas['mhc_class'] = 1
mcpas = mcpas[vdjdb.columns].copy()

# remove rows with sequences containing #
print('McPAS cdr3a sequences with #:', mcpas['cdr3a'].str.contains('#').sum())
mcpas = mcpas[mcpas['cdr3a'].str.contains('#')==False]
print('McPAS cdr3b sequences with #:', mcpas['cdr3b'].str.contains('#').sum())
mcpas = mcpas[mcpas['cdr3b'].str.contains('#')==False]

print('McPAS-TCR final:', mcpas.shape)

mcpas 

McPAS-TCR raw: (39985, 9)
McPAS-TCR drop na: (3124, 9)
McPAS-TCR drop duplicates: (2914, 9)
McPAS cdr3a sequences with #: 0
McPAS cdr3b sequences with #: 0
McPAS-TCR final: (2914, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope
73,CALGLMSNYNVLYF,TRAV4,TRAJ14-4,CASSSGLGGTLYF,TRBV10,TRBJ2-4,mouse,H-2Db,1,SSGVENPGGYCLTKW
80,CAAETTASLGKLQF,TRAV11,TRAJ9,CASGDHGLSYEQYF,TRBV13-3,TRBJ2-6,mouse,H-2b,1,DEPLTSLTPRCNTAWNRLKL
84,CALGDRGSGGSNYK,TRAV4,TRAJ3DT,CAWSRTGGNSDYTF,TRBV31,TRBJ1-2,mouse,H-2b,1,DEPLTSLTPRCNTAWNRLKL
120,CAAEASSSFSKLVF,TRAV11-1,TRAJ42,CASAPDRGGERLF,TRBV8-2,TRBJ1-4,mouse,H-2q,1,GPEGAQGPRGEPGTP
121,CAAEASSSFSKLVF,TRAV11-1,TRAJ42,CASAPDRGGERLF,TRBV13-2,TRBJ1-4,mouse,H-2q,1,GPEGAQGPRGEPGTP
...,...,...,...,...,...,...,...,...,...,...
39028,CALDGPSNTGKLIF,TRAV16,TRAJ37,CATSESSGQTYEQYF,TRBV15,TRBJ2-2,human,HLA-A2:01,1,FLCMKALLL
39029,CATDAEGNNRLAF,TRAV17,TRAJ7,CASSIFGGGLGEQFF,TRBV19,TRBJ2-7,human,HLA-A2:01,1,FLCMKALLL
39030,CGAVGYQKVTF,TRAV34,TRAJ13,CALNGEISYNEQFF,TRBV2,TRBJ2-2,human,HLA-A2:01,1,FLCMKALLL
39031,CAVIWYNNNDMRF,TRAV8-1,TRAJ43,CASSQGVNTGELFF,TRBV4-2,TRBJ2-1,human,HLA-A2:01,1,FLCMKALLL


## Combine

In [7]:
positives = pd.concat((iedb, vdjdb, mcpas), axis=0)
print("concate db:", positives.shape)
positives = positives.drop_duplicates(subset=['cdr3a', 'cdr3b', 'epitope'], keep='last', ignore_index=True)
print("concate db remove duplicates:", positives.shape)
positives['uuid'] = [uuid.uuid4() for _ in range(len(positives.index))]

positives

concate db: (9477, 10)
concate db remove duplicates: (8411, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope,uuid
0,CAVSVEETSGSRLTF,TRAV41,TRAJ58,CASSFFHNNEQFF,TRBV19,TRBJ2-1,human,HLA-A*02:01,1,GILGFVFTL,5e60dc1f-4716-4e13-955f-c18da5d794f3
1,CAVNKGYGQNFVF,TRAV12-2*02,TRAJ26*01,CASSPAGISYNSPLHF,TRBV7-9*03,TRBJ1-6*01,human,HLA-B8,1,ELRRKMMYM,1c53cd6f-31c2-44d1-adf8-53194b561137
2,CAVRDSSYSGAGSYQLTF,TRAV3*01,TRAJ28*01,CASSRLAGASTDTQYF,TRBV7-3*01,TRBJ2-3*01,human,HLA-B8,1,QIKVRVDMV,c63ee292-7117-467f-807a-a5b31b8bd0c1
3,CAVSDYGQNFVF,TRAV21*01,TRAJ26*01,CASSRLSSNTDTQYF,TRBV7-3*01,TRBJ2-3*01,human,HLA-B8,1,QIKVRVDMV,b9b49467-c3c0-4f8a-9686-b5d6d34865fa
4,CATAQVYSGGGADGLTF,TRAV17*01,TRAJ45*01,CASSRLAGNTDTQYF,TRBV7-3*01,TRBJ2-3*01,human,HLA-B8,1,QIKVRVDMV,f7b6056a-00d7-4e36-bc2b-410c14bb8362
...,...,...,...,...,...,...,...,...,...,...,...
8406,CALDGPSNTGKLIF,TRAV16,TRAJ37,CATSESSGQTYEQYF,TRBV15,TRBJ2-2,human,HLA-A2:01,1,FLCMKALLL,9f7fbc33-46c3-4a02-b217-3b0c6aca2e95
8407,CATDAEGNNRLAF,TRAV17,TRAJ7,CASSIFGGGLGEQFF,TRBV19,TRBJ2-7,human,HLA-A2:01,1,FLCMKALLL,f8432964-7689-4085-8b90-d47efae82822
8408,CGAVGYQKVTF,TRAV34,TRAJ13,CALNGEISYNEQFF,TRBV2,TRBJ2-2,human,HLA-A2:01,1,FLCMKALLL,75c1889d-78f0-493a-8095-d6b7c2daea80
8409,CAVIWYNNNDMRF,TRAV8-1,TRAJ43,CASSQGVNTGELFF,TRBV4-2,TRBJ2-1,human,HLA-A2:01,1,FLCMKALLL,91222f74-3516-4a9d-9cca-d6f597f3f471


In [8]:
complex = []
for i in range(len(positives)):
    complex.append((positives['cdr3a'].iloc[i], positives['cdr3b'].iloc[i], positives['epitope'].iloc[i]))
positives['complex'] = complex
positives_dict = positives.drop_duplicates(subset='complex').set_index('complex', drop=True).to_dict(orient='index')
try:
    positives = positives.drop('complex', axis=1)
except:
    pass
print('unique tcr-peptide pairs:', len(positives_dict.keys()))

unique tcr-peptide pairs: 8411


# Negative sample generation

In [9]:
# create a lookup dictionary of all positive binding samples
def lookup_dict(df: pd.DataFrame, cdr3a: str, cdr3b: str, epitope: str) -> dict:
    complex = []
    for i in range(len(df.index)):
        complex.append((df[cdr3a].iloc[i], df[cdr3b].iloc[i], df[epitope].iloc[i]))
    df['complex'] = complex
    df_dict = df.drop_duplicates(subset='complex').set_index('complex', drop=True).to_dict(orient='index')
    return df_dict

# vdjdb
vdjdb_lc_raw = pd.read_csv('data/raw/vdjdb_low_confidence.tsv', sep='\t')
vdjdb_lc_a = vdjdb_lc_raw[vdjdb_lc_raw['Gene']=='TRA'].copy()
vdjdb_lc_b = vdjdb_lc_raw[vdjdb_lc_raw['Gene']=='TRB'].copy()
vdjdb_lc = pd.merge(vdjdb_lc_a, vdjdb_lc_b, on='complex.id')

vdjdb_lc_dict = lookup_dict(vdjdb_lc, 'CDR3_x', 'CDR3_y', 'Epitope_y')

#iedb
subset = ['Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated']
iedb_lc = iedb_raw.dropna(subset=subset)
iedb_lc = iedb_lc.astype('str')

iedb_lc_dict = lookup_dict(iedb_lc, 'Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated', 'Description')

#mcpas-tcr
mcpas_lc = mcpas_raw.dropna(subset=['CDR3.alpha.aa', 'CDR3.beta.aa', 'Epitope.peptide'])

mcpas_lc_dict = lookup_dict(mcpas_lc, 'CDR3.alpha.aa', 'CDR3.beta.aa', 'Epitope.peptide')


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['complex'] = complex


In [10]:
i = 0
negatives = []
while i<len(positives.index):
    sample_idx = random.choice([k for k in range(len(positives)) if k!=i])
    neg_complex = (positives['cdr3a'].iloc[sample_idx], positives['cdr3b'].iloc[sample_idx], positives['epitope'].iloc[i])

    # check if cdr3a-cdr3a-peptide combination not in either positives/vdjdb/iedb/mcpas
    # .get() will return None if not in dict
    in_positives_dict = positives_dict.get(neg_complex)
    in_vdjdb_lc_dict = vdjdb_lc_dict.get(neg_complex)
    in_iedb_lc_dict = iedb_lc_dict.get(neg_complex)
    in_mcpas_lc_dict = mcpas_lc_dict.get(neg_complex)

    if (in_positives_dict is None) and (in_vdjdb_lc_dict is None) and (in_iedb_lc_dict is None) and (in_mcpas_lc_dict is None):
        # not in either db --> TCR-pMHC complex is negative binding
        sample = positives.iloc[sample_idx].copy()
        sample['epitope'] =  positives['epitope'].iloc[i]
        sample['mhc'] = positives['mhc'].iloc[i]
        if tuple(sample.values) not in negatives: # check if generated negative already exists
            negatives.append(tuple(sample.values))
            i+=1
                            
negatives = pd.DataFrame(negatives, columns=positives.columns)
print('generated negatives:', negatives.shape)
negatives = negatives.drop_duplicates() # this confirms there are no duplicates
print('generated negatives drop duplicates:', negatives.shape)

negatives['uuid'] = [uuid.uuid4() for _ in range(len(negatives.index))]
print('negative samples:', negatives.shape)
negatives

generated negatives: (8411, 11)
generated negatives drop duplicates: (8411, 11)
negative samples: (8411, 11)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope,uuid
0,CATVPLLTSGGGADGLTF,TRAV17,TRAJ45,CASSPTGRVQPQHF,TRBV3,TRBJ1-5,human,HLA-A*02:01,1,GILGFVFTL,64ecedb3-3427-46f1-a643-7f1725e8ff39
1,CAVAFGNQFYF,TRAV8-3:01,TRAJ4-01,CASSMTSGALYNEQFF,TRBV2-01,TRBJ2-1:01,human,HLA-B8,1,ELRRKMMYM,30203a99-8fb5-4fc7-8638-103789fe548f
2,CAVGSNYNVLYF,TRAV3N-3:01,TRAJ2-01,CASSGDSAETLYF,TRBV1-01,TRBJ2-3:01,mouse,HLA-B8,1,QIKVRVDMV,300a2dc6-6fca-4efc-b5f8-02a1f0b2acb9
3,CVVNALMDSNYQLIW,TRAV12-1,TRAJ33,CASSEGRGYEQYF,TRBV6-1,TRBJ2-7,human,HLA-B8,1,QIKVRVDMV,1b3204f9-defb-464b-b666-354b4f0df670
4,CAALTGNTGKLIF,TRAV5N-4:01,TRAJ3-01,CASSGLGSSAETLYF,TRBV1-01,TRBJ2-3:01,mouse,HLA-B8,1,QIKVRVDMV,0bc94aa1-1be2-4112-9827-0fee7b4d1d3a
...,...,...,...,...,...,...,...,...,...,...,...
8406,CAVNNARLMF,TRAV3,TRAJ31,CSVVWALGQPQHF,TRBV29,TRBJ1,human,HLA-A2:01,1,FLCMKALLL,478dac35-aec2-40f5-bc0c-f2155b55fa85
8407,CVVCRMDSSYKLIF,TRAV10,TRAJ12,CSVGSQGTNEKLFF,TRBV29,TRBJ1,human,HLA-A2:01,1,FLCMKALLL,82764f8a-ebd0-4f3d-9831-987f5cccb2b9
8408,CAMRSNYQLIW,TRAV12-3,TRAJ33,CATQRNQETQYF,TRBV6-1,TRBJ2-5,human,HLA-A2:01,1,FLCMKALLL,e8e7fe31-eb96-4119-919a-66d68839b683
8409,CAGQKYMRSQGNLIF,TRAV35,TRAJ42,CASSARTVNTEAFF,TRBV14,TRBJ1-1,human,HLA-A2:01,1,FLCMKALLL,2d1ffcbb-3e41-4316-8189-3826e48a19b2


In [11]:
# check if any overlap between negatives and positives
pd.merge(positives, negatives, on=['cdr3a', 'cdr3b', 'epitope'])

Unnamed: 0,cdr3a,va_x,ja_x,cdr3b,vb_x,jb_x,species_x,mhc_x,mhc_class_x,epitope,uuid_x,va_y,ja_y,vb_y,jb_y,species_y,mhc_y,mhc_class_y,uuid_y


# Saving Data

In [12]:
# save_path = 'data/preprocessed'
# positives.to_csv(os.path.join(save_path, 'positives_test.tsv'), sep='\t')
# negatives.to_csv(os.path.join(save_path, 'negatives_test.tsv'), sep='\t')

In [14]:
# concatenate positive and negative binding samples

positives['binding'] = 1
negatives['binding'] = 0

binding_df = pd.concat((positives, negatives), ignore_index=True)

print('Total pos/neg binding df:', binding_df.shape)

save_path = 'data/preprocessed'
binding_df.to_csv(os.path.join(save_path, 'tcrpmhc_binding.tsv'), sep='\t')

Total pos/neg binding df: (16822, 12)


# Hard split

This train/test split method is detailed in _Grazioli F, Mösch A, Machart P, Li K, Alqassem I, O’Donnell TJ and Min MR (2022) On TCR binding predictors failing to generalize to unseen peptides. Front. Immunol. 13:1014256. [doi: 10.3389/fimmu.2022.1014256](https://www.frontiersin.org/articles/10.3389/fimmu.2022.1014256/full#f)_

Here we aim to understand how the algorithm works

In [17]:
binding_df = pd.read_csv('data/preprocessed/tcrpmhc_binding.tsv', sep='\t')

In [25]:
epitope_count = binding_df['epitope'].value_counts()
print("number of unique epitopes:", len(epitope_count.values))
print("max count:", epitope_count.max())
print("min count:", epitope_count.min())
epitope_count.head(20)

number of unique epitopes: 903
max count: 1044
min count: 2


SSYRRPVGI       1044
GILGFVFTL       1026
SPRWYFYYL        932
YLQPRTFLL        810
TTDPSFLGRY       794
SSLENFRAYV       780
NLVPMVATV        686
LLWNGPMAV        682
HGIRNASFI        500
GLCTLVAML        456
ASNENMETM        414
SSPPMFRV         278
FLCMKALLL        274
LTDEMIAQY        270
LSLRNPILV        268
YVLDHLIVV        236
KTFPPTEPK        230
TVYGFCLL         166
DATYQRTRALVR     166
QYIKWPWYI        166
Name: epitope, dtype: int64

In [50]:
from typing import Tuple, List

def hard_split_df(
        df: pd.DataFrame, target_col: str, min_ratio: float, random_seed: float, low: int, high: int) -> Tuple[pd.DataFrame, pd.DataFrame, List[str]]:
    """ Assume a target column, e.g. `epitope`.
    Then:
        1) Select random sample
        2) All samples sharing the same value of that column
        with the randomly selected sample are used for test
        3)Repeat until test budget (defined by train/test ratio) is
        filled.
    """
    min_test_len = round((1-min_ratio) * len(df))
    test_len = 0
    selected_target_val = []

    train_df = df.copy()
    test_df = pd.DataFrame()
    
    target_count_df = df.groupby([target_col]).size().reset_index(name='counts')
    target_count_df = target_count_df[target_count_df['counts'].between(low, high, inclusive='both')]
    possible_target_val = list(target_count_df[target_col].unique())
    max_target_len = len(possible_target_val)

    while test_len < min_test_len:
#         sample = train_df.sample(n=1, random_state=random_state)
#         target_val = sample[target_col].values[0]
        rng = np.random.default_rng(seed=random_seed)
        target_val = rng.choice(possible_target_val)

        if target_val not in selected_target_val:
            to_test = train_df[train_df[target_col] == target_val]

            train_df = train_df.drop(to_test.index)
            test_df = pd.concat((test_df, to_test), axis=0)
            test_len = len(test_df)

            selected_target_val.append(target_val)
            possible_target_val.remove(target_val)

        if len(selected_target_val) == max_target_len:
            print(f"Possible targets left {possible_target_val}")
            raise Exception('No more values to sample from.')

    print(f"Target {target_col} sequences: {selected_target_val}")

    return train_df, test_df, selected_target_val

In [51]:
train_df, test_df, selected_target_val = hard_split_df(binding_df, 'epitope', min_ratio=0.8, low=50, high=800, random_seed=42)

Target epitope sequences: ['CTELKLSDY', 'DATYQRTRALVR', 'EAAGIGILTV', 'FEDLRLLSF', 'FEDLRVLSF', 'FEDLRVSSF', 'FLCMKALLL', 'ASNENMETM', 'FTSDYYQLY', 'GLCTLVAML', 'HGIRNASFI', 'IMNDMPIYM', 'KLVALGINAV', 'KTFPPTEPK', 'LLWNGPMAV']
