In [1]:
import numpy as np
import pandas as pd
import sys

FRACTION_TRAIN = .75

In [2]:
def candidate_split(df, ntrain):
    # Collapse Fulvestrant drugs
    f_drugs = {'Fulvestrant (200 mg/kg)': 'Fulvestrant', 'Fulvestrant (40 mg/kg)': 'Fulvestrant'}
    df['drug_collapsed'] = df['drug'].replace(f_drugs)
    # Split on (sample, collapsed drug) pairs. This is to ensure that we don't end up with (sample A, Fulv 40mg) in
    # the test set and (sample A, Fulv 200mg) in the training set, as an example.
    p = df[['sample', 'drug_collapsed']].drop_duplicates().reset_index(drop=True)
    p = p.reindex(np.random.permutation(p.index))
    train_collapsed = p[:ntrain]
    test_collapsed = p[ntrain:]
    tups = df[['sample', 'drug', 'drug_collapsed']].drop_duplicates().reset_index(drop=True)
    train_pairs = train_collapsed.merge(tups, on=['sample', 'drug_collapsed'], validate='one_to_many')
    test_pairs = test_collapsed.merge(tups, on=['sample', 'drug_collapsed'], validate='one_to_many')
    train_pairs = train_pairs[['sample', 'drug']].drop_duplicates()
    test_pairs = test_pairs[['sample', 'drug']].drop_duplicates()
    return train_pairs, test_pairs

In [3]:
# checks if all the test values in column col appear in the training set
def check_subset(split):
    train, test = split
    if ((set(test['sample']) <= set(train['sample'])) and (set(test['drug']) <= set(train['drug']))):
        return True
    return False

def check_drugs(split, df):
    train, test = split
    df_drugs = df.drug.unique()
    test_drugs = test.drug.unique()
    # check that each drug in dataset appears in test
    for drug in df_drugs:
        if drug not in test_drugs:
            return False
    return True

def split_pairs(df):
    npairs = len(df[['sample', 'drug']].drop_duplicates())
    ntrain = int(np.floor(npairs * FRACTION_TRAIN))
    while True:
        split = candidate_split(df, ntrain)
        if check_subset(split) and check_drugs(split, df):
            return split
        
def split_data_by_pairs(df, pairs, vol_name):
    pairs['sample_drug_pair'] = pairs[['sample', 'drug']].apply(tuple, axis=1)
    pairs = pairs.merge(df, 
                        on=['sample', 'drug', 'sample_drug_pair'], 
                        validate='one_to_many')
    return pairs[['sample', 'drug', vol_name + '_obs']]

def split_data(df, vol_name):
    train_pairs, test_pairs = split_pairs(df)
    df['sample_drug_pair'] = df[['sample', 'drug']].apply(tuple, axis=1)
    train = split_data_by_pairs(df, train_pairs, vol_name)
    test = split_data_by_pairs(df, test_pairs, vol_name)
    return train, test

def validate_disjoint(train, test):
    train_pairs = train[['sample', 'drug']].apply(tuple, axis=1)
    test_pairs = test[['sample', 'drug']].apply(tuple, axis=1)
    assert set(train_pairs).isdisjoint(set(test_pairs))
    
def validate_subset(train, test, col):
    train_vals = train[col].unique()
    test_vals = test[col].unique()
    assert set(test_vals).issubset(set(train_vals))
    
def validate_length(train, test, df):
    assert len(train) + len(test) == len(df)
    
def validate_split(train, test, df):
    validate_length(train, test, df)
    validate_disjoint(train, test)
    validate_subset(train, test, 'sample')
    validate_subset(train, test, 'drug')

def group_observations(df, vol_name):
    vol_list = df.groupby(['sample', 'drug'])[vol_name].apply(list).reset_index(name = vol_name + '_obs')
    mid_list = df.groupby(['sample', 'drug'])['MID'].apply(list).reset_index(name = 'MID_list')
    sd = df[['Sample', 'Drug']].drop_duplicates()
    sd = sd.merge(vol_list, on=['Sample', 'Drug'], validate='one_to_one')
    sd = sd.merge(mid_list, on=['Sample', 'Drug'], validate='one_to_one')
    return sd

read_fn = '../results/2023-05-26/clean_and_split_data/welm_pdx_clean_mid_volume.csv'
df = pd.read_csv(read_fn)
df = df[['Sample', 'Drug', 'log(V_V0+1)', 'MID']]
# map columns
df = df.rename(columns={'Sample': 'sample', 'Drug': 'drug'})
vol_name = 'log(V_V0+1)'
df = group_observations(df, vol_name)
train, test = split_data(df, vol_name)
validate_split(train, test, df)
#train.to_pickle(write_dir + '/train.pkl')
#test.to_pickle(write_dir + '/test.pkl')

In [13]:
df = pd.read_csv(read_fn)
df.head()

