In [1]:
import pickle
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit, RepeatedStratifiedKFold
from utils import *

This file is to split the dataset to do binary classification. The saved format should be the index from the raw data table.
The overall strategy is 2 10-fold cross validation, making 20 splits in total.

In [2]:
# Load the data
geno_pheno = pd.read_pickle('../Data/Walker2015Lancet.pkl')
geno_pheno[:5]

Unnamed: 0_level_0,MUTATIONS,SM,KAN,AK,CAP,EMB,CIP,OFX,MOX,INH,RIF,PZA,source,linName
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
00-R0025,"[eis_V163I, pncA_Y95D, rrs_A1401G, rpoB_S450L,...",,,,,R,,,,R,R,,Maha,EastAsia
00-R0086,"[gyrA_G668D, gyrA_S95T, gidB_E92D, katG_R463L,...",,,,,R,,,,R,R,R,Maha,EastAsia
00-R0178,"[rrs_A1401G, gyrA_G668D, gyrA_S95T, gyrA_E21Q,...",,,,,R,,,,R,R,R,Maha,European
00-R0223,"[gyrA_E21Q, gidB_S100F]",,,,,S,,,,S,S,,Maha,European
00-R0308,"[rpoB_S450L, gyrA_G668D, gyrA_S95T, gyrA_E21Q,...",,,,,R,,,,R,R,R,Maha,European


In [4]:
def singleBinarySplit(drug):
    x, y = [], []
    for i in range(len(geno_pheno)):
        if geno_pheno[drug][i] == 'R':
            x.append(i)
            y.append(1)
        elif geno_pheno[drug][i] == 'S':
            x.append(i)
            y.append(0)
    x, y = np.array(x), np.array(y)
    res_split = []
    rskf = RepeatedStratifiedKFold(n_splits=10, n_repeats=2, random_state=42)
    ssp = StratifiedShuffleSplit(n_splits=1, test_size=1/9, random_state=42)
    for idx1, idx2 in rskf.split(x, y):
        train_val_index, test_index = x[idx1], x[idx2]
        y_train_val = y[idx1]
        for idx3, idx4 in ssp.split(train_val_index, y_train_val):
            # print(getY(y_train_val[idx3]), getY(y_train_val[idx4]), getY(y[idx2]))
            res_split.append((np.sort(train_val_index[idx3]), np.sort(train_val_index[idx4]), np.sort(test_index)))
    
    return res_split, x


# Save the splits for single drug binary classification
def save_split():
    for drug in walker_drug_list:
        res_split, x = singleBinarySplit(drug)
        with open(f"../Data/idx_splits/Walker_single_binary/{drug}_split.pickle", "wb") as fp:
            pickle.dump(res_split, fp)
        with open(f"../Data/idx_splits/Walker_single_binary/{drug}_index.pickle", "wb") as fp:
            pickle.dump(x, fp)

# save_split()

In [9]:
with open('../Data/idx_splits/Walker_single_binary/INH_split.pickle', 'rb') as f:
    res_split = pickle.load(f)
train_idx, val_idx, test_idx = res_split[0]
print(type(train_idx), len(val_idx), len(test_idx))

<class 'numpy.ndarray'> 1236 1236


In [11]:
def getMutDict(drug):
    # get the mutation list of the drug
    mutation = geno_pheno['MUTATIONS']
    mut_set = set()
    for i in range(len(mutation)):
        mut_set = mut_set.union(set(mutation[i]))
    mut_list = list(mut_set)
    # 0 is reserved for padding
    mut_dict = dict(zip(mut_list, range(1, len(mut_list)+1)))
    return mut_dict

# Prepare the train, val, test data into numpy array
def prepareData(drug):
    mut_dict = getMutDict(drug)
    mut_matrix = np.zeros((len(geno_pheno), len(mut_dict)+1))
    for i in range(len(geno_pheno)):
        for mut in geno_pheno['MUTATIONS'][i]:
            mut_matrix[i][mut_dict[mut]] = 1

    # Load the splitted indices for the drug
    with open(f'../Data/idx_splits/Walker_single_binary/{drug}_split.pickle', 'rb') as f:
        splits = pickle.load(f)

    X_data = []
    for i in range(len(splits)):
        train_idx, val_idx, test_idx = splits[i]
        X_train, X_val, X_test = mut_matrix[train_idx], mut_matrix[val_idx], mut_matrix[test_idx]
        X_data.append((X_train, X_val, X_test))
    
    # save the data
    with open(f'../Data/idx_splits/Walker_single_binary/{drug}.pkl', 'wb') as f:
        pickle.dump(X_data, f)

for drug in walker_drug_list:
    prepareData(drug)

In [16]:
# Load the splitted indices for the drug
drug = 'INH'
with open(f'../Data/idx_splits/Walker_single_binary/{drug}_split.pickle', 'rb') as f:
    splits = pickle.load(f)

train_idx, val_idx, test_idx = splits[0]
print(len(train_idx), len(val_idx), len(test_idx))
train_pheno = geno_pheno[drug][train_idx]
# transform the pheno to binary
train_pheno = np.array([1 if y == 'R' else 0 for y in train_pheno])
print(train_pheno)

9887 1236 1236
[1 1 0 ... 0 0 0]
