In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb
import multiprocessing


import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import segmentation_models_pytorch as smp

from models.unet import *
from models.simple_model import *
from models.tri_unet import *
from models.divergent_nets import *
from utils.data_utils.acdc_datamodule import *
from utils.data_utils.data_utils import *
from utils.model_utils.dice_score import *
from utils.model_utils.resnet_loss import ResnetLoss 

from lightning.pytorch.callbacks import RichProgressBar



In [2]:
pl.seed_everything(42)

Seed set to 42


42

In [3]:
wandb.login()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbanfizsombor1999[0m ([33mdrigba[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
class SemanticSegmanter(pl.LightningModule):
    def __init__(self, model, learning_rate, criterion) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.lr = learning_rate
        self.val_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=4)
        self.val_dice = torchmetrics.Dice( num_classes=4, ignore_index=0)
    def forward(self, x):
        model_output = self.model(x)
        return model_output
    
    def training_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        
        loss = self.criterion(masks_pred, ground_truths)
        self.val_acc(masks_pred, ground_truths)
        self.val_dice(masks_pred, ground_truths)
        
        self.log('val_dice', self.val_dice, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc",self.val_acc, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        pred_y = self(x)

        # train dice
        y_pred = torch.argmax(pred_y[2], axis=1)
        y_pred_onehot = F.one_hot(y_pred, 4).permute(0, 3, 1, 2)
        dice = self.compute_dice(y_pred_onehot, y)
        dice_LV = dice[3]
        dice_RV = dice[1]
        dice_MYO = dice[2]
        self.log('dice/all_train_dice', dice[1:].mean(), on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log('dice/train_LV_dice', dice_LV, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log('dice/train_RV_dice', dice_RV, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log('dice/train_MYO_dice', dice_MYO, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        # save grad
        for name, params in self.named_parameters():
            if params.grad is not None:
                self.log(f'abs_{name}',params.grad.abs().mean(), on_epoch=True)
        return loss #TODO
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [5]:
# Constans and Hyperparams
NUM_CLASSES = 4
MAX_EPOCHS = 500

# Big model takes lots of space in memory -> small batch size fits in
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 8
BATCH_SIZE_TEST  = 8

In [6]:
transform = DualTransform(20,0.2,0.2)
datamodule = ACDCDataModule("database", BATCH_SIZE_TRAIN,BATCH_SIZE_VAL,BATCH_SIZE_TEST,(256,256,1), convert_to_single=False, transform=transform)
datamodule.setup("fit")

In [7]:
# Diceloss kiegyensúlyozatlan osztályokra
# CrossEntropy is jó lehet, de Dice is jó 
# Jacquard index - IoU
# ignore index kipróbálása -> metrikánál mindenképp érdemes
#  1 ignorált klassznak lehet nincs hatása
#  ignorált osztály -> lehet crossentropy is azonos hatékonyságú
#  osztályonkénti dice coeff
# split betegenként, fixáljuk a validációs halmazt
#  KFold ha nem fixálunk -> ez is lehet ensemble
# pixelszintű tévesztési mátrix, pixel accuracy 
#  legyenek példapredikciók
#  kvalitatív kiértékelés

criterion = smp.losses.DiceLoss(mode="multiclass")
# loss = ResnetLoss(criterion)
loss = criterion
# loss = torch.nn.CrossEntropyLoss()

In [8]:
from lightning import Callback

class_labels = {
  0: "background",
  1: "RV",
  2: "myocardium",
  3: "LV"
}

class Visualizer(Callback):
    def on_validation_epoch_start(self, trainer, model):
        x = trainer.datamodule.acdc_val[0][0]
        y = trainer.datamodule.acdc_val[0][1]
        pred = model(x.unsqueeze(0).to('cuda'))
        gt_mask = np.array(y.squeeze())
        pred_mask = np.array(torch.argmax(pred.squeeze().cpu(), dim=0))
        error = (gt_mask == pred_mask).astype(np.uint8)
        trainer.logger.experiment.log(
            {"visualizing":[
                    wandb.Image(x, caption="GT", masks={
                        "segmentation": {
                            "mask_data": gt_mask,
                            "class_labels": class_labels
                        },
                    }),
                    wandb.Image(x, caption="pred", masks={
                        "segmentation": {
                            "mask_data": pred_mask,
                            "class_labels": class_labels
                        },
                    }),
                    wandb.Image(error, caption="error"),
                ]
            })

In [9]:
early_stopping = EarlyStopping(monitor='val_loss',  patience=5 ,mode="min", verbose=True)

In [10]:
class ResnetWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)["out"]

In [11]:
from torchvision.models.segmentation import fcn_resnet50

def fcn_factory(in_channel = 1):
    model = fcn_resnet50()
    model.backbone.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.classifier[4] = nn.Conv2d(512, NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))
    return ResnetWrapper(model)

In [None]:
# Unet
unet_model = smp.Unet('resnet34', classes=NUM_CLASSES, in_channels=1)
unet_model_pretrained = smp.Unet('resnet34', classes=NUM_CLASSES, in_channels=1, encoder_weights='imagenet')
unet_model_resnet50_pretrained = smp.Unet('resnet50', classes=NUM_CLASSES, in_channels=1, encoder_weights='imagenet')



# Unet++
# unet_plusplus_model = smp.UnetPlusPlus('resnet34', classes=NUM_CLASSES, in_channels=1, encoder_weights='imagenet')
# unet_plusplus_model_resnet50 = smp.UnetPlusPlus('resnet50', classes=NUM_CLASSES, in_channels=1, encoder_weights='imagenet')

# Fcn
fcn_model = fcn_factory()

# TriUnet - 3 unet
feature_models = torch.nn.ModuleList([smp.Unet('resnet34',classes=NUM_CLASSES, in_channels=1) for _ in range(2)])
triunet_3unet_model = TriUnet(feature_models, smp.Unet('resnet34',classes=NUM_CLASSES, in_channels=len(feature_models)*NUM_CLASSES))

# # TriUnet - 3 fcn
feature_models = torch.nn.ModuleList([fcn_factory(1) for _ in range(2)])
triunet_3fcn_model = TriUnet(feature_models, fcn_factory(len(feature_models)*NUM_CLASSES))

# TriUnet - 2 fcn + unet
# feature_models = torch.nn.ModuleList([fcn_factory(1) for _ in range(2)])
# triunet_2fcn_1unet_model = TriUnet(feature_models, smp.Unet('resnet34',classes=NUM_CLASSES, in_channels=len(feature_models)*NUM_CLASSES))

models = [
    ("unet_model", unet_model),
    ("unet_model_pretrained", unet_model_pretrained),
    ("unet_model_resnet50_pretrained", unet_model_resnet50_pretrained),
    # ("unet++_pretrained", unet_plusplus_model),
    # ("unet++_pretrained_resnet50", unet_plusplus_model_resnet50)
    ("fcn_model", fcn_model),
    ("triunet_3unet_model", triunet_3unet_model),
    ("triunet_3fcn_model", triunet_3fcn_model),
    # ("triunet_2fcn_1unet_model", triunet_2fcn_1unet_model)
]

for name, model in models:
    wandb_logger = pl.loggers.WandbLogger(project="Medical Image Segmentation", name=name)
    wandb_logger.watch(model)
    segmenter = SemanticSegmanter(model=model, learning_rate=1e-4, criterion=loss)
    trainer = pl.Trainer(max_epochs=MAX_EPOCHS, logger=wandb_logger, callbacks=[RichProgressBar(), Visualizer(), early_stopping])
    trainer.fit(segmenter, datamodule=datamodule)
    wandb.finish()

In [None]:
@torch.no_grad()
# ha nem működik használjuk a torchmetrics Dice-ot reduce = None-al, az visszaadja osztályonként a diceCoeff-et ami (1 - diceCoeff) = loss
def compute_dice(self, pred_y, y):
        """
        Computes the Dice coefficient for each class in the ACDC dataset.
        Assumes binary masks with shape (num_masks, num_classes, height, width).
        """
        epsilon = 1e-6
        num_masks = pred_y.shape[0]
        num_classes = pred_y.shape[1]
        dice_scores = torch.zeros((num_classes,), device=self.device)

        for c in range(num_classes):
            intersection = torch.sum(pred_y[:, c] * y[:, c])
            sum_masks = torch.sum(pred_y[:, c]) + torch.sum(y[:, c])
            dice_scores[c] = (2. * intersection + epsilon) / (sum_masks + epsilon)
        print(dice_scores)
        return dice_scores

In [None]:
class Evaluation(Callback):
    def on_test_epoch_start(self, trainer, model):
        x = trainer.datamodule.acdc_val[0][0]
        y = trainer.datamodule.acdc_val[0][1]
        pred = model(x.unsqueeze(0).to('cuda'))
        gt_mask = np.array(y.squeeze())
        pred_mask = np.array(torch.argmax(pred.squeeze().cpu(), dim=0))
        y_pred_onehot = F.one_hot(pred_mask, 4).permute(0, 3, 1, 2)
        dice = compute_dice(y_pred_onehot, gt_mask)
        dice_LV = dice[3]
        dice_RV = dice[1]
        dice_MYO = dice[2]
        self.log('dice/all_train_dice', dice[1:].mean(), on_epoch=True)
        self.log('dice/train_LV_dice', dice_LV, on_epoch=True)
        self.log('dice/train_RV_dice', dice_RV, on_epoch=True)
        self.log('dice/train_MYO_dice', dice_MYO, on_epoch=True)
        # save grad
        for name, params in self.named_parameters():
            if params.grad is not None:
                self.log(f'abs_{name}',params.grad.abs().mean(), on_epoch=True)
        return loss