In [None]:
is_on_colab = True

In [5]:
BATCH_SIZE = 2 #@param
DATASET_PATH = "./dataset" #@param
CONFIG_FILE = "./configs/unet.yaml" #@param
CHECKPOINT_PATH = "" #@param
TENSORBOARD_DIR = "" #@param
MODEL_SAVE_PATH = "" #@param

### Setup codebase (Google Colab)

In [None]:
! git clone https://github.com/arshamkhodajoo/brain-tumor-segmentation
%cd /content/brain-tumor-segmentation

In [None]:
! pip install pytorch_lightning monai

### read config files

In [None]:
from bras.utils import read_config
config = read_config(CONFIG_FILE)

### Load dataset and setup DataLoader

In [1]:
from torch.utils.data import DataLoader
from bras.utils.datasets import (
    BrainTumorSegmentaion, BRATS_TRAIN_TRANSFORM, BRATS_VALIDATION_TRANSFORM)


In [None]:
brats_train_dataset = BrainTumorSegmentaion(
    dataset_path=DATASET_PATH,
    transforms=BRATS_TRAIN_TRANSFORM,
    download=True
)

brats_validation_dataset = BrainTumorSegmentaion(
    dataset_path=DATASET_PATH,
    transforms=BRATS_VALIDATION_TRANSFORM,
    download=True
)

In [None]:
brats_train_dataloader = DataLoader(brats_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
brats_validation_dataloader = DataLoader(brats_train_dataset, batch_size=1, shuffle=False)

### Define model and loss function

In [7]:
from bras.nn.unet import DynUnet3D
from bras.nn.losses import BraTsDiceFocalLoss

In [None]:
unet_3d_model = DynUnet3D(config=config["model"])
loss_fn = BraTsDiceFocalLoss(config=config["loss"])

### Setup train 

In [9]:
import pytorch_lightning as pl
from bras.utils.train import (
    create_optimizer, create_lr_scheduler, LightningSegmentationModel)

from bras.nn.metric import DiceLightningMetric

In [None]:
optimizer = create_optimizer(config["optimizer"], unet_3d_model)
lr_scheduler = create_lr_scheduler(config["lr_scheduler"], optimizer)
metric_fn = DiceLightningMetric()

In [None]:
lightning_model = LightningSegmentationModel(
    torch_model=unet_3d_model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    metric=("dice_metric", metric_fn.forward)
)

### Train schedule

In [1]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

logger = TensorBoardLogger(TENSORBOARD_DIR, name="unet_3d_brats")
trainer = Trainer(
    logger=logger,
    gpus=1,
    max_epochs=20,
    log_every_n_steps=5,
    default_root_dir=CHECKPOINT_PATH
)

In [None]:
trainer.fit(
    model=lightning_model,
    train_dataloaders=brats_train_dataloader,
    val_dataloaders=brats_validation_dataloader
)