In [None]:
import importlib
import os
import sys
import tomllib
from pathlib import Path
from pprint import pprint

import torch
import wandb
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.utils import set_determinism
from torch.utils.data import DataLoader, random_split

from src.datasets.acdc_dataset import ACDCDataset
from src.train import train
from src.transforms import get_transforms
from src.utils import find_optimal_learning_rate, setup_dirs
from src.visualization import visualize_loss_curves, visualize_predictions

sys.path.insert(0, "..")

In [None]:
root_dir = Path(os.getcwd()).parent
data_dir, log_dir, root_out_dir = setup_dirs(root_dir)
data_dir = data_dir / "ACDC" / "database"

with open(root_dir / "config.toml", "rb") as file:
    config = tomllib.load(file)

pprint(config)
batch_size = config["hyperparameters"].get("batch_size", 4)
epochs = config["hyperparameters"].get("epochs", 100)
learning_rate = config["hyperparameters"].get("learning_rate", 1e-5)
percentage_data = config["hyperparameters"].get("percentage_data", 1.0)
validation_split = config["hyperparameters"].get("validation_split", 0.8)

set_determinism(seed=config["hyperparameters"]["seed"])

In [None]:
importlib.reload(sys.modules["src.transforms"])

train_transforms = get_transforms()

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=4,
    # channels=(26, 52, 104, 208, 416),
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    norm=Norm.BATCH,
    # num_res_units=4,
    # dropout=0.5,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# TODO: weight decay check
optimizer = torch.optim.Adam(model.parameters())

metrics = {
    "dice": DiceMetric(include_background=False, reduction="mean"),
    "dice_with_background": DiceMetric(include_background=True, reduction="mean"),
    "hausdorff": HausdorffDistanceMetric(include_background=False, reduction="mean"),
    "dice_per_label": DiceMetric(include_background=False, reduction="mean_batch"),
    "dice_per_label_with_background": DiceMetric(include_background=True, reduction="mean_batch"),
    "hausdorff_per_label": HausdorffDistanceMetric(include_background=False, reduction="mean_batch"),
}

In [None]:
full_dataset_loader = DataLoader(
    ACDCDataset(
        data_dir=data_dir, train=True, transform=train_transforms, percentage_data=1
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
)

# Find the optimal learning rate using the full dataset
# Use the config learning rate as a midpoint.
optimal_learning_rate = find_optimal_learning_rate(
    model=model,
    optimizer=optimizer,
    criterion=loss_function,
    device=device,
    train_loader=full_dataset_loader,
    start_lr=learning_rate / 100,
    end_lr=learning_rate * 100,
    iterations=100,
)

if optimal_learning_rate is None:
    print("Optimal learning rate not found, using default learning rate.")
    optimal_learning_rate = learning_rate
else:
    print(f"Optimal learning rate found: {optimal_learning_rate}")

for group in optimizer.param_groups:
    group["lr"] = optimal_learning_rate

config["hyperparameters"]["optimal_learning_rate"] = optimal_learning_rate

In [None]:
def get_train_dataloaders(dataset: torch.utils.data.Dataset):
    total_training_number = len(dataset)
    train_size = int(validation_split * total_training_number)
    test_size = total_training_number - train_size

    train_ds, val_ds = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=1
    )
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=1)

    return train_loader, val_loader

In [None]:
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjosh-stein[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01669572708196938, max=1.0)â€¦

In [None]:
visualize_loss_curves(epoch_loss_values, metric_values, val_interval, out_dir)
for slice_no in [0, 2, 4]:
    visualize_predictions(
        model=model,
        model_file=os.path.join(out_dir, "best_metric_model.pth"),
        val_loader=val_loader,
        device=device,
        image_key="end_diastole",
        label_key="end_diastole_label",
        slice_no=slice_no,
    )