In [3]:
from glob import glob
import os
import SimpleITK as sitk
import torch
import pandas as pd
import torchio as tio
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

In [4]:
df = pd.read_excel('/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/cases_list_w_endo.xlsx')
healthy_subjetcts = df["Healthy"].tolist()
adenomiosis_subjects = df["Adenomyosis"].dropna().tolist()

df = pd.DataFrame()
subjects = healthy_subjetcts + adenomiosis_subjects
df["subjects"] = subjects
df["label"] = [0] * len(healthy_subjetcts) + [1] * len(adenomiosis_subjects)
df

Unnamed: 0,subjects,label
0,earthworm0001,0
1,earthworm0003,0
2,earthworm0004,0
3,earthworm0005,0
4,earthworm0006,0
...,...,...
79,earthworm0185,1
80,earthworm0192,1
81,earthworm0210,1
82,earthworm0212,1


In [5]:
df.label.value_counts()

label
0    59
1    25
Name: count, dtype: int64

In [6]:
spacings_list = []
subjects_final = []
labels_final = []

for file in tqdm(glob("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/Stefan/*/cropped/*.nii.gz")):
    subject = file.split("/")[-3]
    
    if subject in subjects:
        
        img = tio.ScalarImage(file)
        img = tio.ToCanonical()(img)
        spacings_list.append(img.spacing)
        subjects_final.append(subject)
        labels_final.append(df[df["subjects"] == subject]["label"].item())


print(f"Number of patient: {len(spacings_list)}")
median_spacing =np.median(np.array(spacings_list), axis=0)
print(f"Median spacing: {median_spacing}")

for file in tqdm(glob("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/Stefan/*/cropped/*.nii.gz")):
    subject = file.split("/")[-3]
    
    if subject in subjects:
        
        img = tio.ScalarImage(file)
        img = tio.ToCanonical()(img)
        img = tio.Resample(median_spacing)(img)

        filename = subject + "_" + file.split("/")[-1].replace("_", "-")

        img.save(os.path.join("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/data", filename))


size_list = []
for file in tqdm(glob("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/data/*.nii.gz")):
    img = tio.ScalarImage(file)
    img = tio.ToCanonical()(img)
    size_list.append(img.shape)

print(f"Number of patient: {len(size_list)}")
max_size = np.max(np.array(size_list), axis=0)
max_size = max_size[1:]
print(f"Max size: {max_size}")

for file in tqdm(glob("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/data/*.nii.gz")):
    img = tio.ScalarImage(file)
    img = tio.ToCanonical()(img)
    img = tio.ZNormalization()(img)
    img = tio.CropOrPad(max_size)(img)

    img_tensor = img.tensor.squeeze(0)
    img_tensor = img_tensor.unsqueeze(0).unsqueeze(0)
    torch.save(img_tensor, file.replace(".nii.gz", ".pt"))

    img.save(file)


for file in tqdm(glob("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/data/*.pt")):
    subject = file.split("/")[-1].split("_")[0]
    label = df[df["subjects"] == subject]["label"].item()

    os.rename(file, file.replace(".pt", f"_{label}.pt"))


100%|██████████| 134/134 [00:02<00:00, 60.46it/s]


Number of patient: 67
Median spacing: [3.         0.42969999 0.42969999]


100%|██████████| 134/134 [00:11<00:00, 12.06it/s]
100%|██████████| 67/67 [00:01<00:00, 49.72it/s]


Number of patient: 67
Max size: [ 37 399 441]


100%|██████████| 67/67 [00:16<00:00,  4.04it/s]
100%|██████████| 67/67 [00:00<00:00, 5774.55it/s]


In [7]:
patients = {}

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for fold, (train_val_index, test_index) in enumerate(skf.split(subjects_final, labels_final)):

    n_test = len(test_index)
    n_train = len(train_val_index) - n_test

    train_index = train_val_index[:n_train]
    val_index = train_val_index[n_train:]

    print(f"Fold_{fold}: Train size: {len(train_index)}, Val size: {len(val_index)}, Test size: {len(test_index)}")

    train_ids = [subjects_final[i] for i in train_index]
    val_ids = [subjects_final[i] for i in val_index]
    test_ids = [subjects_final[i] for i in test_index]

    assert set(train_ids).isdisjoint(val_ids)
    assert set(train_ids).isdisjoint(test_ids)
    assert set(val_ids).isdisjoint(test_ids)
    
    patients[fold] = {
        'train': train_ids,
        'val': val_ids,
        'test': test_ids
    }

torch.save(patients, 'data_splits.pt')

Fold_0: Train size: 39, Val size: 14, Test size: 14
Fold_1: Train size: 39, Val size: 14, Test size: 14
Fold_2: Train size: 41, Val size: 13, Test size: 13
Fold_3: Train size: 41, Val size: 13, Test size: 13
Fold_4: Train size: 41, Val size: 13, Test size: 13


In [8]:
torch.load("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/adenomiosis/data/earthworm0196_t2-tse-sag_0.pt").shape

torch.Size([1, 1, 37, 399, 441])

In [9]:
torch.load("/home/johannes/Data/SSD_2.0TB/GNN_pCR/data/breast_cancer/data_processed/ISPY2-100899/ISPY2-100899-T0_2002-10-26_pCR_0.pt").shape

torch.Size([1, 3, 64, 64, 64])