In [5]:
BATCH_SIZE = 5 #@param

- read config files

In [None]:
from bras.utils import read_config
config = read_config("./configs/unet.yaml")

### Load dataset and setup DataLoader

In [1]:
from torch.utils.data import DataLoader
from bras.utils.datasets import BrainTumorSegmentaion, BRATS_TRAIN_TRANSFORM, BRATS_VALIDATION_TRANSFORM

In [2]:
dataset_path = "C:/Users/Arsham/Documents/Projects/brain-tumor-segmentation/dataset/BraTS-sample"

In [None]:
brats_train_dataset = BrainTumorSegmentaion(
    dataset_path=dataset_path,
    transforms=BRATS_TRAIN_TRANSFORM,
    download=True
)

brats_validation_dataset = BrainTumorSegmentaion(
    dataset_path=dataset_path,
    transforms=BRATS_VALIDATION_TRANSFORM,
    download=True
)

In [None]:
brats_train_dataloader = DataLoader(brats_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
brats_validation_dataloader = DataLoader(brats_train_dataset, batch_size=1, shuffle=False)

### Define model and loss function

In [7]:
from bras.nn.unet import DynUnet3D
from bras.nn.losses import BraTsDiceFocalLoss

In [None]:
unet_3d_model = DynUnet3D(config=config["model"])
loss_fn = BraTsDiceFocalLoss(config=config["loss"])

### Setup train 

In [9]:
import pytorch_lightning as pl
from bras.utils.train import create_optimizer, create_lr_scheduler

In [None]:
optimizer = create_optimizer(config["optimizer"], unet_3d_model)
lr_scheduler = create_lr_scheduler(config["scheduler"], optimizer)

In [None]:
class LightningModel(pl.LightningModule):

    def __init__(self, model, loss_fn, optimizer, scheduler):
        self.model = model
        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.optimizer = optimizer

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        channels, segmentations = (
            batch["image"], batch["label"]
        )

        self.log("running batch {} ..".format(batch_idx))
        output = self.model(channels)
        loss = self.loss_fn(output, segmentations)
        return loss