In [1]:
import os
import sys
import uuid

import random

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Dataset preprocessing

In [2]:
vdjdb_raw = pd.read_csv('../data/raw/vdjdb_paired_010923.tsv', sep='\t')
iedb_raw = pd.read_csv('../data/raw/iedb_010923.csv')
mcpas_raw = pd.read_csv('../data/raw/mcpas-tcr_010923.csv', encoding='latin1')

  iedb_raw = pd.read_csv('../data/raw/iedb_010923.csv')
  mcpas_raw = pd.read_csv('../data/raw/mcpas-tcr_010923.csv', encoding='latin1')


In [3]:
unknown_mhc = pd.read_csv('../data/utils/unrecognized_mhcs.txt', header=None, sep='\t')[0].values.tolist()
alphabeta_df = pd.read_csv('../data/utils/alphabeta_db.tsv', sep='\t')
alphabeta_df = alphabeta_df[alphabeta_df['organism']=='human'].copy()
known_va = alphabeta_df.loc[(alphabeta_df['chain']=='A') & (alphabeta_df['region']=='V'), 'id'].values.tolist()
known_vb = alphabeta_df.loc[(alphabeta_df['chain']=='B') & (alphabeta_df['region']=='V'), 'id'].values.tolist()

known_ja = alphabeta_df.loc[(alphabeta_df['chain']=='A') & (alphabeta_df['region']=='J'), 'id'].values.tolist()
known_jb = alphabeta_df.loc[(alphabeta_df['chain']=='B') & (alphabeta_df['region']=='J'), 'id'].values.tolist()

## VDJdb

In [4]:
vdj_cols = ['complex.id', 'CDR3', 'V', 'J', 'Species', 'MHC A','MHC class', 'Epitope']
print('VDJdb raw:', vdjdb_raw.shape)

vdjdb = vdjdb_raw[['complex.id', 'CDR3', 'V', 'J', 'Species', 'MHC A','MHC class', 'Epitope', 'Gene']].dropna()
vdjdb_a = vdjdb[vdjdb['Gene']=='TRA'][vdj_cols].copy()
vdjdb_b = vdjdb[vdjdb['Gene']=='TRB'][vdj_cols].copy()

vdjdb = pd.merge(vdjdb_a, vdjdb_b, on='complex.id')
print('VDJdb paired:', vdjdb.shape)

assert (vdjdb['Species_x'] == vdjdb['Species_y']).all()
assert (vdjdb['Epitope_x'] == vdjdb['Epitope_y']).all()
assert (vdjdb['MHC A_x'] == vdjdb['MHC A_y']).all()

vdj_cols = ['complex.id', 'CDR3_x', 'V_x', 'J_x', 'CDR3_y', 'V_y', 'J_y', 'Species_y', 'MHC A_y', 'MHC class_y', 'Epitope_y']
vdjdb = vdjdb[vdj_cols].copy()
vdj_cols = ['complex.id', 'cdr3a', 'va', 'ja', 'cdr3b', 'vb', 'jb', 'species', 'mhc', 'mhc_class', 'epitope']
vdjdb.columns = vdj_cols

# final data formating
try:
    vdjdb = vdjdb.drop('complex.id', axis=1)
except:
    pass
vdjdb = vdjdb.replace('MHCI', 1) # name formatting
vdjdb = vdjdb.replace('HomoSapiens', 'human') # name formatting
vdjdb = vdjdb[vdjdb['species']=='human'].copy() # keep only humans
vdjdb = vdjdb.drop_duplicates()
print("preprocessed vdjdb:", vdjdb.shape)
vdjdb.head()

save_path = '../data/preprocessed'
vdjdb.to_csv(os.path.join(save_path, 'vdjbd.tsv'), sep='\t')

VDJdb raw: (3050, 17)
VDJdb paired: (1436, 15)
preprocessed vdjdb: (1138, 10)


## IEDB

IEDB queried for MHC 1 and human only

In [5]:
subset = ['Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated']
iedb = iedb_raw.dropna(subset=subset)
iedb = iedb.astype('str')
print("Full IEDB:", iedb.shape)

iedb_base_cols = ['Group Receptor ID', 'Receptor ID', 'Description', 'Organism', 'MHC Allele Names']

