In [1]:
from sklearn.model_selection import StratifiedKFold
from glob import glob
import os
from tqdm import tqdm
import torch
import pandas as pd

In [2]:
final_patients_ids_3_timepoints = []
final_patients_ids_4_timepoints = []

for patient_dir in glob('/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/breast_cancer/data_processed/ISPY*'):
    
    patient_id = os.path.basename(patient_dir)
    files = os.listdir(patient_dir)

    if len(files) == 40:
        final_patients_ids_4_timepoints.append(patient_id)
        final_patients_ids_3_timepoints.append(patient_id)
    elif len(files) == 30:
        final_patients_ids_3_timepoints.append(patient_id)

print(f'Number of patients with 3 timepoints: {len(final_patients_ids_3_timepoints)}')
print(f'Number of patients with 4 timepoints: {len(final_patients_ids_4_timepoints)}')

# Save the patient ids
torch.save(final_patients_ids_3_timepoints, 'final_patient_ids_3_timepoints.pt')
torch.save(final_patients_ids_4_timepoints, 'final_patient_ids_4_timepoints.pt')

Number of patients with 3 timepoints: 625
Number of patients with 4 timepoints: 585


In [3]:
patient_ids_3_timepoints = torch.load('final_patient_ids_3_timepoints.pt')
patient_ids_4_timepoints = torch.load('final_patient_ids_4_timepoints.pt')
df = pd.read_excel("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/breast_cancer/ISPY2-Imaging-Cohort-1-Clinical-Data.xlsx")

patient_labels_3_timepoints = []
for patient_id in patient_ids_3_timepoints:
    patient_id_int = int(patient_id.split('-')[1])
    label = df.loc[df['Patient_ID'] == patient_id_int, 'pCR'].values[0]
    patient_labels_3_timepoints.append(label)

patient_labels_4_timepoints = []
for patient_id in patient_ids_4_timepoints:
    patient_id_int = int(patient_id.split('-')[1])
    label = df.loc[df['Patient_ID'] == patient_id_int, 'pCR'].values[0]
    patient_labels_4_timepoints.append(label)

assert len(patient_ids_3_timepoints) == len(patient_labels_3_timepoints)
assert len(patient_ids_4_timepoints) == len(patient_labels_4_timepoints)

In [4]:
patients_3_timepoints = {}
patients_4_timepoints = {}

skf_3_timepoints = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
skf_4_timepoints = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for fold, (train_val_index, test_index) in enumerate(skf_3_timepoints.split(patient_ids_3_timepoints, patient_labels_3_timepoints)):

    n_test = len(test_index)
    n_train = len(train_val_index) - n_test

    train_index = train_val_index[:n_train]
    val_index = train_val_index[n_train:]

    print(f"Fold_{fold}: Train size: {len(train_index)}, Val size: {len(val_index)}, Test size: {len(test_index)}")

    train_ids = [patient_ids_3_timepoints[i] for i in train_index]
    val_ids = [patient_ids_3_timepoints[i] for i in val_index]
    test_ids = [patient_ids_3_timepoints[i] for i in test_index]

    assert set(train_ids).isdisjoint(val_ids)
    assert set(train_ids).isdisjoint(test_ids)
    assert set(val_ids).isdisjoint(test_ids)
    
    patients_3_timepoints[fold] = {
        'train': train_ids,
        'val': val_ids,
        'test': test_ids
    }

for fold, (train_val_index, test_index) in enumerate(skf_4_timepoints.split(patient_ids_4_timepoints, patient_labels_4_timepoints)):

    n_test = len(test_index)
    n_train = len(train_val_index) - n_test

    train_index = train_val_index[:n_train]
    val_index = train_val_index[n_train:]

    print(f"Fold_{fold}: Train size: {len(train_index)}, Val size: {len(val_index)}, Test size: {len(test_index)}")

    train_ids = [patient_ids_4_timepoints[i] for i in train_index]
    val_ids = [patient_ids_4_timepoints[i] for i in val_index]
    test_ids = [patient_ids_4_timepoints[i] for i in test_index]

    assert set(train_ids).isdisjoint(val_ids)
    assert set(train_ids).isdisjoint(test_ids)
    assert set(val_ids).isdisjoint(test_ids)
    
    patients_4_timepoints[fold] = {
        'train': train_ids,
        'val': val_ids,
        'test': test_ids
    }

torch.save(patients_3_timepoints, 'data_splits_3_timepoints.pt')
torch.save(patients_4_timepoints, 'data_splits_4_timepoints.pt')

Fold_0: Train size: 375, Val size: 125, Test size: 125
Fold_1: Train size: 375, Val size: 125, Test size: 125
Fold_2: Train size: 375, Val size: 125, Test size: 125
Fold_3: Train size: 375, Val size: 125, Test size: 125
Fold_4: Train size: 375, Val size: 125, Test size: 125
Fold_0: Train size: 351, Val size: 117, Test size: 117
Fold_1: Train size: 351, Val size: 117, Test size: 117
Fold_2: Train size: 351, Val size: 117, Test size: 117
Fold_3: Train size: 351, Val size: 117, Test size: 117
Fold_4: Train size: 351, Val size: 117, Test size: 117
