In [None]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import segmentation_models_pytorch as smp

from models.unet import *
from utils.data_utils.acdc_datamodule import *
from utils.data_utils.data_utils import *
from utils.model_utils.dice_score import *

In [None]:
random.seed(42)
torch.random.manual_seed(42)
np.random.seed(42)

In [None]:
wandb.login()
wandb.init(project="Medical Image Segmentation")

In [None]:
class SemanticSegmanter(pl.LightningModule):
    def __init__(self, model, learning_rate, criterion) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.train_losses = []
        self.val_losses = []
        self.test_losses = []
        self.lr = learning_rate

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        loss.requires_grad = True
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.train_losses.append(loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = 1-self.criterion(masks_pred, ground_truths)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.val_losses.append(loss)
        return loss

    def test_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        self.test_losses.append(loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [None]:
# Constans and Hyperparams
NUM_CLASSES = 4
MAX_EPOCHS = 500

# Big model takes lots of space in memory -> small batch size fits in
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 4
BATCH_SIZE_TEST  = 4

In [None]:
datamodule = ACDCDataModule("database", BATCH_SIZE_TRAIN,BATCH_SIZE_VAL,BATCH_SIZE_TEST,(256,256,1), convert_to_single=False)
datamodule.setup("fit")


In [None]:
class DiceLoss:
    def __init__(self, **kwargs):
        self.dice_metric = torchmetrics.classification.Dice(**kwargs).to("cuda")

    def __call__(self, input, target):
        dice_score = self.dice_metric(input, target)
        return 1 - dice_score

In [None]:
# May need to add new preprocessing arg, to include pretrained model preprocessing
# preprocess_input = smp.encoders.get_preprocessing_fn('resnet18', pretrained='imagenet')

unet = smp.Unet('resnet18', encoder_weights='imagenet', classes=NUM_CLASSES, activation='softmax', in_channels=1)
# criterion = torchmetrics.classification.Dice(num_classes=NUM_CLASSES, threshold=0.5)
criterion = DiceLoss(num_classes=NUM_CLASSES, threshold=0.5)
segmenter = SemanticSegmanter(model = unet, learning_rate=1e-3 ,criterion=criterion)

# tsmp.metrics.functional.IoU or torch metric?
# do we need this?
# metric = smp.metrics.functional.IoU(threshold=0.5)

In [None]:
# Configure callbacks and logger
wandb_logger = pl.loggers.WandbLogger()
wandb_logger.watch(unet)

early_stopping = EarlyStopping(monitor='val_loss',  patience=5, mode="min", verbose=True)


In [None]:
trainer = pl.Trainer(max_epochs=100,  logger=wandb_logger, callbacks=[early_stopping])
trainer.fit(segmenter, datamodule=datamodule)
wandb.finish()