In [6]:
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split
import sys
sys.path.append('../..')
from modules.many_features import utils, constants

In [7]:
random.seed(constants.SEED)
np.random.seed(constants.SEED)

In [8]:
def sample_train_set(x, y, frac):
    sample_num = int(frac*len(x))
    idx_list = random.sample(list(x.index), sample_num)
    sampled_x = x.loc[idx_list]
    sampled_y = y.loc[idx_list]
    return sampled_x, sampled_y

In [9]:
def sample_train_set2(x, y, frac):
    X_train, X_non, y_train, y_non = train_test_split(x, y, test_size=1-frac, stratify=y, 
                                                      random_state=constants.SEED)
    return X_train, y_train

#### The dataset

In [10]:
train_df = pd.read_csv('../../final/data/train_set_noisy_6.csv')
X_set = train_df.iloc[:, 0:-1]
y_set = train_df.iloc[:, -1]
train_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,9.401651,-1.0,5.840045,0.0,116.415615,103.966653,-1.0,2.712885,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,28.204953,-1.0,2.0
1,7.029873,102.717562,2.079564,-1.0,-1.0,79.373697,212.817785,2.657004,1,0.675576,104.68878,50.452242,59.237694,8.549773,49.832026,21.08962,-1.0,3.0
2,15.900836,-1.0,-1.0,-1.0,-1.0,-1.0,58.36385,-1.0,1,0.999932,90.916733,60.845387,77.678965,1.513323,139.638756,47.702509,-1.0,0.0
3,12.710083,437.896009,2.773303,-1.0,392.155888,103.338649,91.089902,3.689834,1,1.431816,31.576192,124.321572,79.407308,24.911028,84.817826,38.130249,23.227983,7.0
4,9.256332,267.393719,2.347146,5.758205,258.911569,76.491274,189.160039,3.630348,1,1.259486,19.493158,-1.0,55.681222,-1.0,96.976233,27.768995,73.059709,3.0


In [11]:
fracs = [0.01, 0.05, 0.1, 0.5]

In [12]:
for frac in fracs:
    x, y = sample_train_set(X_set, y_set, frac)
    df = pd.concat([x.reset_index(drop=True), y.reset_index(drop=True)], axis=1)
    print(f'frac:{frac} - {len(df)} samples')
    df.to_csv(f'../../final/data/train_set_noisy_6_{frac}', index=False)

frac:0.01 - 560 samples
frac:0.05 - 2800 samples
frac:0.1 - 5600 samples
frac:0.5 - 28000 samples


In [13]:
x1, y1 = sample_train_set(X_set, y_set, 0.01)
x1.shape, y1.shape

((560, 17), (560,))

In [14]:
x2, y2 = sample_train_set2(X_set, y_set, 0.01)
x2.shape, y2.shape

((560, 17), (560,))

In [15]:
unique, counts = np.unique(y1, return_counts=True)
dict(zip(unique, counts))

{0.0: 111, 1.0: 51, 2.0: 60, 3.0: 64, 4.0: 65, 5.0: 74, 6.0: 71, 7.0: 64}

In [16]:
unique, counts = np.unique(y2, return_counts=True)
dict(zip(unique, counts))

{0.0: 128, 1.0: 65, 2.0: 65, 3.0: 64, 4.0: 60, 5.0: 65, 6.0: 65, 7.0: 48}

In [45]:
df = pd.concat([x1.reset_index(drop=True), y1.reset_index(drop=True)], axis=1)
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.30754,364.404722,4.55067,0.498705,306.276945,92.255672,204.157296,2.37629,0,0.909182,92.993546,47.196764,41.097484,25.527855,100.176889,21.92262,66.657742,5
1,11.984145,-1.0,1.376078,-1.0,-1.0,92.631555,-1.0,3.88123,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,35.952434,-1.0,6
2,12.417215,-1.0,4.117807,5.391325,246.164222,-1.0,-1.0,-1.0,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,37.251646,-1.0,0
3,8.641878,141.251735,1.06782,0.0,326.038721,102.364564,-1.0,2.532676,0,1.142712,144.087673,-1.0,-1.0,-1.0,-1.0,25.925633,-1.0,2
4,10.008272,-1.0,1.032614,0.0,-1.0,104.973975,247.352318,2.860215,1,-1.0,-1.0,102.204171,38.776599,20.365048,74.73624,30.024816,-1.0,2


In [46]:
df.tail()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
555,7.632455,-1.0,-1.0,5.303657,-1.0,101.153114,-1.0,2.263634,0,0.25039,95.460474,56.169752,-1.0,2.167046,-1.0,22.897366,-1.0,1
556,11.559249,56.765972,-1.0,-1.0,479.054093,77.995586,-1.0,4.446117,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,34.677748,-1.0,4
557,13.318941,-1.0,1.084404,-1.0,-1.0,-1.0,-1.0,-1.0,1,1.329228,132.225907,36.636147,-1.0,9.721385,-1.0,39.956823,-1.0,0
558,7.152977,457.824948,1.391993,1.014386,-1.0,76.639277,166.595915,2.799991,1,1.693948,149.700606,31.555618,55.77997,16.713569,47.220864,21.45893,-1.0,3
559,7.992157,100.85527,-1.0,4.810671,411.088886,102.848344,161.217587,2.331245,1,1.653785,23.541767,86.224006,12.105319,7.733197,51.34826,23.976472,39.217209,1
