In [None]:
import importlib
import os
import sys
import tomllib
from pathlib import Path
from pprint import pprint

import torch
import wandb
from monai.losses import DiceLoss
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.utils import set_determinism
from torch.utils.data import DataLoader

from src.datasets.acdc_dataset import ACDCDataset
from src.utils import setup_dirs
from src.visualization import visualize_loss_curves, visualize_predictions

sys.path.insert(0, "..")

In [None]:
root_dir = Path(os.getcwd()).parent
data_dir, log_dir, root_out_dir = setup_dirs(root_dir)
data_dir = data_dir / "ACDC" / "database"

with open(root_dir / "config.toml", "rb") as file:
    config = tomllib.load(file)

pprint(config)
batch_size = config["hyperparameters"].get("batch_size", 4)
epochs = config["hyperparameters"].get("epochs", 100)
learning_rate = config["hyperparameters"].get("learning_rate", 1e-5)
percentage_data = config["hyperparameters"].get("percentage_data", 1.0)
validation_split = config["hyperparameters"].get("validation_split", 0.8)

set_determinism(seed=config["hyperparameters"]["seed"])

In [None]:


from src.transforms.transforms import get_transforms

importlib.reload(sys.modules["src.transforms"])

augment = True
config["hyperparameters"]["augmentations"] = augment
train_transforms, val_transforms = get_transforms(augment)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=4,
    # channels=(26, 52, 104, 208, 416),
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    norm=Norm.BATCH,
    # num_res_units=4,
    # dropout=0.5,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# TODO: weight decay check
optimizer = torch.optim.Adam(model.parameters())

In [None]:
full_dataset_loader = DataLoader(
    ACDCDataset(data_dir=data_dir / "training", transform=train_transforms),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

# Find the optimal learning rate using the full dataset
# Use the config learning rate as a midpoint.
# optimal_learning_rate = find_optimal_learning_rate(
#     model=model,
#     optimizer=optimizer,
#     criterion=loss_function,
#     device=device,
#     train_loader=full_dataset_loader,
#     start_lr=learning_rate / 1000,
#     end_lr=learning_rate * 1000,
#     iterations=epochs,
# )

optimal_learning_rate = 0.00039565388658322663
if optimal_learning_rate is None:
    print("Optimal learning rate not found, using default learning rate.")
    optimal_learning_rate = learning_rate
else:
    print(f"Optimal learning rate found: {optimal_learning_rate}")

for group in optimizer.param_groups:
    group["lr"] = optimal_learning_rate

config["hyperparameters"]["optimal_learning_rate"] = optimal_learning_rate

In [None]:


from src.utils import get_datasets, get_train_dataloaders

importlib.reload(sys.modules["src.train"])
importlib.reload(sys.modules["src.datasets.acdc_dataset"])
from src.train import train
from src.datasets.acdc_dataset import ACDCDataset
import gc

val_interval = 5

percentage_data = 1.0
config["hyperparameters"]["percentage_data"] = percentage_data

for percentage_slices in [0.1, 0.05]:
    config["hyperparameters"]["percentage_slices"] = percentage_slices

    train_data = ACDCDataset(
        data_dir=data_dir / "training",
        transform=train_transforms,
        percentage_data=percentage_data,
    )

    # Re-initialize model
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=4,
        # channels=(26, 52, 104, 208, 416),
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        norm=Norm.BATCH,
        # num_res_units=4,
        # dropout=0.5,
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=optimal_learning_rate)

    print(f"Number of training samples: {len(train_data)}")
    train_data, val_data = get_datasets(
        augment=augment,
        percentage_data=percentage_data,
        percentage_slices=percentage_slices,
        data_dir=data_dir / "training",
    )

    train_loader, val_loader = get_train_dataloaders(
        train_dataset=train_data,
        val_dataset=val_data,
        batch_size=batch_size,
        validation_split=validation_split,
    )

    wandb.init(
        project="acdc-3D-UNet-baseline-restart", config=config["hyperparameters"], tags=["limited_slices"],
        dir=log_dir,
        reinit=True,
    )
    wandb.config.dataset = "ACDC"
    wandb.config.architecture = "UNet"

    out_dir = root_out_dir / f"percentage_data_{percentage_data}" / f"percentage_slices_{percentage_slices}"
    os.makedirs(out_dir, exist_ok=True)

    epoch_loss_values, metric_values = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_function=loss_function,
        optimizer=optimizer,
        val_interval=val_interval,
        epochs=epochs,
        device=device,
        out_dir=out_dir,
        dimensions=3
    )

    wandb.finish()
    torch.cuda.empty_cache()
    gc.collect()

    # model.load_state_dict(state_cacher.retrieve("model"))
    # model.to(device)

    # visualize_loss_curves(epoch_loss_values, metric_values, val_interval, out_dir)

In [None]:
wandb.finish()

In [None]:
visualize_loss_curves(epoch_loss_values, metric_values, val_interval, out_dir)
for slice_no in [0, 2, 4]:
    visualize_predictions(
        model=model,
        val_loader=val_loader,
        device=device,
        slice_no=slice_no,
    )