# Test-Driven Development for PyTorch Brain Generator

This notebook implements tests for the PyTorch-based brain generator in a TDD style using try-except blocks instead of unittest.

In [None]:
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('.'))))

# Import the PyTorch brain generator
from SynthSeg.brain_generator_torch import (
    BrainGenerator, 
    BrainGeneratorModel, 
    BrainGeneratorDataset,
    SpatialTransformer,
    RandomSpatialDeformation,
    RandomCrop,
    RandomFlip,
    SampleConditionalGMM,
    BiasFieldCorruption,
    IntensityAugmentation,
    GaussianBlur,
    DynamicGaussianBlur,
    MimicAcquisition,
    ConvertLabels,
    ImageGradients,
    SampleResolution,
    blurring_sigma_for_downsampling
)

## Test Helper Functions

First, let's define some helper functions for our tests.

In [None]:
def run_test(test_func, test_name):
    """Run a test function and report the result."""
    try:
        test_func()
        print(f"✅ {test_name} - PASSED")
        return True
    except AssertionError as e:
        print(f"❌ {test_name} - FAILED: {str(e)}")
        return False
    except Exception as e:
        print(f"❌ {test_name} - ERROR: {str(e)}")
        return False

def create_dummy_label_map(shape=(10, 10, 10), labels=[0, 1, 2, 3]):
    """Create a dummy label map for testing."""
    label_map = np.zeros(shape, dtype=np.int32)
    
    # Create some regions with different labels
    for i, label in enumerate(labels):
        if i == 0:  # Background
            continue
            
        # Create a cube with this label
        start = shape[0] // (len(labels) + 1) * i
        end = start + shape[0] // (len(labels) + 1)
        label_map[start:end, start:end, start:end] = label
    
    return label_map

def create_dummy_dataset(temp_dir):
    """Create a dummy dataset for testing."""
    os.makedirs(temp_dir, exist_ok=True)
    
    # Create a few label maps
    for i in range(3):
        label_map = create_dummy_label_map()
        np.savez(os.path.join(temp_dir, f"label_map_{i}.npz"), vol=label_map)
    
    return temp_dir

## Test Individual Components

Let's test each component of the PyTorch brain generator.

In [None]:
def test_spatial_transformer():
    """Test the SpatialTransformer module."""
    # Create a simple test volume
    volume = torch.zeros(1, 1, 10, 10, 10)
    volume[0, 0, 4:7, 4:7, 4:7] = 1.0
    
    # Create a simple flow field (shift everything by 1 voxel in each dimension)
    flow = torch.ones(1, 3, 10, 10, 10)
    
    # Apply transformation
    transformer = SpatialTransformer((10, 10, 10))
    transformed = transformer(volume, flow)
    
    # Check that the cube has moved
    assert torch.sum(transformed[0, 0, 5:8, 5:8, 5:8]) > 0, "Transformation did not move the cube"
    assert transformed.shape == volume.shape, "Transformed volume has wrong shape"

run_test(test_spatial_transformer, "SpatialTransformer")

In [None]:
def test_random_spatial_deformation():
    """Test the RandomSpatialDeformation module."""
    # Create a simple test volume
    volume = torch.zeros(2, 1, 20, 20, 20)
    volume[:, :, 8:13, 8:13, 8:13] = 1.0
    
    # Apply deformation
    deformer = RandomSpatialDeformation(
        scaling_bounds=0.2,
        rotation_bounds=15,
        shearing_bounds=0.01,
        translation_bounds=2.0,
        nonlin_std=3.0,
        nonlin_scale=0.05
    )
    deformed = deformer(volume)
    
    # Check that the volume has been deformed
    assert torch.sum(torch.abs(deformed - volume)) > 0, "Volume was not deformed"
    assert deformed.shape == volume.shape, "Deformed volume has wrong shape"

run_test(test_random_spatial_deformation, "RandomSpatialDeformation")

In [None]:
def test_random_crop():
    """Test the RandomCrop module."""
    # Create a simple test volume
    volume = torch.zeros(2, 1, 20, 20, 20)
    
    # Apply crop
    cropper = RandomCrop((10, 10, 10))
    cropped = cropper(volume)
    
    # Check that the volume has been cropped
    assert cropped.shape == (2, 1, 10, 10, 10), f"Cropped volume has wrong shape: {cropped.shape}"

run_test(test_random_crop, "RandomCrop")

