In [24]:
import pandas as pd
import numpy as np

In [14]:
def stratified_sample(y, p=0.67, replace=False, seed=1234):
    unique_y, counts = np.unique(y, return_counts=True)
    n_per_class = np.array([int(np.math.floor(p*c)) for c in counts])
    n_per_class = np.array([max([npc, 1]) for npc in n_per_class])
    if seed != None:
        np.random.seed(seed)
    
    inds = [np.random.choice(np.where(y == unique_y[i])[0], size=npc, replace=replace) for i, npc in enumerate(n_per_class)]
    
    return np.concatenate(inds)

In [42]:
def get_split_label(ind, val_inds, main_inds):
    if ind in val_inds:
        return 'valid'
    elif ind in main_inds:
        return 'train'
    else:
        return 'test'

In [32]:
teamdrive_root = '../../../../teamdrive/transmediasp/kate/'
volume = [0.15, 0.25, 0.35, 0.55, 0.85]
seeds = [1234, 6512, 3845, 4321, 5888, 7356, 1834, 4628, 9375, 8372]

In [63]:
df = pd.read_csv(teamdrive_root + 'icons_experiment/label_noise_data_frame_with_splits.csv',index_col=0)
del df['Unnamed: 0.1']

df.head()

Unnamed: 0,image_path,class,label,split,10_0,10_1,10_2,10_3,10_4,20_0,...,split_90_0,split_90_1,split_90_2,split_90_3,split_90_4,split_00_0,split_00_1,split_00_2,split_00_3,split_00_4
0,../../../data/testdotai/close/_e4530e1aae88750...,close,0,main,0,0,0,36,0,90,...,main,main,main,main,main,main,test,main,main,main
1,../../../data/testdotai/close/~02a0c54fd8374b4...,close,0,main,0,0,0,0,0,48,...,main,main,main,test,main,test,test,main,main,main
2,../../../data/testdotai/close/_10099f88fd8333f...,close,0,main,0,0,0,0,0,0,...,main,main,main,main,test,main,main,main,main,main
3,../../../data/testdotai/close/_bcd740021f1a62a...,close,0,main,0,0,0,0,0,0,...,main,main,main,main,main,main,main,main,main,main
4,../../../data/testdotai/close/_047b3f69c7c53b8...,close,0,main,0,0,0,0,0,0,...,main,main,main,test,main,main,main,main,main,test


In [64]:
clean = df[['image_path', 'class', 'label', 'split']]
clean = clean[clean['split'] !='fine'].reset_index(drop=True)
del clean['split']
clean.head()

Unnamed: 0,image_path,class,label
0,../../../data/testdotai/close/_e4530e1aae88750...,close,0
1,../../../data/testdotai/close/~02a0c54fd8374b4...,close,0
2,../../../data/testdotai/close/_10099f88fd8333f...,close,0
3,../../../data/testdotai/close/_bcd740021f1a62a...,close,0
4,../../../data/testdotai/close/_047b3f69c7c53b8...,close,0


In [65]:
#split into train, validation and test 

y = clean['label'].tolist()
ind_list = np.arange(len(clean))

for vol in volume:
    for run,seed in enumerate(seeds):
        print('running %s volume round %s' %(vol, run))
        main_inds = stratified_sample(y, vol, seed=seed)
        y_main = [y[i] for i in main_inds]
        inds_val = stratified_sample(y_main, .05/vol, seed=seed)
        validation_inds = [main_inds[i] for i in inds_val]
        split = [get_split_label(i, validation_inds, main_inds) for i in ind_list]
        clean['%s_%s' %(int(vol*100), run)] = split
clean.head(20)

running 0.15 volume round 0
running 0.15 volume round 1
running 0.15 volume round 2
running 0.15 volume round 3
running 0.15 volume round 4
running 0.15 volume round 5
running 0.15 volume round 6
running 0.15 volume round 7
running 0.15 volume round 8
running 0.15 volume round 9
running 0.25 volume round 0
running 0.25 volume round 1
running 0.25 volume round 2
running 0.25 volume round 3
running 0.25 volume round 4
running 0.25 volume round 5
running 0.25 volume round 6
running 0.25 volume round 7
running 0.25 volume round 8
running 0.25 volume round 9
running 0.35 volume round 0
running 0.35 volume round 1
running 0.35 volume round 2
running 0.35 volume round 3
running 0.35 volume round 4
running 0.35 volume round 5
running 0.35 volume round 6
running 0.35 volume round 7
running 0.35 volume round 8
running 0.35 volume round 9
running 0.55 volume round 0
running 0.55 volume round 1
running 0.55 volume round 2
running 0.55 volume round 3
running 0.55 volume round 4
running 0.55 volume 

