In [1]:
import os
import torch
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils.loss import DiceLoss
from utils.dataset import BrainMRISliceDataset
from utils.utils import train, validate
from utils.metric import MetricsMonitor
from utils.vis import plot_mri

## Constants

In [2]:
ROOT_DIR = './Data/'
BATCH_SIZE = 64
EPOCHS = 10
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_WORKERS = os.cpu_count()

## Transforms

In [4]:
train_transform = A.Compose([
    A.Resize(256, 256),  # Resize both image and mask
    # A.HorizontalFlip(p=0.5),  # Random horizontal flip
    # A.RandomBrightnessContrast(p=0.2),  # Adjust brightness/contrast for images
    ToTensorV2()  # Convert to PyTorch tensors
], additional_targets={'mask': 'mask'})  # Specify the target name for the label

test_transform = A.Compose([
    A.Resize(256, 256),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

In [5]:
train_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'train'), slice_axis=2, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=2, transform=test_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

## Models

In [6]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

## Freeze encoder and bottleneck layers
# for name, parameter in model.named_parameters():
#     if 'decoder' not in name:
#         parameter.requires_grad = False

model._modules['conv'] = torch.nn.Conv2d(32, 3, kernel_size=1, stride=1)
model = model.to(DEVICE)

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


## Loss & Optimizer

In [7]:
# criteria = torch.nn.BCEWithLogitsLoss()
criteria = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Training

In [8]:
# Monitors
train_monitor = MetricsMonitor(metrics=["loss", "accuracy", "dice_score"])
val_monitor = MetricsMonitor(
    metrics=["loss", "accuracy", "dice_score"], patience=5, mode="max"
)
test_monitor = MetricsMonitor(metrics=["loss", "accuracy", "dice_score"])

In [None]:
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    print("-" * 10)
    train(model, train_loader, criteria, optimizer, DEVICE, train_monitor)
    validate(model, val_loader, criteria, DEVICE, val_monitor)

Epoch 1/10
----------
[Train] Iteration 40/40 - loss: 0.8824, accuracy: 0.4545, dice_score: 0.1176
Train Metrics - loss: 0.8824, accuracy: 0.4545, dice_score: 0.1176
[Validation] Iteration 20/20 - loss: 0.8844, accuracy: 0.5157, dice_score: 0.1156
Validation Metrics - loss: 0.8844, accuracy: 0.5157, dice_score: 0.1156
Epoch 2/10
----------
[Train] Iteration 40/40 - loss: 0.8609, accuracy: 0.5613, dice_score: 0.1391
Train Metrics - loss: 0.8609, accuracy: 0.5613, dice_score: 0.1391
[Validation] Iteration 20/20 - loss: 0.8745, accuracy: 0.5739, dice_score: 0.1255
Validation Metrics - loss: 0.8745, accuracy: 0.5739, dice_score: 0.1255
Epoch 3/10
----------
[Train] Iteration 40/40 - loss: 0.8520, accuracy: 0.7530, dice_score: 0.1480
Train Metrics - loss: 0.8520, accuracy: 0.7530, dice_score: 0.1480
[Validation] Iteration 20/20 - loss: 0.8702, accuracy: 0.8909, dice_score: 0.1298
Validation Metrics - loss: 0.8702, accuracy: 0.8909, dice_score: 0.1298
Epoch 4/10
----------
[Train] Iteration 