In [3]:
import pandas as pd
import numpy as np
import re
from ast import literal_eval
import matplotlib.pyplot as plt

In [4]:
from sklearn import tree
import graphviz

In [5]:
def HLA_cd8_converter(x):
    #define format of datetime
    return x.replace("[","").replace("]","").replace(",", "").replace("'","").split(" ")

def cdr3_lst_converter(x):
    #define format of datetime
    return x.replace("[","").replace("]","").replace("'","").split(" ")

def epitope_converter(x):
    #define format of datetime
    return [y for y in x.replace("[","").replace("]","").replace("\n","").split("'") if (y != '') & (y != ' ')]

def peptide_hla_converter(x):
    return re.findall("\w+\s{1}\w{1}\d+", x.replace("[","").replace("]","").replace("\n","").replace("'",""))

def literal_converter(val):
    # replace NaN with '' and perform literal eval on the rest
    return [] if val == '' else literal_eval(val)

converters = {'peptide_HLA_lst': peptide_hla_converter,
              'umi_count_lst_mhc': literal_eval,
              'umi_count_lst_TRA': literal_converter,'umi_count_lst_TRB': literal_converter,
              'cdr3_lst_TRA': cdr3_lst_converter,
              'cdr3_lst_TRB': cdr3_lst_converter,
              'HLA_lst_mhc': cdr3_lst_converter,'HLA_cd8': HLA_cd8_converter} #

In [6]:
def acc(df):
    return df.train_label.sum() / len(df.train_label.dropna())

In [3]:
import random
n = random.randint(0,99999999999)
n

38280423168

In [4]:
F1 = f'r{r}'

NameError: name 'r' is not defined

In [5]:
F1 = 'r%d'

In [6]:
F1 %3

'r3'

# Input

In [7]:
VALID = '../../experiments/exp13/run1/cat/eval_clonotypes/valid_ct.csv'

In [12]:
HTO = '../../experiments/exp13/run1_archive/brc/outfile.csv'
GEX = '../../experiments/exp13/run1_archive/tcr/usable_gems.txt'
TCR = '../../experiments/exp13/run1/tcr/cellranger_tot/outs/multi/vdj_t/all_contig_annotations.csv'
TCR_ARC = '../../experiments/exp13/run1_archive/tcr/cellranger_tot/outs/multi/vdj_t/all_contig_annotations.csv'

# Load

In [9]:
df = pd.read_csv(VALID, converters=converters)

In [13]:
tcr = pd.read_csv(TCR)
tcr_dct = tcr.groupby('barcode').is_cell.unique().apply(lambda x: x[0])
tcr_arc = pd.read_csv(TCR_ARC)
arc_dct = tcr_arc.groupby('barcode').is_cell.unique().apply(lambda x: x[0])
tcr_cell = pd.merge(arc_dct,tcr_dct, left_index=True,right_index=True, how='outer', suffixes=['_arc','_gex'])
gex = pd.read_csv(GEX, header=None, names=['gem'])

In [18]:
hto = pd.read_csv(HTO, skiprows=1, header=None,
                  names=['gem','seurat','umi_count_hto','feature_rna','count_hto','feature_hto','hto_max_id','hto_sec_id','hto_margin',
                         'hto_classification','hto_global_class','hash_id'])

# Process

In [19]:
df = pd.merge(df, hto, how='left', on='gem')
df.hto_global_class.fillna('Singlet', inplace=True)
df['gex'] = df.gem.isin(gex.gem)
#df['gex'] = df.gem.map(tcr_dct)
df = pd.merge(df, tcr_cell, left_on='gem',right_index=True)

In [21]:
df.single_barcode_mhc = np.where(df.single_barcode_mhc, 'pMHC singlet','pMHC multiplet')

In [85]:
df['hla_match_per_gem'] = df.apply(lambda row: 'HLA match' if row.HLA_mhc in row.HLA_cd8 else 'HLA mismatch', axis=1)

In [87]:
df['cell_flag'] = np.where(df.is_cell_arc, 'is cell','is not cell')

# one-hot

In [24]:
from sklearn.preprocessing import OneHotEncoder

In [88]:
df[['hto_global_class','single_barcode_mhc','tcr_category', 'cell_flag','hla_match_per_gem']]

