In [85]:
import pandas as pd
from skmultilearn.model_selection import IterativeStratification

In [86]:
site_name = 'vin' # edit
df = pd.read_csv('data/data_' + site_name + '.csv')
df

Unnamed: 0,Image,Atelectasis,Cardiomegaly,Effusion,Pneumothorax,Edema
0,80caa435b6ab5edaff4a0a758ffaec6e.dcm,1.0,0.0,0.0,0.0,0
1,0622cd29e4e0e4f198abf15614819ae8.dcm,0.0,1.0,0.0,0.0,0
2,bd6eb525438d6da1ced0ed1810857772.dcm,0.0,0.0,1.0,0.0,0
3,25f2c7b53a6ed09a9aaf73c30357aaf6.dcm,0.0,1.0,0.0,0.0,0
4,f769eea17a2e7678f481f386c3c6261c.dcm,0.0,1.0,0.0,0.0,0
...,...,...,...,...,...,...
5161,c25370b989806c860e4f2025a7ecefbe.dcm,0.0,0.0,0.0,0.0,0
5162,a89d28de87c246614069931ea72e7d30.dcm,0.0,0.0,0.0,0.0,0
5163,48e47872a01b6f07a8bde354db0d65e0.dcm,0.0,0.0,0.0,0.0,0
5164,39f57c83a1026f43f87ec02311874dc3.dcm,0.0,0.0,0.0,0.0,0


In [87]:
# ref: https://madewithml.com/courses/mlops/splitting/

def iterative_train_test_split(X, y, train_size):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'
    """
    stratifier = IterativeStratification(
        n_splits=2, order=1, sample_distribution_per_fold=[1.0-train_size, train_size, ])
    train_indices, test_indices = next(stratifier.split(X, y))
    X_train, y_train = X[train_indices], y[train_indices]
    X_test, y_test = X[test_indices], y[test_indices]
    return X_train, X_test, y_train, y_test

In [88]:
X = df['Image'].to_numpy()
y = df.drop('Image', axis=1).to_numpy()

X_train, X_, y_train, y_ = iterative_train_test_split(X, y, train_size=0.7)
X_val, X_test, y_val, y_test = iterative_train_test_split(X_, y_, train_size=0.5)

In [89]:
print(f"train: {len(X_train)} ({len(X_train)/len(X):.2f})\n"
      f"val: {len(X_val)} ({len(X_val)/len(X):.2f})\n"
      f"test: {len(X_test)} ({len(X_test)/len(X):.2f})")

train: 3616 (0.70)
val: 775 (0.15)
test: 775 (0.15)


In [90]:
train_df = pd.concat([pd.DataFrame(X_train).rename(columns={0:"Image"}), pd.DataFrame(y_train)], axis=1)
train_df = train_df.rename(columns={0:"Atelectasis", 1:"Cardiomegaly", 2:"Effusion", 3:"Pneumothorax", 4:"Edema"})
train_df

Unnamed: 0,Image,Atelectasis,Cardiomegaly,Effusion,Pneumothorax,Edema
0,bd6eb525438d6da1ced0ed1810857772.dcm,0.0,0.0,1.0,0.0,0.0
1,25f2c7b53a6ed09a9aaf73c30357aaf6.dcm,0.0,1.0,0.0,0.0,0.0
2,28d769becacfbdeebab6d3fda7322cf7.dcm,0.0,0.0,1.0,0.0,0.0
3,33403064ce25caa5fda270e6158c6b03.dcm,0.0,1.0,0.0,0.0,0.0
4,1302aab3d9d19f6bcb9db728e3ce6306.dcm,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...
3611,dffc86099028528a5b3c55ff3ae57722.dcm,0.0,0.0,0.0,0.0,0.0
3612,d568e16f214feab68fbadea220fb76ff.dcm,0.0,0.0,0.0,0.0,0.0
3613,c25370b989806c860e4f2025a7ecefbe.dcm,0.0,0.0,0.0,0.0,0.0
3614,a89d28de87c246614069931ea72e7d30.dcm,0.0,0.0,0.0,0.0,0.0


In [91]:
val_df = pd.concat([pd.DataFrame(X_val).rename(columns={0:"Image"}), pd.DataFrame(y_val)], axis=1)
val_df = val_df.rename(columns={0:"Atelectasis", 1:"Cardiomegaly", 2:"Effusion", 3:"Pneumothorax", 4:"Edema"})
val_df

Unnamed: 0,Image,Atelectasis,Cardiomegaly,Effusion,Pneumothorax,Edema
0,f769eea17a2e7678f481f386c3c6261c.dcm,0.0,1.0,0.0,0.0,0.0
1,6900482a91a538ead56b483f77bcf289.dcm,0.0,0.0,1.0,0.0,0.0
2,6a4f9965e83bfad45d66d4afa5d28cc5.dcm,0.0,1.0,0.0,0.0,0.0
3,6e4391555899c8474c4d32f42b2ba21b.dcm,0.0,0.0,1.0,1.0,0.0
4,c70dce909198abf8b39a7e0d41c9a895.dcm,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...
770,16d08231028d769217747dab0bf47726.dcm,0.0,0.0,0.0,0.0,0.0
771,829d6b4c6299fe6ce434d23d2410427f.dcm,0.0,0.0,0.0,0.0,0.0
772,f02b95632470aab569016f570f07fab8.dcm,0.0,0.0,0.0,0.0,0.0
773,a85382cb3d93bc88ebd4c9064ac544a4.dcm,0.0,0.0,0.0,0.0,0.0


In [92]:
test_df = pd.concat([pd.DataFrame(X_test).rename(columns={0:"Image"}), pd.DataFrame(y_test)], axis=1)
test_df = test_df.rename(columns={0:"Atelectasis", 1:"Cardiomegaly", 2:"Effusion", 3:"Pneumothorax", 4:"Edema"})
test_df

Unnamed: 0,Image,Atelectasis,Cardiomegaly,Effusion,Pneumothorax,Edema
0,80caa435b6ab5edaff4a0a758ffaec6e.dcm,1.0,0.0,0.0,0.0,0.0
1,0622cd29e4e0e4f198abf15614819ae8.dcm,0.0,1.0,0.0,0.0,0.0
2,a4fc9faa46af26c5fc462772d88d0af3.dcm,0.0,0.0,1.0,0.0,0.0
3,b80be20a79e5a7539bd00f4907b444b2.dcm,0.0,1.0,0.0,0.0,0.0
4,1c1ef26e3b3323f74041f6dd2371cd24.dcm,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...
770,5ba5f21b61f9d12ef2a88487cc71f6ec.dcm,0.0,0.0,0.0,0.0,0.0
771,abab3c0545772de7b3d73d7b457319d2.dcm,0.0,0.0,0.0,0.0,0.0
772,40d41571f7ee9a4583051e5ede025721.dcm,0.0,0.0,0.0,0.0,0.0
773,cb23c75b43410b0ef6133d9979fd5c89.dcm,0.0,0.0,0.0,0.0,0.0


In [93]:
train_df.to_csv('data/data_' + site_name + '/train_' + site_name + '.csv', index=False)
val_df.to_csv('data/data_' + site_name + '/val_' + site_name + '.csv', index=False)
test_df.to_csv('data/data_' + site_name + '/test_' + site_name + '.csv', index=False)