In [1]:
from typing import List

from tqdm import tqdm
import numpy as np
from nilearn.datasets import fetch_neurovault
from nibabel import Nifti1Image
from nilearn import plotting
from torchio import transforms
import torchio as tio

from ai4sipmbda.utils import fetching, difumo_utils

In [3]:
data = fetching.fetch_nv("/storage/store2/data/", max_images=3)

3 images found on local disk.


In [7]:
type(data)

sklearn.utils.Bunch

In [4]:
AUGMENTATIONS = {
        "RandomElasticDeformation": tio.RandomElasticDeformation(num_control_points = 16, max_displacement = 2),
        "RandomMotion":tio.RandomMotion(degrees = 0.2, translation = 0.2, num_transforms = 2),
        "RandomGhosting": tio.RandomGhosting(num_ghosts = 1, intensity = 0.02, restore = 1.0),
        "RandomSpike": tio.RandomSpike(num_spikes = 2, intensity = 1.15),
        "RandomBiasField": tio.RandomBiasField(order = 1,coefficients=0.05 ),
        "RandomBlur": tio.RandomBlur(std = 1.05),
        "RandomNoise":tio.RandomNoise(mean = 0.3, std = 0.5),
        "RandomGamma": tio.RandomGamma(log_gamma=0.075),
        "RandomFlip": tio.RandomFlip(flip_probability=1.0),
        "None": None,
    }

In [6]:
def create_augmentation(
    aug_names: List[str],
) -> transforms.Transform:
    augmentation_list = [AUGMENTATIONS[aug] for aug in aug_names]
    return tio.transforms.OneOf(augmentation_list)

In [7]:
flip = create_augmentation(["RandomFlip"])

In [9]:
Z_inv = np.load("hcp900_difumo_matrices/Zinv.npy")
mask = np.load("hcp900_difumo_matrices/mask.npy")

In [13]:
def transform_based_augmentation(augmentation, images_paths, Y=None, nb_fakes=10):
    X = list()
    for index, image_path in tqdm(enumerate(images_paths)):
        image_tio = tio.ScalarImage(image_path)

        for _ in range(nb_fakes):
            # transform
            trf_img_tio = augmentation(image_tio)

            # project
            trf_difumo_vec = Z_inv.dot(trf_img_tio.data.squeeze()[mask])

            X.append(trf_difumo_vec)  # XXX Shoudl transform here?

    return np.vstack(X)

In [14]:
augmented_X = transform_based_augmentation(flip, data["images"], Y=None, nb_fakes=2)

3it [00:03,  1.15s/it]


In [15]:
augmented_X.shape

(6, 1024)