In [59]:
from os.path import join

import hydra
import lightning as L
import torch
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig
from src.data_module import DecathlonDataModule
#from src.model import DecathlonModel
from src.utils import generate_run_id
from monai.networks.nets import SegResNet




In [61]:
import lightning as L
import torch
from monai.losses import DiceLoss
from monai.networks.nets import DynUNet
from monai.networks.layers import Norm
from monai.metrics import compute_generalized_dice
from monai.inferers import sliding_window_inference


class DecathlonModel(L.LightningModule):
    def __init__(
        self,
        learning_rate: float = 3e-4,
        use_scheduler: bool = True,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.use_scheduler = use_scheduler

        # Define the model
        self._model = DynUNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            kernel_size=[3, 3, 3, 3, 3, 3],
            strides=[1, 2, 2, 2, 2, 2],
            upsample_kernel_size=[2, 2, 2, 2, 2, 2],
            norm_name=Norm.BATCH,
            deep_supervision=False,
        )

        # Define the loss function
        self.criterion = DiceLoss(to_onehot_y=True, softmax=True)

    def forward(self, x):
        return self._model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch["image"], batch["label"]
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    # Define inference method
    def _inference(self, input):
        def _compute(input):
            return sliding_window_inference(
                inputs=input,
                roi_size=(128, 128, 128),
                sw_batch_size=1,
                predictor=self,
                overlap=0.5,
            )

        VAL_AMP = True
        if VAL_AMP:
            with torch.cuda.amp.autocast():
                return _compute(input)
        else:
            return _compute(input)

    def validation_step(self, batch, batch_idx):
        x, y = batch["image"], batch["label"]
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        y_pred = self._inference(x)
        dice = compute_generalized_dice(y_pred, y)
        dice = dice.mean() if len(dice) > 0 else dice
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("val_dice", dice, on_step=True, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=0.05
        )

        configuration = {
            "optimizer": optimizer,
            "monitor": "val_loss",  # monitor validation loss
        }

        if self.use_scheduler:
            # Add lr scheduler
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
            configuration["lr_scheduler"] = scheduler

        return configuration


In [62]:
def train():
    # generate unique run id based on current date & time
    run_id = generate_run_id()

    # Seed everything for reproducibility
    L.seed_everything(42, workers=True)
    torch.set_float32_matmul_precision("high")

    # Initialize DataModule
    dm = DecathlonDataModule(
        root_dir='./data',
        task="Task06_Lung",
        batch_size=1,
        num_workers=3,
        seed=42
    )
    dm.setup()

    # Init model from datamodule's attributes
    model = DecathlonModel(
        # num_classes=dm.num_classes,
        learning_rate= 3e-1,
        use_scheduler= True,
    )

    # Init logger
    logger = TensorBoardLogger(save_dir='logs', name="", version=run_id)
    # Init callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=2,
        dirpath=join(' artifacts/checkpoints', run_id),
        filename="{epoch}-{step}-{val_loss:.2f}-{val_dice:.2f}",
    )

    # Init LearningRateMonitor
    lr_monitor = LearningRateMonitor(logging_interval="step")

    # early stopping
    early_stopping = EarlyStopping(
        monitor="val_loss",
        patience=10,
        verbose=True,
        mode="min",
    )

    # Initialize Trainer
    trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",
    devices="auto",
    logger=logger,
    callbacks=[checkpoint_callback, lr_monitor, early_stopping],
    precision=16,  # Enable mixed precision training
)


    # Train the model
    trainer.fit(model, dm)


    

In [63]:
train()

Global seed set to 42


KeyboardInterrupt: 