Unnamed: 0.1,Unnamed: 0,MID,Sample,Drug,start_vol,end_vol,duration,V_V0,log(V_V0+1),log(V_V0+1)_sm,log(V_V0+1)_cen
0,0,0,HCI-010,Navitoclax,163.9208,230.687446,21.0,1.40731,1.267422,1.03123,0.236192
1,1,1,HCI-010,Navitoclax,119.794563,111.587111,21.0,0.931487,0.949712,1.03123,-0.081518
2,2,2,HCI-010,Navitoclax,132.027026,110.372796,21.0,0.835986,0.876555,1.03123,-0.154675
3,3,3,HCI-024,Navitoclax,285.77,367.335,21.0,1.285422,1.19246,1.169276,0.023184
4,4,4,HCI-024,Navitoclax,176.157,213.0164,21.0,1.209242,1.143551,1.169276,-0.025725


In [14]:
vol_list = df.groupby(['Sample', 'Drug'])[vol_name].apply(list).reset_index(name = vol_name + '_obs')
mid_list = df.groupby(['Sample', 'Drug'])['MID'].apply(list).reset_index(name = 'MID_list')
sd = df[['Sample', 'Drug']].drop_duplicates()
sd = sd.merge(vol_list, on=['Sample', 'Drug'], validate='one_to_one')
sd = sd.merge(mid_list, on=['Sample', 'Drug'], validate='one_to_one')
sd.head()

Unnamed: 0,Sample,Drug,log(V_V0+1)_obs,MID_list
0,HCI-010,Navitoclax,"[1.2674221734888569, 0.9497121725523964, 0.876...","[0, 1, 2]"
1,HCI-024,Navitoclax,"[1.1924604822460507, 1.1435513005412914, 1.171...","[3, 4, 5]"
2,HCI-015,Navitoclax,"[1.9432807362278128, 1.511835941371754, 1.4325...","[6, 7, 8]"
3,HCI-027,Navitoclax,"[2.231724789216095, 2.698815312853063, 1.65234...","[9, 10, 11]"
4,HCI-002,Navitoclax,"[2.1605774655083625, 2.2689939926685523, 1.579...","[12, 13, 14]"


In [None]:
df = df.groupby(['Sample', 'Drug'])[vol_name].apply(list).reset_index(name = vol_name + '_obs')
df.head()

In [4]:
train

Unnamed: 0,sample,drug,log(V_V0+1)_obs
0,HCI-010,Docetaxel,"[0.6966110197902743, 0.6617240371529612, 1.095..."
1,HCI-002,Birinapant + Irinotecan,"[0.9045428407520958, 1.8598931160713488, 1.454..."
2,HCI-002,Vehicle,"[2.2534170356235266, 2.570040567136758, 2.6695..."
3,HCI-011,Vehicle,"[4.982720841804346, 3.9334932149831983, 4.0913..."
4,HCI-010,RO4929097,"[1.4404166149088045, 1.2824200855022965, 1.740..."
5,HCI-019,Birinapant,"[4.249973388277115, 3.774034821334692, 4.70383..."
6,HCI-001,Vehicle,"[2.6044063778532296, 3.2476149246326367, 2.506..."
7,HCI-010,Vehicle,"[1.4025516524368766, 1.6483637912782196, 1.585..."
8,HCI-002,RO4929097,"[3.099328512453143, 2.698806786398127, 3.17931..."
9,HCI-027,RO4929097,"[1.0315830268210708, 1.7661963054899594, 1.686..."


In [5]:
test

Unnamed: 0,sample,drug,log(V_V0+1)_obs
0,HCI-002,Irinotecan,"[1.929492024545651, 0.9522673892158982, 1.4744..."
1,HCI-001,Docetaxel,"[1.787158194749478, 0.8201789624151877, 1.0561..."
2,HCI-012,Birinapant,"[2.95020461492366, 2.429847754464165, 1.478671..."
3,HCI-002,Navitoclax,"[2.1605774655083625, 2.2689939926685523, 1.579..."
4,HCI-027,Vehicle,"[2.945656764460985, 2.111072497393752, 2.57121..."
5,HCI-017,Fulvestrant (200 mg/kg),"[0.0113402174904839, 0.0088650582547598, 0.017..."
6,HCI-017,Fulvestrant (40 mg/kg),"[0.7354288606700838, 1.648721081380227]"
7,HCI-015,RO4929097,"[1.2182528911690906, 0.9100578746872044, 1.370..."
8,HCI-012,Birinapant + Irinotecan,"[0.0064492509762392, 0.0070311369456857, 0.007..."
9,HCI-011,Fulvestrant (200 mg/kg),"[1.5174499142254496, 1.0704556512712424, 1.071..."


In [7]:
df.drug.unique()

array(['Birinapant', 'Docetaxel', 'Vehicle', 'Birinapant + Irinotecan',
       'Irinotecan', 'Navitoclax', 'RO4929097', 'Fulvestrant (200 mg/kg)',
       'Fulvestrant (40 mg/kg)'], dtype=object)