In [None]:
def test_random_flip():
    """Test the RandomFlip module."""
    # Create a simple test volume with labels
    volume = torch.zeros(2, 1, 10, 10, 10, dtype=torch.long)
    volume[0, 0, 0:5, :, :] = 1  # Left side label 1
    volume[0, 0, 5:10, :, :] = 2  # Right side label 2
    
    # Define generation labels and neutral labels
    generation_labels = torch.tensor([0, 1, 2])
    n_neutral_labels = 1  # Only label 0 is neutral
    
    # Apply flip with 100% probability
    flipper = RandomFlip(0, 1.0, generation_labels, n_neutral_labels)
    flipped = flipper(volume)
    
    # Check that the volume has been flipped and labels swapped
    assert torch.sum(flipped[0, 0, 0:5, :, :] == 2) > 0, "Labels were not swapped correctly"
    assert torch.sum(flipped[0, 0, 5:10, :, :] == 1) > 0, "Labels were not swapped correctly"

run_test(test_random_flip, "RandomFlip")

In [None]:
def test_sample_conditional_gmm():
    """Test the SampleConditionalGMM module."""
    # Create a simple test volume with labels
    labels = torch.zeros(2, 1, 10, 10, 10, dtype=torch.long)
    labels[:, :, 0:5, :, :] = 1  # Region with label 1
    labels[:, :, 5:10, :, :] = 2  # Region with label 2
    
    # Create means and stds
    generation_labels = torch.tensor([0, 1, 2])
    means = torch.tensor([[[0.0, 50.0, 100.0]]], dtype=torch.float32).expand(2, 3, 1)
    stds = torch.tensor([[[0.0, 10.0, 20.0]]], dtype=torch.float32).expand(2, 3, 1)
    
    # Sample from GMM
    sampler = SampleConditionalGMM(generation_labels)
    image = sampler([labels, means, stds])
    
    # Check that the image has been generated with appropriate intensities
    assert image.shape == (2, 1, 10, 10, 10), f"Generated image has wrong shape: {image.shape}"
    assert torch.mean(image[:, :, 0:5, :, :]) > 40, "Region 1 has wrong intensity"
    assert torch.mean(image[:, :, 5:10, :, :]) > 80, "Region 2 has wrong intensity"

run_test(test_sample_conditional_gmm, "SampleConditionalGMM")

In [None]:
def test_bias_field_corruption():
    """Test the BiasFieldCorruption module."""
    # Create a simple test volume
    image = torch.ones(2, 1, 20, 20, 20)
    
    # Apply bias field
    bias = BiasFieldCorruption(bias_field_std=0.5, bias_scale=0.1)
    corrupted = bias(image)
    
    # Check that the image has been corrupted
    assert torch.sum(torch.abs(corrupted - image)) > 0, "Image was not corrupted"
    assert corrupted.shape == image.shape, "Corrupted image has wrong shape"
    
    # Check that the bias field is multiplicative (all values should be positive)
    assert torch.all(corrupted > 0), "Bias field created negative values"

run_test(test_bias_field_corruption, "BiasFieldCorruption")

In [None]:
def test_intensity_augmentation():
    """Test the IntensityAugmentation module."""
    # Create a simple test volume
    image = torch.ones(2, 2, 10, 10, 10) * 100  # Two channels
    
    # Apply intensity augmentation
    aug = IntensityAugmentation(clip=150, normalise=True, gamma_std=0.5, separate_channels=True)
    augmented = aug(image)
    
    # Check that the image has been augmented
    assert augmented.shape == image.shape, "Augmented image has wrong shape"
    
    # Check normalization
    assert torch.all(augmented >= 0) and torch.all(augmented <= 1), "Image was not normalized to [0, 1]"

run_test(test_intensity_augmentation, "IntensityAugmentation")

In [None]:
def test_gaussian_blur():
    """Test the GaussianBlur module."""
    # Create a simple test volume with a point source
    image = torch.zeros(1, 1, 15, 15, 15)
    image[0, 0, 7, 7, 7] = 1.0
    
    # Apply Gaussian blur
    blur = GaussianBlur(sigma=2.0)
    blurred = blur(image)
    
    # Check that the image has been blurred
    assert blurred.shape == image.shape, "Blurred image has wrong shape"
    assert torch.sum(blurred > 0) > 1, "Image was not blurred"
    assert torch.max(blurred) < 1.0, "Blur should reduce peak intensity"