iedb_cols = [col for col in iedb.columns if "Curated" in col and ("CDR3" in col or "V" in col or "J" in col)]
iedb_cols = ['Description', 'Organism', 'MHC Allele Names', 'Curated Chain 1 V Gene', 'Curated Chain 1 J Gene', 'Chain 1 CDR3 Curated',  
            'Curated Chain 2 V Gene','Curated Chain 2 J Gene', 'Chain 2 CDR3 Curated']
iedb = iedb[iedb_cols]
print("IEDB Selected Columns:", iedb.shape)

iedb = iedb.replace('nan', np.nan)
iedb = iedb.dropna()
iedb.columns = ['epitope', 'species', 'mhc', 'va', 'ja', 'cdr3a', 'vb', 'jb', 'cdr3b']
iedb['mhc_class'] = 1
iedb['species'] = 'human' # IEDB filtered for only human

iedb = iedb[vdjdb.columns]
print('IEDB remove nan:', iedb.shape)
iedb = iedb.drop_duplicates()
print('IEDB remove duplicates:', iedb.shape)

# we need the CDR3 ends for 3D generation, one way to get this is to check if the sequence starts with C
iedb = iedb[iedb['cdr3a'].str.startswith('C')]
iedb = iedb[iedb['cdr3b'].str.startswith('C')]
print('IEDB with cdr3 ends only:', iedb.shape)

# remove rows with sequences containing #
print('IEDB cdr3a sequences with #:', iedb['cdr3a'].str.contains('#').sum())
iedb = iedb[iedb['cdr3a'].str.contains('#')==False]
print('IEDB cdr3b sequences with #:', iedb['cdr3b'].str.contains('#').sum())
iedb = iedb[iedb['cdr3b'].str.contains('#')==False]

# V/J genes need to have allele information (e.g., '*01') at the end; ]
# if that's not available we add '*01'
for gene in ['va', 'vb', 'ja', 'jb']:
    print(f"IEDB {gene} sequences without allele information:", iedb[gene].str.contains("\*").sum())
    # iedb = iedb[iedb[gene].str.contains("\*")==True]
    iedb.loc[iedb[gene].str.contains("\*")==False, gene] = iedb.loc[iedb[gene].str.contains("\*")==False, gene] + '*01'

# VJ name formating for cases where V, J gene starts with 'TCR' instead of 'TR'
iedb = iedb.replace('TCR', 'TR', regex=True)

print('IEDB final:', iedb.shape)
save_path = '../data/preprocessed'
iedb.to_csv(os.path.join(save_path, 'iedb.tsv'), sep='\t')
iedb

