In [None]:
import importlib
import os
import sys
import tomllib
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import torch
import wandb
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.utils import set_determinism
from mpl_toolkits.axes_grid1 import ImageGrid
from torch.utils.data import DataLoader, random_split

from src.metrics import METRICS
from src.utils import find_optimal_learning_rate, setup_dirs
from src.visualization import visualize_loss_curves, visualize_predictions

sys.path.insert(0, "..")

In [None]:
root_dir = Path(os.getcwd()).parent
data_dir, log_dir, out_dir = setup_dirs(root_dir)
data_dir = data_dir / "ACDC" / "database"

with open(root_dir / "config.toml", "rb") as file:
    config = tomllib.load(file)

pprint(config)
augment = config["hyperparameters"].get("augment", True)
batch_size = config["hyperparameters"].get("batch_size", 4)
epochs = config["hyperparameters"].get("epochs", 100)
learning_rate = config["hyperparameters"].get("learning_rate", 1e-5)
percentage_data = config["hyperparameters"].get("percentage_data", 1.0)
validation_split = config["hyperparameters"].get("validation_split", 0.8)

set_determinism(seed=config["hyperparameters"]["seed"])

In [None]:
importlib.reload(sys.modules["src.transforms"])
importlib.reload(sys.modules["src.datasets.acdc_dataset"])
from src.transforms.transforms import get_transforms
from src.datasets.acdc_dataset import ACDCDataset

train_transforms = get_transforms(augment)
train_data = ACDCDataset(data_dir=data_dir, train=True, transform=train_transforms)

In [None]:
check_dataloader = DataLoader(train_data, batch_size=1, shuffle=False)
check_data = next(iter(check_dataloader))
image, label = check_data["end_diastole"][0][0], check_data["end_diastole_label"][0][0]

print(image.shape, label.shape)

# slices = image.shape[2]
slices = 1
fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111, nrows_ncols=(slices, 2), axes_pad=0.1)

images = []
for i in range(slices):
    images.append(image[i, ...])
    images.append(label[i, ...])

for ax, image in zip(grid, images):
    ax.imshow(image, origin="lower")

plt.show()

# LV = 3
# RV = 1
# MYO = 2

In [None]:
train_data = ACDCDataset(
    data_dir=data_dir,
    train=True,
    transform=train_transforms,
    percentage_data=percentage_data,
)

total_training_number = len(train_data)
train_size = int(validation_split * total_training_number)
test_size = total_training_number - train_size

# TODO: cache dataset
# train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=1)
# val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=1)

train_ds, val_ds = random_split(train_data, [train_size, test_size])
print(f"Training size: {len(train_ds)}")
print(f"Validation size: {len(val_ds)}")

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=1)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=4,
    # channels=(26, 52, 104, 208, 416),
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    norm=Norm.INSTANCE,
    # num_res_units=4,
    # dropout=0.5,
).to(device)


In [None]:
from src.utils import find_optimal_learning_rate

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# TODO: weight decay check
optimizer = torch.optim.Adam(model.parameters())

# Use the config learning rate as a midpoint.
optimal_learning_rate = find_optimal_learning_rate(
    model=model,
    optimizer=optimizer,
    criterion=loss_function,
    device=device,
    train_loader=train_loader,
    learning_rate=learning_rate,
    iterations=100,
)

for group in optimizer.param_groups:
    group["lr"] = optimal_learning_rate

config["hyperparameters"]["optimal_learning_rate"] = optimal_learning_rate

In [None]:
wandb.init(project="acdc-3D-UNet-baseline", config=config["hyperparameters"], reinit=True)
wandb.config.dataset = "ACDC"
wandb.config.architecture = "UNet"

In [None]:
from src.train import train

val_interval = 5

# TODO: if early stopping is desired
# early_stopper = EarlyStopper(patience=50, min_delta=10)
# Pass as parameter
epoch_loss_values, metric_values = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_function=loss_function,
    optimizer=optimizer,
    val_interval=val_interval,
    epochs=epochs,
    metrics=metrics,
    device=device,
    out_dir=out_dir,
)

In [None]:
wandb.finish()

In [None]:
visualize_loss_curves(epoch_loss_values, metric_values, val_interval, out_dir)
for slice_no in [0, 2, 4]:
    visualize_predictions(
        model=model,
        model_file=os.path.join(out_dir, "best_metric_model.pth"),
        val_loader=val_loader,
        device=device,
        slice_no=slice_no,
    )