run_test(test_gaussian_blur, "GaussianBlur")

In [None]:
def test_dynamic_gaussian_blur():
    """Test the DynamicGaussianBlur module."""
    # Create a simple test volume with a point source
    image = torch.zeros(2, 1, 15, 15, 15)
    image[:, 0, 7, 7, 7] = 1.0
    
    # Create sigma values
    sigma = torch.tensor([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]])
    
    # Apply dynamic Gaussian blur
    blur = DynamicGaussianBlur(max_sigma=3.0)
    blurred = blur([image, sigma])
    
    # Check that the image has been blurred
    assert blurred.shape == image.shape, "Blurred image has wrong shape"
    
    # First batch should be blurred
    assert torch.sum(blurred[0] > 0) > 1, "First batch was not blurred"
    assert torch.max(blurred[0]) < 1.0, "Blur should reduce peak intensity"
    
    # Second batch should not be blurred (sigma = 0)
    assert torch.sum(blurred[1] > 0) == 1, "Second batch should not be blurred"
    assert torch.max(blurred[1]) == 1.0, "Second batch should not be blurred"

run_test(test_dynamic_gaussian_blur, "DynamicGaussianBlur")

In [None]:
def test_mimic_acquisition():
    """Test the MimicAcquisition module."""
    # Create a simple test volume
    image = torch.ones(2, 1, 20, 20, 20)
    
    # Create resolution
    resolution = torch.tensor([[2.0, 2.0, 2.0], [1.0, 1.0, 1.0]])
    
    # Apply acquisition mimicking
    mimic = MimicAcquisition(atlas_res=[1.0, 1.0, 1.0], data_res=[1.0, 1.0, 1.0], output_shape=[10, 10, 10], downsample=True)
    processed = mimic([image, resolution])
    
    # Check that the image has been processed
    assert processed.shape == (2, 1, 10, 10, 10), f"Processed image has wrong shape: {processed.shape}"

run_test(test_mimic_acquisition, "MimicAcquisition")

In [None]:
def test_convert_labels():
    """Test the ConvertLabels module."""
    # Create a simple test volume with labels
    labels = torch.zeros(2, 1, 10, 10, 10, dtype=torch.long)
    labels[:, :, 0:5, :, :] = 1  # Region with label 1
    labels[:, :, 5:10, :, :] = 2  # Region with label 2
    
    # Define mapping
    source_values = torch.tensor([0, 1, 2])
    dest_values = torch.tensor([0, 10, 20])
    
    # Apply label conversion
    converter = ConvertLabels(source_values, dest_values)
    converted = converter(labels)
    
    # Check that the labels have been converted
    assert converted.shape == labels.shape, "Converted labels have wrong shape"
    assert torch.all(converted[:, :, 0:5, :, :] == 10), "Label 1 was not converted to 10"
    assert torch.all(converted[:, :, 5:10, :, :] == 20), "Label 2 was not converted to 20"

run_test(test_convert_labels, "ConvertLabels")

In [None]:
def test_image_gradients():
    """Test the ImageGradients module."""
    # Create a simple test volume with a gradient
    image = torch.zeros(1, 1, 10, 10, 10)
    for i in range(10):
        image[0, 0, i, :, :] = i / 9.0
    
    # Apply gradient computation
    grad = ImageGradients('sobel', True)
    gradient = grad(image)
    
    # Check that the gradient has been computed
    assert gradient.shape == image.shape, "Gradient has wrong shape"
    assert torch.sum(gradient > 0) > 0, "No gradient was detected"

run_test(test_image_gradients, "ImageGradients")

In [None]:
def test_sample_resolution():
    """Test the SampleResolution module."""
    # Create a dummy input
    dummy = torch.zeros(2, 1, 1, 1, 1)
    
    # Sample resolution
    sampler = SampleResolution([1.0, 1.0, 1.0], [4.0, 4.0, 4.0], [8.0, 8.0, 8.0])
    resolution, blur_res = sampler(dummy)
    
    # Check that the resolution has been sampled
    assert resolution.shape == (2, 3), "Resolution has wrong shape"
    assert blur_res.shape == (2, 3), "Blur resolution has wrong shape"
    
    # Check that the resolution is within bounds
    assert torch.all(resolution >= 1.0) and torch.all(resolution <= 8.0), "Resolution is out of bounds"

run_test(test_sample_resolution, "SampleResolution")

