In [1]:
import os
import torch
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils.dataset import BrainMRISliceDataset
from utils.utils import train, validate
from utils.metric import MetricsMonitor
from utils.vis import plot_mri

## Constants

In [22]:
ROOT_DIR = '../Data/'
BATCH_SIZE = 16
EPOCHS = 10
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Transforms

In [16]:
train_transform = A.Compose([
    A.Resize(256, 256),  # Resize both image and mask
    A.HorizontalFlip(p=0.5),  # Random horizontal flip
    A.RandomBrightnessContrast(p=0.2),  # Adjust brightness/contrast for images
    ToTensorV2()  # Convert to PyTorch tensors
], additional_targets={'mask': 'mask'})  # Specify the target name for the label

test_transform = A.Compose([
    A.Resize(256, 256),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

In [31]:
train_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'train'), slice_axis=2, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

val_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=2, transform=test_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [18]:
train_dataset[0][1].shape

torch.Size([3, 256, 256])

## Models

In [19]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

## Freeze encoder and bottleneck layers
for name, parameter in model.named_parameters():
    if 'decoder' not in name:
        parameter.requires_grad = False

model._modules['conv'] = torch.nn.Conv2d(32, 3, kernel_size=1, stride=1)
model = model.to(DEVICE)

Using cache found in /Users/huytrq/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


## Loss & Optimizer

In [20]:
criteria = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Training

In [21]:
# Monitors
train_monitor = MetricsMonitor(metrics=["loss", "accuracy"])
val_monitor = MetricsMonitor(
    metrics=["loss", "accuracy"], patience=5, mode="max"
)
test_monitor = MetricsMonitor(metrics=["loss", "accuracy"])

In [None]:
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    print("-" * 10)
    train(model, train_loader, criteria, optimizer, DEVICE, train_monitor)
    validate(model, val_loader, criteria, DEVICE, val_monitor)

Epoch 1/10
----------
[Train] Iteration 160/160 - loss: 0.8950, accuracy: 0.8106
Train Metrics - loss: 0.8950, accuracy: 0.8106

Validation Metrics - loss: 0.0000, accuracy: 0.0000
Epoch 2/10
----------
[Train] Iteration 69/160 - loss: 0.8573, accuracy: 0.8331

KeyboardInterrupt: 

In [33]:
validate(model, val_loader, criteria, DEVICE, val_monitor)

[Validation] Iteration 80/80 - loss: 0.8512, accuracy: 0.8919
Validation Metrics - loss: 0.8512, accuracy: 0.8919


{'loss': 0.8512429669499397, 'accuracy': 0.8918538331985474}