In [3]:
BATCH_SIZE = 2 #@param
DATASET_PATH = "./temp-data" #@param {"type": "string"}
CONFIG_FILE = "../configs/unet.yaml" #@param {"type": "string"}
CHECKPOINT_PATH = "/content/drive/MyDrive/BraTs/checkpoints" #@param {"type": "string"}
TENSORBOARD_DIR = "/content/drive/MyDrive/BraTs/tensorboard" #@param {"type": "string"}
MODEL_SAVE_PATH = "/content/drive/MyDrive/BraTs/models" #@param {"type": "string"}

### Setup codebase (Google Colab)

In [3]:
! git clone https://github.com/arshamkhodajoo/brain-tumor-segmentation
%cd /content/brain-tumor-segmentation

Cloning into 'brain-tumor-segmentation'...
remote: Enumerating objects: 192, done.[K
remote: Counting objects: 100% (192/192), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 192 (delta 71), reused 174 (delta 53), pack-reused 0[K
Receiving objects: 100% (192/192), 9.65 MiB | 11.11 MiB/s, done.
Resolving deltas: 100% (71/71), done.
/content/brain-tumor-segmentation


In [4]:
! pip install pytorch_lightning monai

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.6.5-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 28.3 MB/s 
[?25hCollecting monai
  Downloading monai-0.9.0-202206131636-py3-none-any.whl (939 kB)
[K     |████████████████████████████████| 939 kB 65.0 MB/s 
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 44.9 MB/s 
[?25hCollecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.9.2-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 61.7 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |███████████████████████████

### read config files

In [4]:
from bras.utils import read_config
config = read_config(CONFIG_FILE)

### Load dataset and setup DataLoader

In [6]:
from torch.utils.data import DataLoader
from bras.utils.datasets import (
    BrainTumorSegmentaion, BRATS_TRAIN_TRANSFORM, BRATS_VALIDATION_TRANSFORM)


In [8]:
from pathlib import Path
Path(DATASET_PATH).mkdir(exist_ok=True)

In [9]:
brats_train_dataset = BrainTumorSegmentaion(
    dataset_path=DATASET_PATH,
    transforms=BRATS_TRAIN_TRANSFORM,
    download=True
)

brats_validation_dataset = BrainTumorSegmentaion(
    dataset_path=DATASET_PATH,
    transforms=BRATS_VALIDATION_TRANSFORM,
    download=True
)

Task01_BrainTumour.tar: 7.09GB [05:19, 23.8MB/s]                            

2022-07-15 13:05:49,982 - INFO - Downloaded: temp-data/Task01_BrainTumour.tar





2022-07-15 13:06:21,863 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2022-07-15 13:06:21,865 - INFO - Writing into directory: temp-data.
2022-07-15 13:07:44,983 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2022-07-15 13:07:44,985 - INFO - File exists: temp-data/Task01_BrainTumour.tar, skipped downloading.
2022-07-15 13:07:44,991 - INFO - Non-empty folder exists in temp-data/Task01_BrainTumour, skipped extracting.


In [10]:
brats_train_dataloader = DataLoader(brats_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
brats_validation_dataloader = DataLoader(brats_train_dataset, batch_size=1, shuffle=False)

### Define model and loss function

In [5]:
from bras.nn.unet import DynUnet3D
from bras.nn.losses import BraTsDiceFocalLoss

In [6]:
unet_3d_model = DynUnet3D(config=config["model"])
loss_fn = BraTsDiceFocalLoss(config=config["loss"])

### Setup train 

In [7]:
import pytorch_lightning as pl
from bras.utils.train import (
    create_optimizer, create_lr_scheduler, LightningSegmentationModel)

from bras.nn.metric import DiceLightningMetric

In [8]:
optimizer = create_optimizer(config["optimizer"], unet_3d_model)
lr_scheduler = create_lr_scheduler(config["lr_scheduler"], optimizer)
metric_fn = DiceLightningMetric()

In [9]:
lightning_model = LightningSegmentationModel(
    torch_model=unet_3d_model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    metric=("dice_metric", metric_fn)
)

### Train schedule

In [25]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

logger = TensorBoardLogger(TENSORBOARD_DIR, name="unet_3d_brats")
trainer = Trainer(
    logger=logger,
    max_epochs=20,
    log_every_n_steps=5,
    default_root_dir=CHECKPOINT_PATH,
    gpus=1
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(
    model=lightning_model,
    train_dataloaders=brats_train_dataloader,
    val_dataloaders=brats_validation_dataset
)