In [None]:
import importlib
import os
import sys
import tomllib
from pathlib import Path
from pprint import pprint

import torch
import wandb
from monai.losses import DiceLoss
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.utils import set_determinism
from torch.utils.data import DataLoader, random_split

from src.datasets.acdc_dataset import ACDCDataset
from src.transforms.transforms import get_transforms
from src.utils import setup_dirs
from src.visualization import visualize_predictions

sys.path.insert(0, "..")

In [None]:
root_dir = Path(os.getcwd()).parent
data_dir, log_dir, out_dir = setup_dirs(root_dir)
data_dir = data_dir / "ACDC" / "database"
out_dir = out_dir / "2d_UNet"
os.makedirs(out_dir, exist_ok=True)

with open(root_dir / "config.toml", "rb") as file:
    config = tomllib.load(file)

pprint(config)
batch_size = config["hyperparameters"].get("epochs", 1)
epochs = config["hyperparameters"].get("epochs", 100)
learning_rate = config["hyperparameters"].get("learning_rate", 1e-5)
percentage_data = config["hyperparameters"].get("percentage_data", 1.0)
validation_split = config["hyperparameters"].get("validation_split", 0.8)

set_determinism(seed=config["hyperparameters"]["seed"])

In [None]:
train_transforms, val_transforms = get_transforms()
train_data = ACDCDataset(
    data_dir=data_dir / "Training",
    transform=train_transforms,
    percentage_data=percentage_data,
)

total_training_number = len(train_data)
train_size = int(validation_split * total_training_number)
val_size = total_training_number - train_size

train_ds, val_ds = random_split(train_data, [train_size, val_size])

print(f"Total training number: {total_training_number}")
print(total_training_number // batch_size)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=1)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    norm=Norm.BATCH,
).to(device)

# TODO: include_background = False?
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# TODO: weight decay check
optimizer = torch.optim.Adam(model.parameters())

In [None]:


# Use the config learning rate as a midpoint.
# optimal_learning_rate = find_optimal_learning_rate(
#     model=model,
#     optimizer=optimizer,
#     criterion=loss_function,
#     device=device,
#     train_loader=train_loader,
#     iterations=100,
#     learning_rate=learning_rate
# )

optimal_learning_rate = 0.0008697490026177839

if optimal_learning_rate is None:
    print("Optimal learning rate not found, using default learning rate.")
    optimal_learning_rate = learning_rate
else:
    print(f"Optimal learning rate found: {optimal_learning_rate}")

for group in optimizer.param_groups:
    group["lr"] = optimal_learning_rate

config["hyperparameters"]["optimal_learning_rate"] = optimal_learning_rate

In [None]:
wandb.init(
    project="acdc-2D-UNet-baseline", config=config["hyperparameters"], reinit=True
)
wandb.config.dataset = "ACDC"
wandb.config.architecture = "2D-UNet"

In [None]:


importlib.reload(sys.modules["src.train"])
from src.train import train

val_interval = 1

epoch_loss_values, metric_values = train(model=model, train_loader=train_loader, val_loader=val_loader,
                                         loss_function=loss_function, optimizer=optimizer,
                                         val_interval=val_interval, epochs=epochs, device=device,
                                         out_dir=out_dir, dimensions=2)

In [None]:
wandb.finish()

In [None]:
importlib.reload(sys.modules["src.visualization"])
# visualize_loss_curves(epoch_loss_values, metric_values, val_interval, out_dir)
for slice_no in [0, 2, 4]:
    visualize_predictions(
        model=model,
        val_loader=val_loader,
        device=device,
        slice_no=slice_no,
    )