# Vesuvis PyTorch ⚡ MONAI 
From https://www.kaggle.com/code/clemchris/vesuvis-pytorch-monai

## Nice [MONAI](https://docs.monai.io/en/stable/index.html) Features:
- [CSVDataset](https://docs.monai.io/en/stable/data.html#csvdataset) to easily create dataset from DataFrame containing paths to volumes, masks, and labels
- [RandWeightedCropd](https://docs.monai.io/en/stable/transforms.html#randweightedcropd) to create multiple random crops weighted with the mask
- [matshow3d()](https://docs.monai.io/en/stable/visualize.html#monai.visualize.utils.matshow3d) function to quickly visualize volumes, masks, and labels
- [UNet](https://docs.monai.io/en/stable/networks.html#unet) implementation
- [DiceLoss](https://docs.monai.io/en/stable/losses.html#diceloss) implementation (Jaccard & DiceCELoss available as well)
- [sliding_window_inference](https://docs.monai.io/en/stable/inferers.html#monai.inferers.sliding_window_inference) to run prediction on whole volume using patches

## Nice [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) Features:
- [LightningDataModule](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html) to set up the train and val datasets, transforms, and dataloaders
- [LightningModule](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html) to set up the model, loss, metrics, optimizer, scheduler, logging, callbacks, training and validation steps
- [Trainer](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html) to run training on (multiple) GPUs with mixed precision

# Imports

In [None]:
from collections import defaultdict
from io import StringIO
from pathlib import Path
from typing import Tuple

import lovely_numpy as ln
import monai
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image as Image
import pytorch_lightning as pl
import seaborn as sns
import torch
from monai.data import CSVDataset, CacheDataset
from monai.data import DataLoader
from monai.inferers import sliding_window_inference
from monai.visualize import matshow3d
from torchmetrics import Dice
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryFBetaScore
from tqdm.auto import tqdm

In [None]:
torch.cuda.is_available()

In [None]:
torch.cuda.device_count()

In [None]:
torch.cuda.current_device()

In [None]:
torch.cuda.get_device_name(0)

# Paths & Settings

In [55]:
COMPETITION_DATA_DIR = Path("/data")

TRAIN_DATA_CSV_PATH = COMPETITION_DATA_DIR / "data.csv"
TEST_DATA_CSV_PATH = COMPETITION_DATA_DIR/"test.csv"

NUM_Z_SLICES = 65

ACCELERATOR = "gpu"
BATCH_SIZE = 1
DEVICES = 1
DROPOUT = 0.0
ETA_MIN = 1e-6
FAST_DEV_RUN = False
INTENSITY_TRANSFORM = "NormalizeIntensity"
LEARNING_RATE = 0.01
LOSS = "BCE"
MODEL_NAME = "FlexibleUNet_efficientnet-b0"
MAX_EPOCHS = 100
NUM_WORKERS = 40
NUM_SAMPLES = 12
OPTIMIZER = "SGD"
OVERFIT_BATCHES = 0
PATCH_SIZE = (512, 512)
PRECISION = 16
RAND_TRANSFORMS = "RandZoom-RandGaussianNoise-RandGaussianSmooth-RandScaleIntensity-RandFlip0-RandFlip1"
SCHEDULER = "CosineAnnealingLR"
SEED = 2023
SW_BATCH_SIZE = 4
VAL_FRAGMENT_ID = "1"
WEIGHT_DECAY = 1e-6

THRESHOLD = 0.5

# Lightning Datamodule

In [None]:
class VesuvisDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int,
        data_csv_path: str,
        intensity_transform: str,
        num_workers: int,
        num_samples: int,
        patch_size: Tuple[int, int],
        rand_transforms: str,
        val_fragment_id: str,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.df = pd.read_csv(data_csv_path)

        self.keys = ("volume_npy", "mask_npy", "label_npy")
        self.train_transform = self._init_train_transform()
        self.val_transform = self._init_val_transform()
        self.predict_transform = self._init_predict_transform()
        
    def _load_transforms(self, predict: bool = False):
        return [
            monai.transforms.LoadImaged(
                keys="volume_npy",
            ),
            monai.transforms.LoadImaged(
                keys=("mask_npy", "label_npy") if not predict else "mask_npy",
                ensure_channel_first=True,
            ),
        ]

    @property
    def _intensity_transforms(self):
        if self.hparams.intensity_transform == "NormalizeIntensity":
            return [
                monai.transforms.NormalizeIntensityd(
                    keys="volume_npy",
                    nonzero=True,
                    channel_wise=True,
                ),
            ]
        elif self.hparams.intensity_transform == "ScaleIntensity":
            return [
                monai.transforms.ScaleIntensityd(
                    keys="volume_npy",
                ),
            ]
        else:
            raise ValueError(f"{self.hparams.intensity_transform} is not implemented")
            
    @property
    def _rand_transforms(self):
        all_rand_transforms = {
            "RandAffine": monai.transforms.RandAffined(
                keys=self.keys,
                prob=0.75,
                rotate_range=(np.pi / 4, np.pi / 4),
                translate_range=(0.0625, 0.0625),
                scale_range=(0.1, 0.1),
            ),
            "RandFlip0": monai.transforms.RandFlipd(
                keys=self.keys,
                spatial_axis=0,
                prob=0.5,
            ),
            "RandFlip1": monai.transforms.RandFlipd(
                keys=self.keys,
                spatial_axis=1,
                prob=0.5,
            ),
            "RandGaussianNoise": monai.transforms.RandGaussianNoised(
                keys="volume_npy",
                prob=0.15,
                mean=0.0,
                std=0.01,
            ),
            "RandGaussianSmooth": monai.transforms.RandGaussianSmoothd(
                keys="volume_npy",
                prob=0.15,
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
            ),
            "RandScaleIntensity": monai.transforms.RandScaleIntensityd(
                keys="volume_npy",
                factors=0.3,
                prob=0.15,
            ),
            "RandZoom": monai.transforms.RandZoomd(
                keys=self.keys,
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest", "nearest"),
                align_corners=(True, None, None),
                prob=0.15,
            ),
        }

        rand_transforms = [
            monai.transforms.RandCropByPosNegLabeld(
                keys=self.keys,
                label_key="label_npy",
                spatial_size=self.hparams.patch_size,
                num_samples=self.hparams.num_samples,
                image_key="volume_npy",
                image_threshold=0,
            ),
        ]

        if self.hparams.rand_transforms is not None:
            for rand_transform in self.hparams.rand_transforms.split("-"):
                rand_transforms.append(all_rand_transforms[rand_transform])

        return rand_transforms
    
    def _init_train_transform(self):
#         return monai.transforms.Compose(self._load_transforms() + self._intensity_transforms + self._rand_transforms)
        return monai.transforms.Compose(self._intensity_transforms + self._rand_transforms)

    def _init_val_transform(self):
        return monai.transforms.Compose(self._intensity_transforms)
#         return monai.transforms.Compose(self._load_transforms() + self._intensity_transforms)

    def _init_predict_transform(self):
        return monai.transforms.Compose(self._intensity_transforms)
#         return monai.transforms.Compose(self._load_transforms(predict=True) + self._intensity_transforms)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            train_val_df = self.df[self.df.stage == "train"].reset_index(drop=True)

            train_df = train_val_df[train_val_df.fragment_id != int(self.hparams.val_fragment_id)].reset_index(
                drop=True
            )

            val_df = train_val_df[train_val_df.fragment_id == int(self.hparams.val_fragment_id)].reset_index(drop=True)

            self.train_dataset = self._dataset(train_df, self._load_transforms(), self.train_transform)
            self.val_dataset = self._dataset(val_df, self._load_transforms(), self.val_transform)

            print(f"# train: {len(self.train_dataset)}")
            print(f"# val: {len(self.val_dataset)}")

        if stage == "predict" or stage is None:
            predict_df = self.df[self.df.stage == "test"].reset_index(drop=True)
            self.predict_dataset = self._dataset(predict_df, self._load_transforms(predict=True), self.predict_transform)

            print(f"# predict: {len(self.predict_dataset)}")

    def _dataset(self, df, load_transform, transform):
        return CacheDataset(
    data=CSVDataset(
            src=df,
            transform=monai.transforms.Compose(load_transform),
        ),
            transform=transform,
            cache_rate=1.0, runtime_cache="processes", copy_cache=False
)

    def train_dataloader(self):
        return self._dataloader(self.train_dataset, train=True)

    def val_dataloader(self):
        return self._dataloader(self.val_dataset)

    def predict_dataloader(self):
        return self._dataloader(self.predict_dataset)

    def _dataloader(self, dataset, train=False):
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=train,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            drop_last=train,
        )

# Visualize Data

In [58]:
data_module = VesuvisDataModule(
    batch_size=BATCH_SIZE,
    data_csv_path=TRAIN_DATA_CSV_PATH,
    intensity_transform="ScaleIntensity",
    num_workers=NUM_WORKERS,
    num_samples=2,
    rand_transforms=RAND_TRANSFORMS,
    patch_size=PATCH_SIZE,
    val_fragment_id=VAL_FRAGMENT_ID,
)


In [59]:
data_module.setup()


# train: 2
# val: 1
# predict: 0


In [60]:

dataloaders = {
    "train": data_module.train_dataloader(),
    "val": data_module.val_dataloader(),
}



In [None]:
def visualize_dataloaders(dataloaders, train=True):
    for stage, dataloader in dataloaders.items():
        print(stage, len(dataloader))
        for batch_idx, batch in enumerate(dataloader):
            print(batch_idx)
            volumes = batch["volume_npy"]
            masks = batch["mask_npy"]
            
            if train:
                labels = batch["label_npy"]
            else: 
                labels = masks

            print(volumes.shape)
            print(masks.shape)
                
            for volume, mask, label in zip(volumes, masks, labels):
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                plt.suptitle(f"stage: {stage}, fragment: {batch_idx}")

                for idx, image in enumerate((volume, mask, label)):
                    matshow3d(
                        volume=image,
                        fig=axes[idx],
                        title=f"{list(image.shape)}, {image.min().item():.2f}, {image.max().item():.2f}",
                        vmin=0.0,
                        vmax=1.0,
                        every_n=1,
                        fill_value=1.0,
                        margin=4,
                        cmap="gray",
                    )
                plt.show()

In [None]:
# visualize_dataloaders(dataloaders)

# Lightning Module

In [61]:
class VesuvisModule(pl.LightningModule):
    def __init__(
        self,
        dropout: float,
        eta_min: float,
        learning_rate: float,
        loss: str,
        model_name: str,
        max_epochs: int,
        num_z_slices: int,
        optimizer: str,
        patch_size: Tuple[int, int],
        scheduler: str,
        sw_batch_size: int,
        weight_decay: float,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = self._init_model()

        self.loss = self._init_loss()

        self.metrics = self._init_metrics()

    def _init_model(self):
        # TODO: add more models
        if self.hparams.model_name == "UNet":
            return monai.networks.nets.UNet(
                spatial_dims=2,
                in_channels=self.hparams.num_z_slices,
                out_channels=1,
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2),
                num_res_units=2,
                dropout=self.hparams.dropout,
            )
        elif "FlexibleUNet" in self.hparams.model_name:
            return monai.networks.nets.FlexibleUNet(
                in_channels=self.hparams.num_z_slices,
                out_channels=1,
                backbone=self.hparams.model_name.split("_")[1],
                pretrained=True,
                spatial_dims=2,
                dropout=self.hparams.dropout,
            )
        else:
            raise ValueError(f"{self.hparams.model_name} is not implemented")

    def _init_loss(self):
        if self.hparams.loss == "BCE":
            loss = torch.nn.BCEWithLogitsLoss()
        elif self.hparams.loss == "Dice":
            loss = monai.losses.DiceLoss(sigmoid=True)
        elif self.hparams.loss == "Jaccard":
            loss = monai.losses.DiceLoss(
                sigmoid=True,
                jaccard=True,
            )
        elif self.hparams.loss == "DiceCE":
            loss = monai.losses.DiceCELoss(sigmoid=True)
        else:
            raise ValueError(f"{self.hparams.loss} is not implemented")

        return monai.losses.MaskedLoss(loss)

    def _init_metrics(self):
        metric_collection = MetricCollection(
            {
                "dice": Dice(),
                "fbeta": BinaryFBetaScore(beta=0.5),
            }
        )

        return torch.nn.ModuleDict(
            {
                "train_metrics": metric_collection.clone(prefix="train_"),
                "val_metrics": metric_collection.clone(prefix="val_"),
            }
        )

    def configure_optimizers(self):
        optimizer = self._init_optimizer()
        scheduler = self._init_scheduler(optimizer)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
            },
        }

    def _init_optimizer(self):
        if self.hparams.optimizer == "Adam":
            return torch.optim.Adam(
                params=self.parameters(),
                lr=self.hparams.learning_rate,
                weight_decay=self.hparams.weight_decay,
            )
        elif self.hparams.optimizer == "AdamW":
            return torch.optim.AdamW(
                params=self.parameters(),
                lr=self.hparams.learning_rate,
                weight_decay=self.hparams.weight_decay,
            )
        elif self.hparams.optimizer == "SGD":
            return torch.optim.SGD(
                params=self.parameters(),
                lr=self.hparams.learning_rate,
                momentum=0.99,
                nesterov=True,
            )
        else:
            raise ValueError(f"{self.hparams.optimizer} is not implemented")

    def _init_scheduler(self, optimizer):
        if self.hparams.scheduler == "CosineAnnealingLR":
            return torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.hparams.max_epochs,
                eta_min=self.hparams.eta_min,
            )
        elif self.hparams.scheduler == "StepLR":
            return torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=self.hparams.max_epochs // 5,
                gamma=0.95,
            )
        elif self.hparams.scheduler == "PolyLR":
            return torch.optim.lr_scheduler.LambdaLR(
                optimizer, lr_lambda=lambda epoch: (1 - epoch / self.hparams.max_epochs) ** 0.9
            )
        else:
            raise ValueError(f"{self.hparams.scheduler} is not implemented")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        self._shared_step(batch, "val")

    def predict_step(self, batch, batch_idx):
        outputs = self._forward_pass(batch, "predict")
        return outputs.sigmoid().squeeze()

    def _shared_step(self, batch, stage):
        outputs, labels, masks = self._forward_pass(batch, stage)

        loss = self.loss(outputs, labels, masks)

        self.metrics[f"{stage}_metrics"](outputs, labels)

        self._log(loss, stage, batch_size=len(outputs))

        return loss

    def _forward_pass(self, batch, stage):
        volumes = batch["volume_npy"].as_tensor()
        masks = batch["mask_npy"].as_tensor()

        if stage == "train":
            outputs = self(volumes)
        elif stage == "val":
            outputs = sliding_window_inference(
                inputs=volumes,
                roi_size=self.hparams.patch_size,
                sw_batch_size=self.hparams.sw_batch_size,
                predictor=self,
                overlap=0.5,
                mode="gaussian",
            )
        elif stage == "predict":
            outputs = sliding_window_inference(
                inputs=volumes,
                roi_size=self.hparams.patch_size,
                sw_batch_size=self.hparams.sw_batch_size,
                predictor=self,
                overlap=0.5,
                mode="gaussian",
            )

            ct = 1.0
            for dims in [[2], [3], [2, 3]]:
                flip_inputs = torch.flip(volumes, dims)
                flip_outputs = torch.flip(
                    sliding_window_inference(
                        inputs=flip_inputs,
                        roi_size=self.hparams.patch_size,
                        sw_batch_size=self.hparams.sw_batch_size,
                        predictor=self,
                        overlap=0.5,
                        mode="gaussian",
                    ),
                    dims,
                )
                del flip_inputs
                outputs += flip_outputs
                del flip_outputs
                ct += 1.0

            outputs /= ct
            
            return outputs

        try:
            labels = batch["label_npy"].as_tensor().long()
            return outputs, labels, masks
        except KeyError:
            return outputs, masks

    def _log(self, loss, stage, batch_size):
        self.log(f"{stage}_loss", loss, batch_size=batch_size)
        self.log_dict(self.metrics[f"{stage}_metrics"], batch_size=batch_size)

# Train

In [62]:
def train(
    accelerator=ACCELERATOR,
    batch_size=BATCH_SIZE,
    data_csv_path=TRAIN_DATA_CSV_PATH,
    devices=DEVICES,
    dropout=DROPOUT,
    eta_min=ETA_MIN,
    fast_dev_run=FAST_DEV_RUN,
    intensity_transform=INTENSITY_TRANSFORM,
    learning_rate=LEARNING_RATE,
    loss=LOSS,
    model_name=MODEL_NAME,
    max_epochs=MAX_EPOCHS,
    num_workers=NUM_WORKERS,
    num_samples=NUM_SAMPLES,
    num_z_slices=NUM_Z_SLICES,
    optimizer=OPTIMIZER,
    overfit_batches=OVERFIT_BATCHES,
    patch_size=PATCH_SIZE,
    precision=PRECISION,
    rand_transforms=RAND_TRANSFORMS,
    scheduler=SCHEDULER,
    seed=SEED,
    sw_batch_size=SW_BATCH_SIZE,
    val_fragment_id=VAL_FRAGMENT_ID,
    weight_decay=WEIGHT_DECAY,
):
    monai.utils.set_determinism(seed)
    pl.seed_everything(seed, workers=True)

    data_module = VesuvisDataModule(
        batch_size=batch_size,
        data_csv_path=data_csv_path,
        intensity_transform=intensity_transform,
        num_workers=num_workers,
        num_samples=num_samples,
        patch_size=patch_size,
        rand_transforms=rand_transforms,
        val_fragment_id=val_fragment_id,
    )

    module = VesuvisModule(
        dropout=dropout,
        eta_min=eta_min,
        learning_rate=learning_rate,
        loss=loss,
        model_name=model_name,
        max_epochs=max_epochs,
        num_z_slices=num_z_slices,
        optimizer=optimizer,
        patch_size=patch_size,
        scheduler=scheduler,
        sw_batch_size=sw_batch_size,
        weight_decay=weight_decay,
    )

    trainer = pl.Trainer(
        accelerator=accelerator,
        benchmark=True,
        check_val_every_n_epoch=1,
        devices=devices,
        fast_dev_run=fast_dev_run,
        logger=pl.loggers.CSVLogger(save_dir='logs/'),
        log_every_n_steps=1,
        max_epochs=max_epochs,
        overfit_batches=overfit_batches,
        precision=precision,
        strategy="ddp" if devices > 1 else 'auto',
    )

    trainer.fit(module, datamodule=data_module)

    return module, trainer

In [None]:
module, trainer = train()

2023-05-04 08:36:54,257 - Global seed set to 2023
2023-05-04 08:36:55,742 - Using 16bit Automatic Mixed Precision (AMP)
2023-05-04 08:36:55,763 - GPU available: True (cuda), used: True
2023-05-04 08:36:55,764 - TPU available: False, using: 0 TPU cores
2023-05-04 08:36:55,765 - IPU available: False, using: 0 IPUs
2023-05-04 08:36:55,766 - HPU available: False, using: 0 HPUs


  rank_zero_warn(


# train: 2
# val: 1
2023-05-04 08:36:56,439 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
2023-05-04 08:36:56,604 - 
  | Name    | Type         | Params
-----------------------------------------
0 | model   | FlexibleUNet | 7.5 M 
1 | loss    | MaskedLoss   | 0     
2 | metrics | ModuleDict   | 0     
-----------------------------------------
7.5 M     Trainable params
0         Non-trainable params
7.5 M     Total params
30.183    Total estimated model params size (MB)


  rank_zero_warn(


Sanity Checking: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  t = cls([], dtype=storage.dtype, device=storage.device)
  t = cls([], dtype=storage.dtype, device=storage.device)


Training: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  t = cls([], dtype=storage.dtype, device=storage.device)
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


Validation: 0it [00:00, ?it/s]

  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


In [None]:
# From https://www.kaggle.com/code/jirkaborovec?scriptVersionId=93358967&cellId=22
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
metrics = metrics[["epoch", "train_loss", "val_loss", "val_dice", "val_fbeta"]]
metrics.set_index("epoch", inplace=True)

sns.relplot(data=metrics, kind="line", height=5, aspect=1.5)
plt.grid()

## Visualize

In [None]:
data_module = VesuvisDataModule(
    batch_size=BATCH_SIZE,
    data_csv_path=TEST_DATA_CSV_PATH,
    intensity_transform=INTENSITY_TRANSFORM,
    num_workers=NUM_WORKERS,
    num_samples=NUM_SAMPLES,
    rand_transforms=RAND_TRANSFORMS,
    patch_size=PATCH_SIZE,
    val_fragment_id=VAL_FRAGMENT_ID,
)
data_module.setup(stage="predict")

dataloaders = {
    "predict": data_module.predict_dataloader(),
}

visualize_dataloaders(dataloaders, train=False)

# Predict

In [None]:
def predict(
    module,
    accelerator=ACCELERATOR,
    batch_size=BATCH_SIZE,
    data_csv_path=TEST_DATA_CSV_PATH,
    devices=DEVICES,
    intensity_transform=INTENSITY_TRANSFORM,
    num_workers=NUM_WORKERS,
    num_samples=NUM_SAMPLES,
    patch_size=PATCH_SIZE,
    precision=PRECISION,
    rand_transforms=RAND_TRANSFORMS,
    seed=SEED,
    val_fragment_id=VAL_FRAGMENT_ID,
):
    monai.utils.set_determinism(seed)
    pl.seed_everything(seed, workers=True)

    data_module = VesuvisDataModule(
        batch_size=batch_size,
        data_csv_path=data_csv_path,
        intensity_transform=intensity_transform,
        num_workers=num_workers,
        num_samples=num_samples,
        patch_size=patch_size,
        rand_transforms=rand_transforms,
        val_fragment_id=val_fragment_id,
    )

    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        precision=precision,
    )

    predictions = trainer.predict(module, datamodule=data_module)

    return predictions

In [None]:
predictions = predict(module)

# Submission

In [None]:
def plot_image(image, title):
    fig = plt.figure()
    plt.title(title)
    plt.imshow(image, cmap="gray")
    
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    # pixels = (pixels >= thr).astype(int)
    
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


In [None]:
submission_df = pd.read_csv(COMPETITION_DATA_DIR / "sample_submission.csv")

predictions_rle = []
for mask_png_path, prediction in zip(test_df["mask_png"].values, predictions):
    prediction = prediction.numpy()
    plot_image(prediction, f"{ln.lovely(prediction)}")

    mask = load_image(mask_png_path)

    prediction = prediction * mask
    plot_image(prediction, f"{ln.lovely(prediction)}")

    prediction = np.where(prediction > THRESHOLD, 1, 0).astype(np.uint8)
    plot_image(prediction, f"{ln.lovely(prediction)}")

    prediction_rle = rle(prediction)
    predictions_rle.append(prediction_rle)

    plot_image(prediction, f"{ln.lovely(prediction)}")
        
    del prediction
    
submission_df["Predicted"] = predictions_rle
submission_df.to_csv("submission.csv", index=False)
submission_df