In [27]:
from pathlib import Path
import torchio as tio

labelmap_paths = sorted(Path('labelmaps').glob('*.nii.gz'))
subjects = []

for path in labelmap_paths:
    subject = tio.Subject(
        labelmap=tio.LabelMap(path)
    )
    subjects.append(subject)

dataset = tio.SubjectsDataset(subjects)


In [28]:
class GenerateSyntheticMRI(tio.Transform):
    def __init__(self, label_range=(1, 6), shared_mean=100.0, shared_std=10.0):
        super().__init__()
        self.label_range = label_range
        self.shared_mean = float(shared_mean)
        self.shared_std = float(shared_std)

    def apply_transform(self, subject):
        label_tensor = subject['labelmap'].data.clone()
        synthetic_tensor = torch.zeros_like(label_tensor).float()  # Force float32

        labels = torch.unique(label_tensor).tolist()
        for label in labels:
            if label == 0:
                continue
            mask = label_tensor == label

            # Create intensity values
            if self.label_range[0] <= label <= self.label_range[1]:
                values = torch.randn_like(mask.float()) * self.shared_std + self.shared_mean
            else:
                mean = float(np.random.uniform(50, 150))
                std = float(np.random.uniform(5, 20))
                values = torch.randn_like(mask.float()) * std + mean

            values = values.float()  # Explicitly cast to float32
            synthetic_tensor[mask] = values[mask]

        subject['image'] = tio.ScalarImage(tensor=synthetic_tensor, affine=subject['labelmap'].affine)
        return subject


In [29]:
class RandomResample(tio.Transform):
    def __init__(self, min_spacing=(.2, .2, .2), max_spacing=(1.0, 1.0, 1.0), p=1.0):
        super().__init__(p=p)
        self.min_spacing = np.array(min_spacing)
        self.max_spacing = np.array(max_spacing)

    def apply_transform(self, subject):
        random_spacing = np.random.uniform(self.min_spacing, self.max_spacing)
        resample = tio.Resample(random_spacing, image_interpolation='linear', exclude=['labelmap'])
        return resample(subject)


In [30]:
transform = tio.Compose([
    # 1. Anatomical deformations applied to labelmap (and carried over to synthetic image)
    tio.RandomAffine(scales=(0.95, 1.05), degrees=10, translation=5, p=0.75),
    tio.RandomElasticDeformation(num_control_points=7, max_displacement=7.5, p=0.5),

    # 2. Generate synthetic image from deformed labelmap
    GenerateSyntheticMRI(),

    # 3. Simulate acquisition: spatial resolution loss only to image
    RandomResample( min_spacing=(.2, .2, .2), max_spacing=(1.0, 1.0, 1.0), p=1.0),

    # 4. MRI acquisition-like artifacts (image only)
    tio.RandomMotion(p=0.3, exclude=['labelmap']),
    tio.RandomGhosting(p=0.2, exclude=['labelmap']),
    tio.RandomSpike(p=0.2, exclude=['labelmap']),
    tio.RandomBiasField(p=0.3, exclude=['labelmap']),
    tio.RandomNoise(p=0.2, mean=0.0, std=(0, 0.05), exclude=['labelmap']),
    tio.RandomBlur(p=0.2, std=(0.25, 1.0), exclude=['labelmap']),
    tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.2, exclude=['labelmap']),
    tio.RandomAnisotropy(p=0.2, downsampling=(1, 4), exclude=['labelmap']),
])


In [31]:
augmented_dataset = tio.SubjectsDataset(subjects, transform=transform)
loader = torch.utils.data.DataLoader(augmented_dataset, batch_size=1, shuffle=True)


In [32]:
from pathlib import Path

def save_samples(dataset, output_dir='output_samples', n=10):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for i, subject in enumerate(dataset):
        if i >= n:
            break
        subject['image'].save(output_dir / f'image_{i:03}.nii.gz')
        subject['labelmap'].save(output_dir / f'label_{i:03}.nii.gz')


In [33]:
save_samples(augmented_dataset, n=10)


  spectrum[i, j, k] += artifact
