# Skull Segmentation

Training a segmentation model with Pytorch Lightning

In [None]:
from glob import glob
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
)
from monai.data import CacheDataset, DataLoader
from autoimplantpipe import UNetSegmentation

In [None]:
def set_data_dict(data_dir: str):
    """
    Use for create data dict from data dir.
    """
    images = glob(f"{data_dir}/*/*_resampled.nii")
    labels = glob(f"{data_dir}/*/*_label.nii")
    data_dicts = [
        {"image": image_name, "label": label_name}
        for image_name, label_name in zip(images, labels)
    ]
    return data_dicts

In [None]:
# Define data dir and create training and validation data dict
train_files = set_data_dict("Datasets_skull/train")
val_files = set_data_dict("Datasets_skull/val")

In [None]:
# Setup transform for training and validation data
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-43,
            a_max=453,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(0.5, 0.5, 0.625),
            mode=("bilinear", "nearest"),
        ),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-43,
            a_max=453,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(0.5, 0.5, 0.625),
            mode=("bilinear", "nearest"),
        ),
    ]
)

In [None]:
# Define dataloader for training and validation
train_ds = CacheDataset(
    data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

In [None]:
# Train segmentation with Pytorch lightning Unet model
wandb_logger = WandbLogger(project="<project_name>", log_model="all")
checkpoint_callback = ModelCheckpoint(
    monitor="val_dice",
    mode="max",
    dirpath="checkpoints/",
    filename="sample-mnist-{epoch:02d}-{val_dice:.2f}",
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    logger=wandb_logger,
    log_every_n_steps=1,
    callbacks=checkpoint_callback,
    default_root_dir="model/",
)
model = UNetSegmentation()
trainer.fit(model, train_loader, val_loader)