In [None]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb
import multiprocessing


import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import segmentation_models_pytorch as smp

from models.unet import *
from models.simple_model import *
from utils.data_utils.acdc_datamodule import *
from utils.data_utils.data_utils import *
from utils.model_utils.dice_score import *
from utils.model_utils.resnet_loss import ResnetLoss 

from lightning.pytorch.callbacks import RichProgressBar



In [None]:
pl.seed_everything(42)

In [None]:
wandb.login()


In [None]:
class SemanticSegmanter(pl.LightningModule):
    def __init__(self, model, learning_rate, criterion) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.lr = learning_rate

    def forward(self, x):
        model_output = self.model(x)
        return model_output
    
    def training_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        loss = self.criterion(masks_pred, ground_truths)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        # ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        return loss

    def test_step(self, batch, batch_idx):
        images, ground_truths = batch
        masks_pred = self.model(images)
        ground_truths = ground_truths.long()
        loss = self.criterion(masks_pred, ground_truths)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [None]:
# Constans and Hyperparams
NUM_CLASSES = 4
MAX_EPOCHS = 500

# Big model takes lots of space in memory -> small batch size fits in
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 8
BATCH_SIZE_TEST  = 8



In [None]:
transform = DualTransform(20,0.2,0.2)
datamodule = ACDCDataModule("database", BATCH_SIZE_TRAIN,BATCH_SIZE_VAL,BATCH_SIZE_TEST,(256,256,1), convert_to_single=False, transform=transform)
datamodule.setup("fit")


In [None]:
len(datamodule.acdc_val.x)

In [None]:
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights


wandb_logger = pl.loggers.WandbLogger(project="Medical Image Segmentation")

model = smp.Unet('resnet34',classes=NUM_CLASSES, in_channels=1)

# model = fcn_resnet50()
# model.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# model.classifier[4] = nn.Conv2d(512, NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))

# model = UNet(n_channels=1, n_classes=4)

# model = SimpleSegmentationModel(1,4)

wandb_logger.watch(model)


In [None]:
# Diceloss kiegyensúlyozatlan osztályokra
# CrossEntropy is jó lehet, de Dice is jó 
# Jacquard index - IoU
# ignore index kipróbálása -> metrikánál mindenképp érdemes
#  1 ignorált klassznak lehet nincs hatása
#  ignorált osztály -> lehet crossentropy is azonos hatékonyságú
#  osztályonkénti dice coeff
# split betegenként, fixáljuk a validációs halmazt
#  KFold ha nem fixálunk -> ez is lehet ensemble
# pixelszintű tévesztési mátrix, pixel accuracy 
#  legyenek példapredikciók
#  kvalitatív kiértékelés


criterion = smp.losses.DiceLoss(mode="multiclass", ignore_index=0)
# loss = ResnetLoss(criterion)
loss = criterion
# loss = torch.nn.CrossEntropyLoss()

In [None]:
from lightning import Callback

class_labels = {
  0: "background",
  1: "RV",
  2: "myocardium",
  3: "LV"
}

class Visualizer(Callback):
    def on_validation_epoch_start(self, trainer, model):
        x = trainer.datamodule.acdc_val[0][0]
        y = trainer.datamodule.acdc_val[0][1]
        pred = model(x.unsqueeze(0).to('cuda'))
        gt_mask = np.array(y.squeeze())
        pred_mask = np.array(torch.argmax(pred.squeeze().cpu(), dim=0))
        error = (gt_mask == pred_mask).astype(np.uint8)
        trainer.logger.experiment.log(
            {"visualizing":[
                    wandb.Image(x, caption="GT", masks={
                        "segmentation": {
                            "mask_data": gt_mask,
                            "class_labels": class_labels
                        },
                    }),
                    wandb.Image(x, caption="pred", masks={
                        "segmentation": {
                            "mask_data": pred_mask,
                            "class_labels": class_labels
                        },
                    }),
                    wandb.Image(error, caption="error"),
                ]
            })

In [None]:
# May need to add new preprocessing arg, to include pretrained model preprocessing
# preprocess_input = smp.encoders.get_preprocessing_fn('resnet18', pretrained='imagenet')



segmenter = SemanticSegmanter(model = model, learning_rate=1e-4 ,criterion=loss)

# tsmp.metrics.functional.IoU or torch metric?
# do we need this?
# metric = smp.metrics.functional.IoU(threshold=0.5)


# Configure callbacks and logger


early_stopping = EarlyStopping(monitor='val_loss',  patience=5 ,mode="min", verbose=True)

trainer = pl.Trainer(
    max_epochs=30,  
    logger=wandb_logger, 
    callbacks = [
        RichProgressBar(),
        Visualizer(),
        early_stopping
        ]
    )
trainer.fit(segmenter, datamodule=datamodule)
wandb.finish()

In [None]:
wandb.finish()