Unnamed: 0,image_path,class,label,15.0_0,15.0_1,15.0_2,15.0_3,15.0_4,15.0_5,15.0_6,...,85.0_0,85.0_1,85.0_2,85.0_3,85.0_4,85.0_5,85.0_6,85.0_7,85.0_8,85.0_9
0,../../../data/testdotai/close/_e4530e1aae88750...,close,0,test,test,test,test,valid,test,test,...,train,train,test,train,train,train,train,train,train,valid
1,../../../data/testdotai/close/~02a0c54fd8374b4...,close,0,test,test,test,test,train,test,test,...,train,train,train,train,train,test,train,train,train,train
2,../../../data/testdotai/close/_10099f88fd8333f...,close,0,test,test,test,test,test,test,test,...,train,train,test,train,train,train,test,train,train,train
3,../../../data/testdotai/close/_bcd740021f1a62a...,close,0,test,test,valid,test,test,test,test,...,test,test,train,train,train,test,train,train,train,train
4,../../../data/testdotai/close/_047b3f69c7c53b8...,close,0,test,test,test,test,train,test,test,...,train,train,train,train,train,train,train,train,test,train
5,../../../data/testdotai/close/_530464eb7c56a08...,close,0,test,train,test,valid,test,test,test,...,valid,train,test,train,test,train,test,train,train,train
6,../../../data/testdotai/close/_b9fca4f2b106c1f...,close,0,test,test,test,test,test,test,test,...,train,train,train,train,train,train,train,train,train,train
7,../../../data/testdotai/close/~9414624116f8235...,close,0,test,test,test,test,test,test,test,...,train,train,train,train,train,test,train,train,train,train
8,../../../data/testdotai/close/~b8a44d79b1165ed...,close,0,test,valid,test,test,test,test,test,...,train,train,train,test,test,train,valid,train,train,test
9,../../../data/testdotai/close/_def90b0d68090cb...,close,0,test,test,test,train,test,test,test,...,train,train,train,valid,test,train,train,train,train,valid


In [68]:
#rename columns
clean.columns = ['image_path', 'class', 'label', '15_0', '15_1', '15_2', '15_3',
       '15_4', '15_5', '15_6', '15_7', '15_8', '15_9', '25_0',
       '25_1', '25_2', '25_3', '25_4', '25_5', '25_6', '25_7',
       '25_8', '25_9', '35_0', '35_1', '35_2', '35_3', '35_4',
       '35_5', '35_6', '35_7', '35_8', '35_9', '55_0',
       '55_1', '55_2', '55_3', '55_4', '55_5', '55_6',
       '55_7', '55_8', '55_9', '85_0', '85_1', '85_2', '85_3', '85_4', '85_5', '85_6', '85_7', '85_8', '85_9']
clean.head()

Unnamed: 0,image_path,class,label,15_0,15_1,15_2,15_3,15_4,15_5,15_6,...,85_0,85_1,85_2,85_3,85_4,85_5,85_6,85_7,85_8,85_9
0,../../../data/testdotai/close/_e4530e1aae88750...,close,0,test,test,test,test,valid,test,test,...,train,train,test,train,train,train,train,train,train,valid
1,../../../data/testdotai/close/~02a0c54fd8374b4...,close,0,test,test,test,test,train,test,test,...,train,train,train,train,train,test,train,train,train,train
2,../../../data/testdotai/close/_10099f88fd8333f...,close,0,test,test,test,test,test,test,test,...,train,train,test,train,train,train,test,train,train,train
3,../../../data/testdotai/close/_bcd740021f1a62a...,close,0,test,test,valid,test,test,test,test,...,test,test,train,train,train,test,train,train,train,train
4,../../../data/testdotai/close/_047b3f69c7c53b8...,close,0,test,test,test,test,train,test,test,...,train,train,train,train,train,train,train,train,test,train


In [69]:
clean.to_csv(teamdrive_root + 'icons_experiment/data_frame_with_volume_splits.csv')