Unnamed: 0,hto_global_class,single_barcode_mhc,tcr_category,cell_flag,hla_match_per_gem
0,Doublet,pMHC multiplet,multiple chains,is cell,HLA match
1,Singlet,pMHC multiplet,unique chains,is cell,HLA match
2,Singlet,pMHC multiplet,unique chains,is cell,HLA match
3,Singlet,pMHC multiplet,unique chains,is cell,HLA match
4,Singlet,pMHC multiplet,unique chains,is cell,HLA match
...,...,...,...,...,...
7120,Doublet,pMHC multiplet,missing chain,is cell,HLA match
7121,Singlet,pMHC multiplet,missing chain,is not cell,HLA match
7122,Singlet,pMHC multiplet,missing chain,is not cell,HLA match
7123,Singlet,pMHC multiplet,missing chain,is cell,HLA mismatch


In [130]:
t1 = df.loc[~df.pep_match.isna(), ['hto_global_class','single_barcode_mhc','tcr_category', 'cell_flag','hla_match_per_gem']]
t2 = df.loc[~df.pep_match.isna(),['umi_count_mhc','delta_umi_mhc',
                             'umi_count_cd8','delta_umi_cd8',
                             'umi_count_TRA','delta_umi_TRA',
                             'umi_count_TRB','delta_umi_TRB']]

In [131]:
Xn = df.loc[~df.pep_match.isna(), ['hto_global_class','single_barcode_mhc','tcr_category', 'cell_flag','hla_match_per_gem',
                              'umi_count_mhc','delta_umi_mhc',
                             'umi_count_cd8','delta_umi_cd8',
                             'umi_count_TRA','delta_umi_TRA',
                             'umi_count_TRB','delta_umi_TRB']]

In [132]:
enc = OneHotEncoder()
enc.fit(t1)
columns = [lab for cat in enc.categories_ for lab in cat]
columns_drop = [lab for cat in enc.categories_ for lab in cat[1:]]

t1 = pd.DataFrame(enc.transform(t1).toarray(), columns=columns, index=t2.index)

enc.categories_

[array(['Doublet', 'Negative', 'Singlet'], dtype=object),
 array(['pMHC multiplet', 'pMHC singlet'], dtype=object),
 array(['missing chain', 'multiple chains', 'unique chains'], dtype=object),
 array(['is cell', 'is not cell'], dtype=object),
 array(['HLA match', 'HLA mismatch'], dtype=object)]

In [133]:
columns_drop

['Negative',
 'Singlet',
 'pMHC singlet',
 'multiple chains',
 'unique chains',
 'is not cell',
 'HLA mismatch']

In [134]:
t3 = df.loc[:,['hto_global_class','single_barcode_mhc','tcr_category', 'cell_flag','hla_match_per_gem']]
t4 = df.loc[:,['umi_count_mhc','delta_umi_mhc',
                             'umi_count_cd8','delta_umi_cd8',
                             'umi_count_TRA','delta_umi_TRA',
                             'umi_count_TRB','delta_umi_TRB']]
t3 = pd.DataFrame(enc.transform(t3).toarray(), columns=columns, index=t4.index)

In [135]:
X = pd.merge(t1,t2, left_index=True, right_index=True) #[columns_drop]
y = df.loc[~df.pep_match.isna(),'pep_match'].astype(int)

In [136]:
X_test = df.loc[:,['umi_count_mhc','delta_umi_mhc',
                             'umi_count_cd8','delta_umi_cd8',
                             'umi_count_TRA','delta_umi_TRA',
                             'umi_count_TRB','delta_umi_TRB']]

X_test = pd.merge(t3,t4, left_index=True, right_index=True) #[columns_drop]

In [137]:
t1

Unnamed: 0,Doublet,Negative,Singlet,pMHC multiplet,pMHC singlet,missing chain,multiple chains,unique chains,is cell,is not cell,HLA match,HLA mismatch
2,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0
4,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0
10,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0
14,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0
16,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...
7113,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0
7114,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
7115,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
7120,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0


In [138]:
X