In [None]:
def test_blurring_sigma_for_downsampling():
    """Test the blurring_sigma_for_downsampling function."""
    # Calculate sigma
    sigma = blurring_sigma_for_downsampling([1.0, 1.0, 1.0], [2.0, 2.0, 2.0])
    
    # Check that sigma is positive
    assert torch.all(sigma > 0), "Sigma should be positive for downsampling"
    
    # Calculate sigma for upsampling (should be 0)
    sigma = blurring_sigma_for_downsampling([2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
    
    # Check that sigma is 0
    assert torch.all(sigma == 0), "Sigma should be 0 for upsampling"

run_test(test_blurring_sigma_for_downsampling, "blurring_sigma_for_downsampling")

## Test BrainGeneratorModel

Now let's test the BrainGeneratorModel.

In [None]:
def test_brain_generator_model():
    """Test the BrainGeneratorModel."""
    # Create a simple model
    model = BrainGeneratorModel(
        labels_shape=[20, 20, 20],
        n_channels=1,
        generation_labels=torch.tensor([0, 1, 2]),
        output_labels=torch.tensor([0, 1, 2]),
        n_neutral_labels=1,
        atlas_res=[1.0, 1.0, 1.0],
        target_res=[1.0, 1.0, 1.0],
        output_shape=[16, 16, 16],
        flipping=True,
        scaling_bounds=0.1,
        rotation_bounds=10,
        nonlin_std=2.0,
        bias_field_std=0.3,
        randomise_res=True
    )
    
    # Create inputs
    label_map = torch.zeros(2, 1, 20, 20, 20, dtype=torch.long)
    label_map[:, :, 0:10, :, :] = 1
    label_map[:, :, 10:20, :, :] = 2
    
    means = torch.tensor([[[0.0, 50.0, 100.0]]], dtype=torch.float32).expand(2, 3, 1)
    stds = torch.tensor([[[0.0, 10.0, 20.0]]], dtype=torch.float32).expand(2, 3, 1)
    
    inputs = {
        'label_map': label_map,
        'means': means,
        'stds': stds
    }
    
    # Generate image and labels
    image, labels = model(inputs)
    
    # Check that the outputs have the correct shape
    assert image.shape == (2, 1, 16, 16, 16), f"Generated image has wrong shape: {image.shape}"
    assert labels.shape == (2, 1, 16, 16, 16), f"Generated labels have wrong shape: {labels.shape}"
    
    # Check that the image has appropriate intensity values
    assert torch.min(image) >= 0 and torch.max(image) <= 1, "Image values should be in [0, 1]"
    
    # Check that the labels have the correct values
    unique_labels = torch.unique(labels)
    assert all(label in [0, 1, 2] for label in unique_labels), f"Labels contain unexpected values: {unique_labels}"

run_test(test_brain_generator_model, "BrainGeneratorModel")

## Test BrainGenerator

Finally, let's test the main BrainGenerator class.

In [None]:
def test_brain_generator():
    """Test the BrainGenerator class."""
    # Create a temporary directory for test data
    import tempfile
    temp_dir = tempfile.mkdtemp()
    
    try:
        # Create dummy dataset
        labels_dir = create_dummy_dataset(temp_dir)
        
        # Create brain generator
        generator = BrainGenerator(
            labels_dir=labels_dir,
            batchsize=2,
            n_channels=1,
            flipping=True,
            scaling_bounds=0.1,
            rotation_bounds=10,
            nonlin_std=2.0,
            bias_field_std=0.3,
            randomise_res=True,
            device='cpu'  # Use CPU for testing
        )
        
        # Generate brain
        image, labels = generator.generate_brain()
        
        # Check that the outputs have the correct shape
        assert image.shape[0] == 2, f"Generated image has wrong batch size: {image.shape}"
        assert labels.shape[0] == 2, f"Generated labels have wrong batch size: {labels.shape}"
        
        # Check that the image has appropriate intensity values
        assert np.min(image) >= 0 and np.max(image) <= 1, "Image values should be in [0, 1]"
        
    finally:
        # Clean up
        import shutil
        shutil.rmtree(temp_dir)

run_test(test_brain_generator, "BrainGenerator")

## Summary

Let's summarize the test results.

In [None]:
print("\nTest Summary:")
print("=============")
print("All tests have been run. Check the results above to see if any tests failed.")
print("If all tests passed, the PyTorch brain generator is working correctly.")