In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import segmentation_models_pytorch as smp

from models.unet import *
from utils.data_utils.acdc_datamodule import *
from utils.data_utils.data_utils import *
from utils.model_utils.dice_score import *

In [None]:
random.seed(42)
torch.random.manual_seed(42)
np.random.seed(42)

In [2]:
wandb.login()
wandb.init(project="Medical Image Segmentation")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbanfizsombor1999[0m ([33mdrigba[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class SemanticSegmanter(pl.LightningModule):
    def __init__(self, model, learning_rate, criterion) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.train_losses = []
        self.val_losses = []
        self.test_losses = []
        self.lr = learning_rate

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        loss.requires_grad = True
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.train_losses.append(loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = 1-self.criterion(masks_pred, ground_truths)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.val_losses.append(loss)
        return loss

    def test_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        self.test_losses.append(loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [4]:
# Constans and Hyperparams
NUM_CLASSES = 4
MAX_EPOCHS = 500

# Big model takes lots of space in memory -> small batch size fits in
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 4
BATCH_SIZE_TEST  = 4

In [5]:
datamodule = ACDCDataModule("database", BATCH_SIZE_TRAIN,BATCH_SIZE_VAL,BATCH_SIZE_TEST,(256,256,1), convert_to_single=False)
datamodule.setup("fit")


In [6]:
# May need to add new preprocessing arg, to include pretrained model preprocessing
# preprocess_input = smp.encoders.get_preprocessing_fn('resnet18', pretrained='imagenet')

unet = smp.Unet('resnet18', encoder_weights='imagenet', classes=NUM_CLASSES, activation='softmax', in_channels=1)
criterion = torchmetrics.classification.Dice(num_classes=NUM_CLASSES, threshold=0.5)
segmenter = SemanticSegmanter(model = unet, learning_rate=1e-3 ,criterion=criterion)

# tsmp.metrics.functional.IoU or torch metric?
# do we need this?
# metric = smp.metrics.functional.IoU(threshold=0.5)

In [7]:
# Configure callbacks and logger
wandb_logger = pl.loggers.WandbLogger()
wandb_logger.watch(unet)

early_stopping = EarlyStopping(monitor='val_loss',  patience=5, mode="min", verbose=True)


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [8]:
trainer = pl.Trainer(max_epochs=100,  logger=wandb_logger, callbacks=[early_stopping])
trainer.fit(segmenter, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type | Params
-----------------------------------
0 | model     | Unet | 3.1 M 
1 | criterion | Dice | 0     
-----------------------------------
3.1 M     Trainable params
0         Non-trainable params
3.1 M     Total params
12.351    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  return self.activation(x)
  return self.activation(x)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.544


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.542


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.541


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.541. Signaling Trainer to stop.