Unnamed: 0,Doublet,Negative,Singlet,pMHC multiplet,pMHC singlet,missing chain,multiple chains,unique chains,is cell,is not cell,HLA match,HLA mismatch,umi_count_mhc,delta_umi_mhc,umi_count_cd8,delta_umi_cd8,umi_count_TRA,delta_umi_TRA,umi_count_TRB,delta_umi_TRB
2,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,7.0,5.600000,1724.0,77.483146,3.0,12.0,4.0,16.0
4,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,33.0,26.400000,2224.0,91.711340,1.0,4.0,10.0,40.0
10,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,52.0,208.000000,1560.0,102.295082,2.0,1.6,3.0,12.0
14,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,4.000000,593.0,32.493151,3.0,12.0,9.0,36.0
16,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,20.0,16.000000,1261.0,152.848485,6.0,24.0,13.0,52.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7113,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,8.0,3.555556,3808.0,118.077519,0.0,0.0,4.0,16.0
7114,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.000000,21.0,2.048780,0.0,0.0,1.0,4.0
7115,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.000000,17.0,2.344828,0.0,0.0,1.0,4.0
7120,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,3.0,2.400000,723.0,34.023529,0.0,0.0,4.0,16.0


# Random Forest

In [139]:
from sklearn.ensemble import RandomForestClassifier

## Hyper variable testing

In [116]:
min_samples = [1, 5, 10, 15]
max_depth = [None, 15, 10, 5]
for s in min_samples:
    for d in max_depth:
        clf = RandomForestClassifier(min_samples_leaf=s, max_depth=d, n_estimators=20, random_state=0)
        clf = clf.fit(X, y)
        df[f'rf_pred_{s}_{d}'] = clf.predict(X_test)

In [121]:
df.set_index('gem').iloc[:,-16:].to_csv('../tmp_files/random_forest_test.csv', index=True)

## Testing with one-hot

In [140]:
clf = RandomForestClassifier(min_samples_leaf=1, max_depth=15, n_estimators=100, random_state=0)
clf = clf.fit(X, y)

In [141]:
df['rf_pred'] = clf.predict(X_test)

In [142]:
df[['rf_pred','pep_match']].dropna() #,'pred'

Unnamed: 0,rf_pred,pep_match
2,1,True
4,1,True
10,1,True
14,0,False
16,1,True
...,...,...
7113,1,True
7114,0,False
7115,0,False
7120,1,True


In [143]:
lol = df[['pep_match','rf_pred']].dropna().apply(lambda row: int(row.pep_match) == row.rf_pred, axis=1)

In [144]:
X[~lol]

Unnamed: 0,Doublet,Negative,Singlet,pMHC multiplet,pMHC singlet,missing chain,multiple chains,unique chains,is cell,is not cell,HLA match,HLA mismatch,umi_count_mhc,delta_umi_mhc,umi_count_cd8,delta_umi_cd8,umi_count_TRA,delta_umi_TRA,umi_count_TRB,delta_umi_TRB
395,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,3.0,2.4,1107.0,119.675676,2.0,8.0,6.0,24.0
6064,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,4.0,3.2,252.0,14.608696,0.0,0.0,1.0,4.0
6584,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.0,14.0,1.69697,0.0,0.0,1.0,4.0
6748,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.0,24.0,1.078652,0.0,0.0,1.0,4.0
6754,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.0,20.0,1.632653,0.0,0.0,1.0,4.0
6840,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,3.0,2.4,892.0,36.783505,0.0,0.0,5.0,20.0
7016,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,4.0,20.0,3.2,0.0,0.0,1.0,4.0


In [145]:
df[['ct','pep_match','rf_pred']].dropna()[~lol]

Unnamed: 0,ct,pep_match,rf_pred
395,22.0,False,1
6064,22.0,False,1
6584,2.0,True,0
6748,22.0,True,0
6754,2.0,True,0
6840,3.0,False,1
7016,2078.0,True,0


In [146]:
df.set_index('gem').loc[:,['rf_pred']].to_csv('../tmp_files/random_forest_onehot.csv', index=True)

# Decision Tree

In [38]:
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

In [54]:
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=X.columns,
                                class_names=np.where(y == 1, 'binder', 'artifact'),
                                filled=True, rounded=True,
                                special_characters=True)  
graph = graphviz.Source(dot_data) 
graph.render("tester")

'tester.pdf'

In [67]:
df['pred'] = clf.predict(X_test)

