In [7]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchio as tio

from models.Unet import UNet3D

In [2]:
# Define the dataset
def prepare_dataset(image_paths, label_paths):
    """
    Prepare a TorchIO dataset of Subjects for 3D segmentation.

    Args:
        image_paths (list): List of paths to MRI images.
        label_paths (list): List of paths to segmentation labels.

    Returns:
        torchio.SubjectsDataset: A TorchIO dataset.
    """
    subjects = []
    for img_path, label_path in zip(image_paths, label_paths):
        subject = tio.Subject(
            image=tio.ScalarImage(img_path),
            label=tio.LabelMap(label_path)
        )
        subjects.append(subject)
    return tio.SubjectsDataset(subjects)

In [23]:
def create_patch_loader(dataset, patch_size, batch_size, patch_overlap=(16, 16, 16), num_workers=0):
    """
    Create a DataLoader for patch-based segmentation.

    Args:
        dataset (torchio.SubjectsDataset): TorchIO dataset.
        patch_size (tuple): Size of each patch, e.g., (64, 64, 64).
        batch_size (int): Number of patches in a batch.
        patch_overlap (tuple): Overlap between patches.
        num_workers (int): Number of workers for DataLoader.

    Returns:
        DataLoader, list: Loader for patches and a list of aggregators for each subject.
    """
    patch_loaders = []
    aggregators = []

    for subject in dataset:  # Iterate over each subject in the dataset
        # Create a GridSampler for the subject
        sampler = tio.GridSampler(
            subject,
            patch_size=patch_size,
            patch_overlap=patch_overlap
        )

        # Create an aggregator for reassembling the volume
        aggregator = tio.GridAggregator(sampler, overlap_mode='average')

        # Add the sampler's patches to the DataLoader
        loader = DataLoader(
            sampler,
            batch_size=batch_size,
            num_workers=num_workers
        )

        patch_loaders.append(loader)
        aggregators.append(aggregator)

    return patch_loaders, aggregators

## Transforms

In [4]:
# Define data augmentations
train_transform = tio.Compose([
    tio.RandomFlip(axes=(0, 1, 2), flip_probability=0.5),
    tio.RandomAffine(scales=(0.9, 1.1), degrees=15, isotropic=True),
    tio.RandomElasticDeformation(num_control_points=5, max_displacement=7.5),
    tio.ZNormalization()  # Normalize image intensities
])

val_transform = tio.Compose([
    tio.ZNormalization()
])

In [28]:
def train_model(loader, model, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for patches in loader:
        images = patches['image'][tio.DATA].to(device)  # Access image tensor
        labels = patches['mask'][tio.DATA].to(device)   # Access mask tensor

        labels = labels.squeeze(1).long()  # Convert to [B, D, H, W] for CrossEntropyLoss if needed

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


def validate_model(loader, model, criterion, device, aggregator):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for patches in loader:
            images = patches['image'][tio.DATA].to(device)
            labels = patches['mask'][tio.DATA].to(device)
            labels = labels.squeeze(1).long()

            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Add the model's predictions to the aggregator
            aggregator.add_batch(outputs, patches)

    val_loss = total_loss / len(loader)
    val_prediction = aggregator.get_output_tensor()  # Reconstruct full volume from patches
    return val_loss, val_prediction


## Training

In [9]:
DATA_DIR = '/root/huy/BrainSegmentation/Data'

In [24]:
# File paths
train_images, train_labels = get_data_paths(os.path.join(DATA_DIR, 'train'))
val_images, val_labels = get_data_paths(os.path.join(DATA_DIR, 'val'))

# Prepare datasets
train_dataset = prepare_dataset(train_images, train_labels)
val_dataset = prepare_dataset(val_images, val_labels)

In [20]:
train_dataset[0]['image']

ScalarImage(shape: (1, 256, 128, 256); spacing: (0.94, 1.50, 0.94); orientation: RAS+; dtype: torch.ShortTensor; memory: 16.0 MiB)

In [25]:
# Patch-based loaders
patch_size = (64, 64, 64)
batch_size = 4
train_loader, _ = create_patch_loader(train_dataset, patch_size, batch_size)
val_loader, val_aggregator = create_patch_loader(val_dataset, patch_size, batch_size)

In [35]:
# Get the first batch
first_batch = next(iter(train_loader[0]))

# Check the type of the batch
print(type(first_batch))

# If it's a TorchIO Subject or similar, inspect its attributes
if isinstance(first_batch, dict):
    print(first_batch.keys())
    print(first_batch['image'].shape)  # Access the image tensor
    print(first_batch['mask'].shape)   # Access the mask tensor
elif isinstance(first_batch, tio.Subject):
    # Iterate through the attributes in the Subject
    for key, value in first_batch.items():
        print(f"{key}: {value.shape if isinstance(value, torch.Tensor) else type(value)}")
else:
    print("Unknown batch structure:", first_batch)



<class 'torchio.data.subject.Subject'>
dict_keys(['image', 'label', 'location'])


ValueError: too many values to unpack (expected 4)

In [29]:
# Model, optimizer, and loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet3D(in_channels=1, out_channels=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training
epochs = 10
for epoch in range(epochs):
    train_loss = train_model(train_loader, model, optimizer, criterion, device)
    val_loss, val_prediction = validate_model(val_loader, model, criterion, device, val_aggregator)

    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Save model
torch.save(model.state_dict(), "3d_unet_model.pth")

TypeError: 'DataLoader' object is not subscriptable