In [None]:
import csv
import os
import sys
import tomllib
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from monai.losses import DiceLoss
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.utils import set_determinism

from src.utils import setup_dirs

sys.path.insert(0, "..")

In [None]:
root_dir = Path(os.getcwd()).parent
data_dir, log_dir, root_out_dir = setup_dirs(root_dir)
data_dir = data_dir / "MNMS"

with open(root_dir / "config.toml", "rb") as file:
    config = tomllib.load(file)

pprint(config)
batch_size = config["hyperparameters"].get("batch_size", 4)
epochs = config["hyperparameters"].get("epochs", 100)
learning_rate = config["hyperparameters"].get("learning_rate", 1e-5)
percentage_data = config["hyperparameters"].get("percentage_data", 1.0)
validation_split = config["hyperparameters"].get("validation_split", 0.8)

set_determinism(seed=config["hyperparameters"]["seed"])

In [None]:
training_dir = data_dir / "Training"
labeled = training_dir / "Labeled"

# Although there are '_gt.nii.gz' files in this 'unlabeled' folder, the segmentations are empty.
unlabeled = training_dir / "Unlabeled"

# Again, although there are '_gt.nii.gz' they are unlabeled?
validation_dir = data_dir / "Validation"
train_patients = [f.name for f in os.scandir(labeled) if f.is_dir()]
val_patients = [f.name for f in os.scandir(data_dir / "Validation") if f.is_dir()]
test_patients = [f.name for f in os.scandir(data_dir / "Testing") if f.is_dir()]

print("Num train", len(train_patients))
print("Num val", len(val_patients))
print("Num test", len(test_patients))

In [None]:
cardiac_phase_indexes = {}

num_train = 0
num_val = 0
num_test = 0
num_unknown = 0

with open(data_dir / "211230_M&Ms_Dataset_information_diagnosis_opendataset.csv") as csvfile:
    reader = csv.reader(csvfile)
    headers = next(reader)
    patient_index = headers.index("External code")
    ed_index = headers.index("ED")
    es_index = headers.index("ES")
    for row in reader:
        # cardiac_phase_indexes[row[patient_index]] = [row[ed_index], row[es_index]]
        patient_label = row[patient_index]
        cardiac_phase_indexes[patient_label] = {
            "end_diastole": int(row[ed_index]),
            "end_systole": int(row[es_index]),
        }

        if patient_label in train_patients:
            num_train += 1
        elif patient_label in val_patients:
            num_val += 1
        elif patient_label in test_patients:
            num_test += 1
        else:
            num_unknown += 1

print(f"Train: {num_train}\nVal: {num_val}\nTest: {num_test}\nUnknown: {num_unknown}")
print(f"Total: {num_train + num_val + num_test + num_unknown}")

In [None]:
print(train_patients)

In [None]:
patient = training_dir / "Labeled" / train_patients[0]
image = nib.load(patient / f"{patient.name}_sa.nii.gz")
label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz")

image = image.get_fdata(dtype=np.float32)
label = label.get_fdata(dtype=np.float32)

_, _, slices, times = image.shape
slices = 2
for i in range(slices):
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(image[..., i, cardiac_phase_indexes[patient.name]["end_diastole"]])
    plt.subplot(1, 2, 2)
    ones = label[..., i, cardiac_phase_indexes[patient.name]["end_diastole"]] == 1
    threes = label[..., i, cardiac_phase_indexes[patient.name]["end_diastole"]] == 3
    label[..., i, cardiac_phase_indexes[patient.name]['end_diastole']][ones] = 3
    label[..., i, cardiac_phase_indexes[patient.name]['end_diastole']][threes] = 1

    plt.imshow(label[..., i, cardiac_phase_indexes[patient.name]["end_diastole"]])

plt.show()

In [None]:
from src.transforms.transforms import get_transforms

train_transforms, val_transforms = get_transforms()

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=4,
    # channels=(26, 52, 104, 208, 416),
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    norm=Norm.BATCH,
    # num_res_units=4,
    # dropout=0.5,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# TODO: weight decay check
optimizer = torch.optim.Adam(model.parameters())