In [None]:
from monai.config import print_config

In [None]:
print_config()

In [None]:
import pytorch_lightning as pl

In [None]:
pl.seed_everything(42, workers=True)

In [None]:
# import math

from glob import glob
import os.path

from typing import Optional

from sklearn.model_selection import train_test_split

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
    ScaleIntensityRanged, CropForegroundd, RandCropByLabelClassesd,
    RandAffined, ToTensord
)

from monai.data import CacheDataset, DataLoader, Dataset


NUM_LABELS = 13

ROI_SIZE = (128, 128, 64)

class FlareDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 32,
        dev_ratio: float = 0.2,
        cache_ds: bool = True,
        max_workers: int = 4,
        **kwargs
    ):
        super().__init__()
        
        self._dict_keys = ("image", "label")
        
        data_dir = "/mnt/HDD2/flare2022/datasets/FLARE2022"
        self.supervised_dir = os.path.join(data_dir, "Training", "FLARE22_LabeledCase50")
        
        self.num_workers = min(os.cpu_count(), max_workers)
        
        self.save_hyperparameters()
    
    def setup(self, stage: Optional[str] = None):
        if stage is None or stage == "fit":
            images = self.get_image_paths("images")
            labels = self.get_image_paths("labels")

            data_dicts = tuple(
                {"image": img, "label": lab} for img, lab in zip(images, labels)
            )
            
            train_files, val_files = train_test_split(data_dicts, test_size=self.hparams.dev_ratio)
            
            self.crop_num_samples = 4
            
            train_transforms = self.get_transform(
                    RandCropByLabelClassesd(
                        keys=self._dict_keys,
                        label_key="label",
                        spatial_size=ROI_SIZE,
                        num_samples=self.crop_num_samples,
                        num_classes=NUM_LABELS + 1
                    ),
                    # user can also add other random transforms
#                     RandAffined(
#                         keys=keys,
#                         mode=('bilinear', 'nearest'),
#                         prob=1.0,
#                         rotate_range=(0, 0, math.pi/15),
#                         scale_range=(0.1, 0.1, 0.1)
#                     )
            )
            val_transforms = self.get_transform()
            
            self.train_ds = self.get_dataset(train_files, train_transforms)
            
            self.val_ds = self.get_dataset(val_files, val_transforms)
    
    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.hparams.batch_size // self.crop_num_samples ,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=1, # Because the images do not align and are not cropped
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def get_image_paths(self, baseDir: str):
        image_paths = glob(os.path.join(self.supervised_dir, baseDir, "*.nii.gz"))
        image_paths.sort()
        return image_paths

    def get_transform(self, *random_transforms):
        keys = self._dict_keys
        return Compose((
            LoadImaged(keys=keys),
            EnsureChannelFirstd(keys=keys),
            Orientationd(keys=keys, axcodes="RAS"),
    #         Spacingd(keys=keys, pixdim=(
    #             1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    #         ScaleIntensityRanged(
    #             "image", a_min=-57, a_max=164,
    #             b_min=0.0, b_max=1.0, clip=True,
    #         ),
            CropForegroundd(keys=keys, source_key="image"),
            *random_transforms,
            ToTensord(keys=keys),
        ))
    
    def get_dataset(self, *dataset_args):
        return (
            CacheDataset(*dataset_args, num_workers=self.num_workers)
            if self.hparams.cache_ds else
            Dataset(*dataset_args)
        )

In [None]:
datamodule = FlareDataModule(batch_size=16)

In [None]:
# datamodule.setup()

In [None]:
# datamodule.train_ds[0][0]["label"].shape

In [None]:
# datamodule.train_ds[0][0]["image"].shape

In [None]:
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference

import torch
from torchmetrics.functional import dice_score

class Segmentor(pl.LightningModule):
    def __init__(
        self,
        learning_rate: float = 1e-4,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()
        
        self.model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=NUM_LABELS + 1,
            channels=(4, 8, 16),
            strides=(2, 2),
            num_res_units=2,
            norm="batch",
            bias=False # no need for bias for batch norm
        )
        
        self.criterion = DiceLoss(to_onehot_y=True, softmax=True)

    def forward(self, image):
        return self.model(image)

    def training_step(self, batch, batch_idx):
        label = batch["label"]
        image = batch["image"]
        
        output = self(image)
        
        score = dice_score(output, label, bg=False)
        
        self.log("train_dice_score", score, prog_bar=True)
        
        loss = self.criterion(output, label)
        
        self.log("train_loss", loss)
        
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        label = batch["label"]
        image = batch["image"]
        
        roi_size = ROI_SIZE
        sw_batch_size = 4
        output = sliding_window_inference(image, roi_size, sw_batch_size, self, overlap=0.1)
        
        score = dice_score(output, label, bg=False)
        
        self.log("val_dice_score", score, batch_size=1, prog_bar=True)
        
        loss = self.criterion(output, label)
        
        self.log("val_loss", loss, batch_size=1)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), self.hparams.learning_rate)

In [None]:
model = Segmentor()

In [None]:
wandb_logger = pl.loggers.WandbLogger(
    project="flare",
    name="unet",
)

# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=10,
    monitor="val_loss",
    mode="min",
    dirpath="checkpoints",
    save_weights_only=True,
    filename="unet-{epoch:02d}-{val_loss:.2f}",
)

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=50,
    gpus=[1],
    log_every_n_steps=5,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

In [None]:
trainer.fit(model, datamodule=datamodule)

In [None]:
# with torch.inference_mode():
#     batch = next(iter(datamodule.train_dataloader()))
    
#     label = batch["label"]
#     image = batch["image"]
    
#     preds = model(image)

In [None]:
# preds.shape

In [None]:
# label.shape

In [None]:
# dice_score(preds, label)