Full IEDB: (24032, 71)
IEDB Selected Columns: (24032, 9)
IEDB remove nan: (6093, 10)
IEDB remove duplicates: (6002, 10)
IEDB with cdr3 ends only: (5483, 10)
IEDB cdr3a sequences with #: 22
IEDB cdr3b sequences with #: 0
IEDB va sequences without allele information: 858
IEDB vb sequences without allele information: 872
IEDB ja sequences without allele information: 878
IEDB jb sequences without allele information: 879
IEDB final: (5461, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope
81,CAVRPTSGGSYIPTF,TRAV21*01,TRAJ6*01,CASSYVGNTGELFF,TRBV6-5*01,TRBJ2-2*01,human,HLA-A*02:01,1,SLLMWITQC
173,CAGGTGNQFYF,TRAV35*02,TRAJ49*01,CAISEVGVGQPQHF,TRBV10-3*01,TRBJ1-5*01,human,HLA-A*02:01,1,AAGIGILTV
617,CALSEAGTGGSYIPTF,TRAV19*01,TRAJ6*01,CASSMFVGQPQHF,TRBV19*01,TRBJ1-5*01,human,HLA-A*02:01,1,GILGFVFTL
618,CAVSVEETSGSRLTF,TRAV41*01,TRAJ58*01,CASSFFHNNEQFF,TRBV19*01,TRBJ2-1*01,human,HLA-A*02:01,1,GILGFVFTL
619,CAYRSARDSSYKLIF,TRAV38-2/DV8*01,TRAJ12*01,CASSDHSVTGISSPLHF,TRBV7-9*03,TRBJ1-6*02,human,HLA-B7,1,TPRVTGGGAM
...,...,...,...,...,...,...,...,...,...,...
24299,CGAGETSGSRLTF,TRAV21*01,TRAJ58*01,CSVNLGGPTDTQYF,TRBV29-1*01,TRBJ2-3*01,human,HLA-B*07:02,1,KPVETSNSF
24300,CALEGSQGNLIF,TRAV9-2*01,TRAJ42*01,CSVPDGAEPYGYTF,TRBV20-1*01,TRBJ1-2*01,human,HLA-A*01:01,1,TTDPSFLGRY
24301,CLVGNTGGFKTIF,TRAV4*01,TRAJ9*01,CSVPDRGNTEAFF,TRBV29-1*01,TRBJ1-1*01,human,HLA-B*07:02,1,SPRWYFYYL
24303,CAPSRHAGNNRKLIW,TRAV9-2*01,TRAJ38*01,CSVQGGTNEKLFF,TRBV29-1*01,TRBJ1-4*01,human,HLA-A*01:01,1,VSDGGPNLY


## McPAS-TCR

In [6]:
mcpas_cols = ['CDR3.alpha.aa', 'CDR3.beta.aa', 'Species', 'Epitope.peptide', 'MHC', 'TRAV', 'TRAJ', 'TRBV', 'TRBJ']
mcpas = mcpas_raw[mcpas_cols].copy()
print('McPAS-TCR raw:', mcpas.shape)
mcpas = mcpas.dropna()
print('McPAS-TCR drop na:', mcpas.shape)
mcpas = mcpas.drop_duplicates()
print('McPAS-TCR drop duplicates:', mcpas.shape)

# formatting
mcpas['Species'] = mcpas['Species'].str.lower()
mcpas = mcpas[mcpas['Species'] == 'human']
mcpas.columns = ['cdr3a', 'cdr3b', 'species', 'epitope', 'mhc', 'va', 'ja', 'vb', 'jb']
mcpas['mhc_class'] = 1
mcpas = mcpas[vdjdb.columns].copy()

# remove rows with sequences containing #
print('McPAS cdr3a sequences with #:', mcpas['cdr3a'].str.contains('#').sum())
mcpas = mcpas[mcpas['cdr3a'].str.contains('#')==False]
print('McPAS cdr3b sequences with #:', mcpas['cdr3b'].str.contains('#').sum())
mcpas = mcpas[mcpas['cdr3b'].str.contains('#')==False]

for gene in ['va', 'vb', 'ja', 'jb']:
    print(f"McPAS {gene} sequences without allele information:", mcpas[gene].str.contains("\*").sum())
    mcpas.loc[mcpas[gene].str.contains("\*")==False, gene] = mcpas.loc[mcpas[gene].str.contains("\*")==False, gene] + '*01'

print('McPAS-TCR final:', mcpas.shape)

mcpas 

McPAS-TCR raw: (39985, 9)
McPAS-TCR drop na: (3124, 9)
McPAS-TCR drop duplicates: (2914, 9)
McPAS cdr3a sequences with #: 0
McPAS cdr3b sequences with #: 0
McPAS va sequences without allele information: 0
McPAS vb sequences without allele information: 1
McPAS ja sequences without allele information: 0
McPAS jb sequences without allele information: 0
McPAS-TCR final: (1059, 10)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope
147,CASPDAGGTSYGKLT,TRAV2*01,TRAJ5-1*01,CASLAGQGYNEQF,TRBV4*01,TRBJ2-1*01,human,HLA-Cw* 16:01,1,SAYGEPRKL
148,CAAPQAGTALIF,TRAV8-2*01,TRAJ15*01,CASLGAQNNEQF,TRBV12*01,TRBJ2-1*01,human,HLA-Cw* 16:01,1,AARAVFLAL
150,CTDVSTGGFKTIF,TRAV3*01,TRAJ9*01,CASSYSTGDEQYF,TRBV6*01,TRBJ2-7*01,human,HLA-Cw* 16:01,1,AARAVFLAL
151,CTDLNTGGFKTIF,TRAV3*01,TRAJ9*01,CASSYSTGDEQYF,TRBV6*01,TRBJ2-7*01,human,HLA-Cw* 16:01,1,AARAVFLAL
152,CVVKKNNTDKLIF,TRAV2*01,TRAJ2*01,CASSQGTSQFNEQF,TRBV7*01,TRBJ2-1*01,human,HLA-A*02,1,EAAGIGILTV
...,...,...,...,...,...,...,...,...,...,...
39028,CALDGPSNTGKLIF,TRAV16*01,TRAJ37*01,CATSESSGQTYEQYF,TRBV15*01,TRBJ2-2*01,human,HLA-A2:01,1,FLCMKALLL
39029,CATDAEGNNRLAF,TRAV17*01,TRAJ7*01,CASSIFGGGLGEQFF,TRBV19*01,TRBJ2-7*01,human,HLA-A2:01,1,FLCMKALLL
39030,CGAVGYQKVTF,TRAV34*01,TRAJ13*01,CALNGEISYNEQFF,TRBV2*01,TRBJ2-2*01,human,HLA-A2:01,1,FLCMKALLL
39031,CAVIWYNNNDMRF,TRAV8-1*01,TRAJ43*01,CASSQGVNTGELFF,TRBV4-2*01,TRBJ2-1*01,human,HLA-A2:01,1,FLCMKALLL


## Combine

In [8]:
positives = pd.concat((iedb, vdjdb, mcpas), axis=0)
print("concate db:", positives.shape)
positives = positives.drop_duplicates(subset=['cdr3a', 'cdr3b', 'epitope'], keep='last', ignore_index=True)
print("concate db remove duplicates:", positives.shape)


# positives = positives[~positives['mhc'].isin(unknown_mhc)]
# print('after removing unknown mhc', positives.shape)
# positives = positives[positives['va'].isin(known_va)]
# print('after removing unknown va', positives.shape)
# positives = positives[positives['vb'].isin(known_vb)]
# print('after removing unknown vb', positives.shape)
# positives = positives[positives['ja'].isin(known_ja)]
# print('after removing unknown ja', positives.shape)
# positives = positives[positives['jb'].isin(known_jb)]
# print('after removing unknown jb', positives.shape)

positives['uuid'] = [uuid.uuid4() for _ in range(len(positives.index))]

positives.shape

concate db: (7658, 10)
concate db remove duplicates: (6665, 10)


(6665, 11)

In [9]:
complex = []
for i in range(len(positives)):
    complex.append((positives['cdr3a'].iloc[i], positives['cdr3b'].iloc[i], positives['epitope'].iloc[i]))
positives['complex'] = complex
positives_dict = positives.drop_duplicates(subset='complex').set_index('complex', drop=True).to_dict(orient='index')
try:
    positives = positives.drop('complex', axis=1)
except:
    pass
print('unique tcr-peptide pairs:', len(positives_dict.keys()))

unique tcr-peptide pairs: 6665


# Negative sample generation

In [10]:
# create a lookup dictionary of all positive binding samples
def lookup_dict(df: pd.DataFrame, cdr3a: str, cdr3b: str, epitope: str) -> dict:
    complex = []
    for i in range(len(df.index)):
        complex.append((df[cdr3a].iloc[i], df[cdr3b].iloc[i], df[epitope].iloc[i]))
    df['complex'] = complex
    df_dict = df.drop_duplicates(subset='complex').set_index('complex', drop=True).to_dict(orient='index')
    return df_dict

# vdjdb
vdjdb_lc_raw = pd.read_csv('../data/raw/vdjdb_low_confidence.tsv', sep='\t')
vdjdb_lc_a = vdjdb_lc_raw[vdjdb_lc_raw['Gene']=='TRA'].copy()
vdjdb_lc_b = vdjdb_lc_raw[vdjdb_lc_raw['Gene']=='TRB'].copy()
vdjdb_lc = pd.merge(vdjdb_lc_a, vdjdb_lc_b, on='complex.id')

vdjdb_lc_dict = lookup_dict(vdjdb_lc, 'CDR3_x', 'CDR3_y', 'Epitope_y')

#iedb
subset = ['Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated']
iedb_lc = iedb_raw.dropna(subset=subset)
iedb_lc = iedb_lc.astype('str')

iedb_lc_dict = lookup_dict(iedb_lc, 'Chain 1 CDR3 Curated', 'Chain 2 CDR3 Curated', 'Description')

#mcpas-tcr
mcpas_lc = mcpas_raw.dropna(subset=['CDR3.alpha.aa', 'CDR3.beta.aa', 'Epitope.peptide'])

mcpas_lc_dict = lookup_dict(mcpas_lc, 'CDR3.alpha.aa', 'CDR3.beta.aa', 'Epitope.peptide')


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['complex'] = complex


In [11]:
i = 0
negatives = []
while i<len(positives.index):
    # sample a cdr3a/cdr3b pair and a another random peptide
    sample_idx = random.choice([k for k in range(len(positives)) if k!=i])
    neg_complex = (positives['cdr3a'].iloc[sample_idx], positives['cdr3b'].iloc[sample_idx], positives['epitope'].iloc[i])

    # check if cdr3a-cdr3a-peptide combination not in either positives/vdjdb/iedb/mcpas
    # .get() will return None if not in dict
    in_positives_dict = positives_dict.get(neg_complex)
    in_vdjdb_lc_dict = vdjdb_lc_dict.get(neg_complex)
    in_iedb_lc_dict = iedb_lc_dict.get(neg_complex)
    in_mcpas_lc_dict = mcpas_lc_dict.get(neg_complex)

    if (in_positives_dict is None) and (in_vdjdb_lc_dict is None) and (in_iedb_lc_dict is None) and (in_mcpas_lc_dict is None):
        # not in either db --> TCR-pMHC complex is negative binding
        sample = positives.iloc[sample_idx].copy()
        sample['epitope'] =  positives['epitope'].iloc[i]
        sample['mhc'] = positives['mhc'].iloc[i]
        if tuple(sample.values) not in negatives: # check if generated negative already exists
            negatives.append(tuple(sample.values))
            i+=1
                            
negatives = pd.DataFrame(negatives, columns=positives.columns)
print('generated negatives:', negatives.shape)
negatives = negatives.drop_duplicates() # this confirms there are no duplicates
print('generated negatives drop duplicates:', negatives.shape)

negatives['uuid'] = [uuid.uuid4() for _ in range(len(negatives.index))]
print('negative samples:', negatives.shape)
negatives

generated negatives: (6665, 11)
generated negatives drop duplicates: (6665, 11)
negative samples: (6665, 11)


Unnamed: 0,cdr3a,va,ja,cdr3b,vb,jb,species,mhc,mhc_class,epitope,uuid
0,CAVSDLEPNSSASKIIF,TRAV8-4*01,TRAJ3*01,CASSFGADTQYF,TRBV25-1*01,TRBJ2-3*01,human,HLA-A*02:01,1,GILGFVFTL,34bc7776-828a-44b0-bd88-1275b5ad47d5
1,CAVSEVGNKLTF,TRAV8-6*01,TRAJ17*01,CASSQSGGGEQFF,TRBV5-6*01,TRBJ2-1*01,human,HLA-B8,1,ELRRKMMYM,7aa90d75-de41-49a1-9f22-cd5dee099d73
2,CAVRGYSGGGADGLTF,TRAV20*01,TRAJ45*01,CASNFPNISEGTCSNQPQHF,TRBV3-1*01,TRBJ1-5*01,human,HLA-B8,1,QIKVRVDMV,e53745aa-ca36-446d-81f1-709a172be09c
3,CALGDSWGKLQF,TRAV9-2*01,TRAJ24*01,CASSAGTGVPDTQYF,TRBV10-2*01,TRBJ2-3*01,human,HLA-B8,1,QIKVRVDMV,6f8fd6f0-b3ce-439b-9da8-aa14fa29f38a
4,CALIEMYSGGGADGLTF,TRAV19*01,TRAJ45*01,CASSPRTDRSGANVLTF,TRBV12-3*01,TRBJ2-6*01,human,HLA-B8,1,QIKVRVDMV,575ed8f7-e800-4557-b5f2-ad4bbdd62755
...,...,...,...,...,...,...,...,...,...,...,...
6660,CAENRGFGNEKLTF,TRAV13-2*01,TRAJ48*01,CASSEGLEIRMRSSYEQYF,TRBV2*01,TRBJ2-7*01,human,HLA-A2:01,1,FLCMKALLL,6f4756b4-4986-49af-b515-bf55ff340126
6661,CAVDTNTDKLIF,TRAV2*01,TRAJ34*01,CASSQGYEQYF,TRBV11-2*01,TRBJ2-7*01,human,HLA-A2:01,1,FLCMKALLL,88da228a-6ee6-4a59-a022-39f9a4080223
6662,CAFCYNNNDMRF,TRAV38-1*01,TRAJ43*01,CASSQETGIYEQYF,TRBV4-1*01,TRBJ2-7*01,human,HLA-A2:01,1,FLCMKALLL,5006084a-2977-486e-b872-a966d8af818c
6663,CAASYGSGGYNKLIF,TRAV13-1*01,TRAJ4*01,CASSQDAGTAHVGEQFF,TRBV4-1*01,TRBJ2-1*01,human,HLA-A2:01,1,FLCMKALLL,9ff0480f-fbc4-4a00-934c-6931962ef765


In [12]:
# check if any overlap between negatives and positives
pd.merge(positives, negatives, on=['cdr3a', 'cdr3b', 'epitope'])

Unnamed: 0,cdr3a,va_x,ja_x,cdr3b,vb_x,jb_x,species_x,mhc_x,mhc_class_x,epitope,uuid_x,va_y,ja_y,vb_y,jb_y,species_y,mhc_y,mhc_class_y,uuid_y


# Saving Data

In [15]:
save_path = '../data/preprocessed'
positives.to_csv(os.path.join(save_path, 'positives.tsv'), sep='\t')
negatives.to_csv(os.path.join(save_path, 'negatives.tsv'), sep='\t')

In [13]:
# concatenate positive and negative binding samples

positives['binding'] = 1
negatives['binding'] = 0

binding_df = pd.concat((positives, negatives), ignore_index=True)

print('Total pos/neg binding df:', binding_df.shape)

save_path = '../data/preprocessed'
binding_df.to_csv(os.path.join(save_path, 'tcrpmhc_binding.tsv'), sep='\t')

Total pos/neg binding df: (13330, 12)


# Hard split

This train/test split method is detailed in _Grazioli F, Mösch A, Machart P, Li K, Alqassem I, O’Donnell TJ and Min MR (2022) On TCR binding predictors failing to generalize to unseen peptides. Front. Immunol. 13:1014256. [doi: 10.3389/fimmu.2022.1014256](https://www.frontiersin.org/articles/10.3389/fimmu.2022.1014256/full#f)_

Here we aim to understand how the algorithm works

In [None]:
binding_df = pd.read_csv('../data/preprocessed/tcrpmhc_binding.tsv', sep='\t')

In [None]:
epitope_count = binding_df['epitope'].value_counts()
print("number of unique epitopes:", len(epitope_count.values))
print("max count:", epitope_count.max())
print("min count:", epitope_count.min())
epitope_count.head(20)

In [None]:
from typing import Tuple, List

def hard_split_df(
        df: pd.DataFrame, target_col: str, min_ratio: float, random_seed: float, low: int, high: int) -> Tuple[pd.DataFrame, pd.DataFrame, List[str]]:
    """ Assume a target column, e.g. `epitope`.
    Then:
        1) Select random sample
        2) All samples sharing the same value of that column
        with the randomly selected sample are used for test
        3)Repeat until test budget (defined by train/test ratio) is
        filled.
    """
    min_test_len = round((1-min_ratio) * len(df))
    test_len = 0
    selected_target_val = []

    train_df = df.copy()
    test_df = pd.DataFrame()
    
    target_count_df = df.groupby([target_col]).size().reset_index(name='counts')
    target_count_df = target_count_df[target_count_df['counts'].between(low, high, inclusive='both')]
    possible_target_val = list(target_count_df[target_col].unique())
    max_target_len = len(possible_target_val)

    while test_len < min_test_len:
#         sample = train_df.sample(n=1, random_state=random_state)
#         target_val = sample[target_col].values[0]
        rng = np.random.default_rng(seed=random_seed)
        target_val = rng.choice(possible_target_val)

        if target_val not in selected_target_val:
            to_test = train_df[train_df[target_col] == target_val]

            train_df = train_df.drop(to_test.index)
            test_df = pd.concat((test_df, to_test), axis=0)
            test_len = len(test_df)

            selected_target_val.append(target_val)
            possible_target_val.remove(target_val)

        if len(selected_target_val) == max_target_len:
            print(f"Possible targets left {possible_target_val}")
            raise Exception('No more values to sample from.')

    print(f"Target {target_col} sequences: {selected_target_val}")

    return train_df, test_df, selected_target_val

In [None]:
train_df, test_df, selected_target_val = hard_split_df(binding_df, 'epitope', min_ratio=0.8, low=50, high=800, random_seed=42)