In [1]:
import numpy as np
import pandas as pd

FRACTION_TRAIN = .8
OUT_DIR = 'data/split/'

In [2]:
def candidate_split(df, ntrain):
    p = df[['sample', 'drug']].drop_duplicates().reset_index(drop=True)
    p = p.reindex(np.random.permutation(p.index))
    train = p[:ntrain]
    test = p[ntrain:]
    return train, test

# checks if all the test values in column col appear in the training set
def check_subset(split):
    train, test = split
    if ((set(test['sample']) <= set(train['sample'])) and (set(test['drug']) <= set(train['drug']))):
        return True
    return False

In [3]:
def split_pairs(df):
    npairs = len(df[['sample', 'drug']].drop_duplicates())
    ntrain = int(np.floor(npairs * FRACTION_TRAIN))
    while True:
        split = candidate_split(df, ntrain)
        if check_subset(split):
            return split
        
def split_data_by_pairs(df, pairs, vol_name):
    pairs['sample_drug_pair'] = pairs[['sample', 'drug']].apply(tuple, axis=1)
    pairs = pairs.merge(df, 
                        on=['sample', 'drug', 'sample_drug_pair'], 
                        validate='one_to_many')
    return pairs[['sample', 'drug', vol_name + '_obs']]

def split_data(df, vol_name):
    train_pairs, test_pairs = split_pairs(df)
    df['sample_drug_pair'] = df[['sample', 'drug']].apply(tuple, axis=1)
    train = split_data_by_pairs(df, train_pairs, vol_name)
    test = split_data_by_pairs(df, test_pairs, vol_name)
    return train, test

In [4]:
def validate_disjoint(train, test):
    train_pairs = train[['sample', 'drug']].apply(tuple, axis=1)
    test_pairs = test[['sample', 'drug']].apply(tuple, axis=1)
    assert set(train_pairs).isdisjoint(set(test_pairs))
    
def validate_subset(train, test, col):
    train_vals = train[col].unique()
    test_vals = test[col].unique()
    assert set(test_vals).issubset(set(train_vals))
    
def validate_length(train, test, df):
    assert len(train) + len(test) == len(df)
    
def validate_split(train, test, df):
    validate_length(train, test, df)
    validate_disjoint(train, test)
    validate_subset(train, test, 'sample')
    validate_subset(train, test, 'drug')

In [5]:
def group_observations(df, vol_name):
    return df.groupby(['sample', 'drug'])[vol_name].apply(list).reset_index(name = vol_name + '_obs')

In [10]:
df = pd.read_csv('data/welm_pdx_clean_mid_volume.csv')
df = df[['Sample', 'Drug', 'log(V_V0+1)']]
# map columns
df = df.rename(columns={'Sample': 'sample', 'Drug': 'drug'})

In [11]:
df.head()

Unnamed: 0,sample,drug,log(V_V0+1)
0,HCI-010,Navitoclax,1.267422
1,HCI-010,Navitoclax,0.949712
2,HCI-010,Navitoclax,0.876555
3,HCI-024,Navitoclax,1.19246
4,HCI-024,Navitoclax,1.143551


In [12]:
#t_data = {'sample': ['s0', 's0', 's0', 's0', 's1', 's1', 's1', 's2'], 
#          'drug': ['d0', 'd0', 'd1', 'd1', 'd0', 'd0', 'd1', 'd0'], 
#          'log(V_V0)': [3, 3.1, 4, 4.2, 6, 5.9, 8, 9] }
#df = pd.DataFrame.from_dict(t_data)

In [13]:
vol_name = 'log(V_V0+1)'
df = group_observations(df, vol_name)
train, test = split_data(df, vol_name)
validate_split(train, test, df)
train.to_pickle(OUT_DIR + '/train.pkl')
test.to_pickle(OUT_DIR + '/test.pkl')

In [14]:
train

Unnamed: 0,sample,drug,log(V_V0+1)_obs
0,HCI-015,Navitoclax,"[1.9432807362278128, 1.511835941371754, 1.4325..."
1,HCI-010,Docetaxel,"[0.6966110197902743, 0.6617240371529612, 1.095..."
2,HCI-001,Docetaxel,"[1.787158194749478, 0.8201789624151877, 1.0561..."
3,HCI-011,Vehicle,"[4.982720841804346, 3.9334932149831983, 4.0913..."
4,HCI-015,Docetaxel,"[0.588717879959435, 0.5574157459740647, 0.7234..."
5,HCI-016,Vehicle,"[1.6353472115141618, 2.2895301652099707, 2.088..."
6,HCI-027,RO4929097,"[1.0315830268210708, 1.7661963054899594, 1.686..."
7,HCI-011,Fulvestrant (200 mg/kg),"[1.5174499142254496, 1.0704556512712424, 1.071..."
8,HCI-002,Vehicle,"[2.2534170356235266, 2.570040567136758, 2.6695..."
9,HCI-017,Vehicle,"[2.7595312193712096, 2.050113725645231, 2.3037..."


In [15]:
test

Unnamed: 0,sample,drug,log(V_V0+1)_obs
0,HCI-015,Birinapant,"[0.0638224701636058, 0.5674480761444503, 0.490..."
1,HCI-003,Vehicle,"[1.8669104857265288, 2.9361881666172884]"
2,HCI-024,Docetaxel,"[0.6343423431401605, 0.1392747886021277, 0.326..."
3,HCI-001,Birinapant,"[5.113505485685279, 4.576495959133186, 3.87578..."
4,HCI-024,Navitoclax,"[1.1924604822460507, 1.1435513005412914, 1.171..."
5,HCI-003,Fulvestrant (200 mg/kg),[0.579485798718202]
6,HCI-002,Irinotecan,"[1.929492024545651, 0.9522673892158982, 1.4744..."
7,HCI-002,RO4929097,"[3.099328512453143, 2.698806786398127, 3.17931..."
8,HCI-015,RO4929097,"[1.2182528911690906, 0.9100578746872044, 1.370..."
9,HCI-019,Vehicle,"[1.1449696912291676, 1.393448426141541, 1.3080..."
