In [1]:
import os
import torch
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils.dataset import BrainMRISliceDataset
from utils.transforms import TransformWithLabels
from utils.utils import train, validate
from utils.vis import plot_mri

## Constants

In [2]:
ROOT_DIR = '../Data/'
BATCH_SIZE = 1
DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

## Transforms

In [3]:
train_transform = A.Compose([
    A.Resize(256, 256),  # Resize both image and mask
    A.HorizontalFlip(p=0.5),  # Random horizontal flip
    A.RandomBrightnessContrast(p=0.2),  # Adjust brightness/contrast for images
    ToTensorV2()  # Convert to PyTorch tensors
], additional_targets={'mask': 'mask'})  # Specify the target name for the label

test_transform = A.Compose([
    A.Resize(256, 256),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

In [4]:
train_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'train'), slice_axis=2, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = BrainMRISliceDataset(os.path.join(ROOT_DIR, 'val'), slice_axis=2, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
train_dataset[150][1].shape

torch.Size([256, 256, 1])

## Models

## Loss & Optimizer