In [72]:
df.loc[df.pep_match.isna(),['umi_count_mhc','delta_umi_mhc',
                             'umi_count_cd8','delta_umi_cd8',
                             'umi_count_TRA','delta_umi_TRA',
                             'umi_count_TRB','delta_umi_TRB', 'ct','peptide_HLA','HLA_cd8','ct_hla','pred']]

Unnamed: 0,umi_count_mhc,delta_umi_mhc,umi_count_cd8,delta_umi_cd8,umi_count_TRA,delta_umi_TRA,umi_count_TRB,delta_umi_TRB,ct,peptide_HLA,HLA_cd8,ct_hla,pred
0,4.0,3.200000,4134.0,9.222532,8.0,32.0,6.0,1.411765,1825.0,RPHERNGFTVL B0702,"[A0201, B0702]",,0
1,7.0,2.153846,2168.0,97.438202,10.0,40.0,8.0,32.000000,1136.0,RVRAYTYSK A0301,"[A0201, A0301, B0702]",,1
3,19.0,15.200000,4645.0,285.846154,1.0,4.0,5.0,20.000000,79.0,CLGGLLTMV A0201,"[A0201, B0702]",,1
5,3.0,12.000000,806.0,46.724638,5.0,20.0,8.0,32.000000,2023.0,RVRAYTYSK A0301,"[A0201, A0301, B0702]",,1
6,7.0,1.333333,2455.0,47.902439,1.0,4.0,11.0,44.000000,344.0,TPRVTGGGAM B0702,"[A0301, B0702]",,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
7118,28.0,3.862069,2417.0,1.024913,0.0,0.0,7.0,28.000000,1636.0,TPRVTGGGAM B0702,"[A0201, A0301, B0702]",,1
7119,4.0,16.000000,3104.0,102.611570,0.0,0.0,10.0,40.000000,181.0,RVRAYTYSK A0301,"[A0201, B0702]",,1
7121,1.0,0.800000,1275.0,7.317073,0.0,0.0,4.0,16.000000,3579.0,TPRVTGGGAM B0702,"[A0201, A0301, B0702]",,0
7122,2.0,1.600000,2484.0,77.023256,0.0,0.0,1.0,4.000000,6790.0,RVRAYTYSK A0301,"[A0301, B0702]",,0


In [91]:
df.pep_match.sum() / len(df.pep_match.dropna())

0.8653262518968133

In [88]:
df[~df.pep_match.isna()].pred.sum() / len(df.pep_match.dropna())

0.8645675265553869

In [89]:
idx = (df.umi_count_mhc >= 2) & (df.delta_umi_mhc >= 1)

In [92]:
df[idx].pep_match.sum() / len(df[idx].pep_match.dropna())

0.9539007092198581

In [84]:
len(df[~df.pep_match.isna()])

2636

In [87]:
len(df.pep_match.dropna())

2636

In [99]:
lol = df[['pep_match','pred']].dropna().apply(lambda row: int(row.pep_match) == row.pred, axis=1)

In [101]:
X[~lol]

Unnamed: 0,umi_count_mhc,delta_umi_mhc,umi_count_cd8,delta_umi_cd8,umi_count_TRA,delta_umi_TRA,umi_count_TRB,delta_umi_TRB
6584,1.0,4.0,14.0,1.69697,0.0,0.0,1.0,4.0
6748,1.0,4.0,24.0,1.078652,0.0,0.0,1.0,4.0


In [104]:
df[['ct','pep_match','pred']].dropna()[~lol]

Unnamed: 0,ct,pep_match,pred
6584,2.0,True,0
6748,22.0,True,0


In [70]:
df.set_index('gem').loc[:,['rf_pred']]

Unnamed: 0_level_0,rf_pred
gem,Unnamed: 1_level_1
AAACCTGAGCCCAGCT-1,1
AAACCTGAGTCAATAG-1,1
AAACCTGCAATCCGAT-1,1
AAACCTGCAGCCAGAA-1,1
AAACCTGCATGCCACG-1,1
...,...
TTTGGTTGTTGAGGTG-1,1
TTTGTCAAGAATGTTG-1,0
TTTGTCAAGCGTGAAC-1,1
TTTGTCATCACCTCGT-1,1
