# Data Augmentation

### Necessary imports

In [1]:
from collections import defaultdict
from dataclasses import dataclass
import io
import json
from pathlib import Path

from matplotlib import pyplot as plt
import random
import numpy as np
from scipy import ndimage
from seaborn import color_palette
from skimage.util import random_noise
import cv2

import nibabel as nib
from nilearn.image import new_img_like, resample_img
from nilearn.plotting import plot_anat, view_img
from niworkflows.viz.utils import cuts_from_bbox, robust_set_limits
from niworkflows.utils.images import rotation2canonical, rotate_affine

from halfpipe.utils.path import split_ext

## Skull Strip Class

In [2]:
def _plot_anat_with_contours(image, segs=None, **plot_params):
    nsegs = len(segs or [])
    plot_params = plot_params or {}
    # plot_params' values can be None, however they MUST NOT
    # be None for colors and levels from this point on.
    colors = plot_params.pop("colors", None) or []
    levels = plot_params.pop("levels", None) or []
    missing = nsegs - len(colors)
    if missing > 0:  # missing may be negative
        colors = colors + color_palette("husl", missing)

    colors = [[c] if not isinstance(c, list) else c for c in colors]

    if not levels:
        levels = [[0.5]] * nsegs

    # anatomical
    display = plot_anat(image, **plot_params)

    # remove plot_anat -specific parameters
    plot_params.pop("display_mode")
    plot_params.pop("cut_coords")

    plot_params["linewidths"] = 0.5
    for i in reversed(range(nsegs)):
        plot_params["colors"] = colors[i]
        display.add_contours(segs[i], levels=levels[i], **plot_params)
        
    return display

In [3]:
target_width = 2048

def to_rgb(display):
    figure = display.frame_axes.figure
    canvas = figure.canvas
    
    # scale to target_width
    width, height = canvas.get_width_height()
    figure.set_dpi(target_width / width * figure.get_dpi())
    
    canvas.draw()
    width, height = canvas.get_width_height()
    
    image = np.frombuffer(
        canvas.tostring_rgb(), dtype=np.uint8
    ).reshape((height, width, -1))[..., :3]
    
    image = image[:, :target_width, :]  # crop rounding errors
    
    return image

In [4]:
@dataclass
class SkullStrip:
    t1w: nib.Nifti1Image
    mask: nib.Nifti1Image
    
    def to_image(self):
        plot_params = dict(colors=None)

        image_nii: nib.Nifti1Image = self.t1w
        seg_nii = self.mask

        canonical_r = rotation2canonical(image_nii)
        image_nii = rotate_affine(image_nii, rot=canonical_r)
        seg_nii = rotate_affine(seg_nii, rot=canonical_r)

        data = image_nii.get_fdata()
        plot_params = robust_set_limits(data, plot_params)

        bbox_nii = seg_nii

        cuts = cuts_from_bbox(bbox_nii, cuts=7)

        images = list()
        for d in plot_params.pop("dimensions", ("z", "x", "y")):
            plot_params["display_mode"] = d
            plot_params["cut_coords"] = cuts[d]
            display = _plot_anat_with_contours(
                image_nii, segs=[seg_nii], **plot_params
            )
            images.append(to_rgb(display))
            display.close()

        image = np.vstack(images)
        return image

### Nifti2Numpy

In [5]:
class NiftiToNumpy():
    def __call__(self,file):
        a = nib.load(file)
        a = nib.as_closest_canonical(a) # transform into RAS as proposed in nibabel documentation
        a = np.array(a.dataobj)
        a = np.float32(a)
        return a

## Augmentation Strategies functions

### Random Rotate Mask

In [6]:
def randomRotateMask(skullstrip: SkullStrip, max_angle: int):
    angle = random.uniform(-max_angle, max_angle)
    
    mask = skullstrip.mask
    mask_data = np.asanyarray(mask.dataobj).astype(bool)
    
    bad_mask_data = ndimage.rotate(mask_data, angle, reshape=False, axes=(0,1), output=float) > 0.5
    if not mask_data.shape == bad_mask_data.shape:
        return None
    bad_mask = new_img_like(mask, bad_mask_data, copy_header=True)
    skullstrip.mask = bad_mask
    
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    # Show images
    plot_anat(bad_mask)
    image = bad_skull_strip.to_image()
    figure = plt.figure(figsize=(20,10))
    plt.imshow(image)
    print(angle)

### Random Flipping

In [7]:
def randomFlippingSkull(t1w_file: Path, mask_file: Path):
    testskull = SkullStrip(t1w=nib.load(t1w_file), mask=nib.load(mask_file))
    a = NiftiToNumpy()
    b = a(t1w_file)
    #t1w = skullstrip.t1w
    #t1w_data = nib.as_closest_canonical(t1w)
    #t1w_data = np.array(t1w_data.dataobj).astype(np.float32)
    plot_anat(nib.load(t1w_file))
    flipped_data = np.fliplr(b)
    
    nifti_mask = nib.load(mask_file)
    ref_affine = nifti_mask.affine
    new_t1w = nib.Nifti1Image(flipped_data, affine=ref_affine)
    testskull.t1w = new_t1w
    #new_skullstrip = SkullStrip(t1w=new_t1w, mask=nifti_mask)
    #show_img = new_skullstrip.to_image()
    new_image = testskull.to_image()
    figure = plt.figure(figsize=(20,10))
    plt.imshow(new_image)

### Increase Contrast

In [8]:
def increaseContrast(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj)
    
    minval = np.percentile(t1w_data, 2)
    maxval = np.percentile(t1w_data, 98)
    pixvals = np.clip(t1w_data, minval, maxval)
    pixvals = ((pixvals - minval) / (maxval - minval)) * 255
    new_t1w = new_img_like(t1w, pixvals, copy_header=True)
    skullstrip.t1w = new_t1w
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    plot_anat(new_t1w)
    image = new_skull_strip.to_image()
    figure = plt.figure(figsize=(20,10))
    plt.imshow(image)

### Random Noise

In [9]:
def randomNoise(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj).astype(np.uint8) # because gaussian noise adds a np.uint8 mask to the img
    
    noise_t1w_data = random_noise(t1w_data, mode='gaussian', var=0.01**2)
    noise_img = (255*noise_t1w_data).astype(np.uint8)

    noise_t1w = new_img_like(t1w, noise_img, copy_header=True)
    
    # skullstrip.t1w = noise_t1w
    # new_skullstrip = skullstrip # save as variable to return as file if necessary
    new_skullstrip = SkullStrip(t1w=noise_t1w, mask=skullstrip.mask)
    plot_anat(noise_t1w)
    image = new_skullstrip.to_image()
    figure = plt.figure(figsize=(20,10))
    plt.imshow(image)

### Remove text and axis from nifti Image (not working yet!)

In [10]:
def removeUnnecessaryText(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj).astype(np.uint8)
    print(t1w_data.shape)
    
    # coeffs = np.array([0.114, 0.587, 0.229])
    # image_gray = (t1w_data.astype(np.float) * coeffs).sum(axis=-1)
    # image_gray = image_gray.astype(t1w_data.dtype)
    #t1w_data = t1w_data[:,:,0]
    #grayscale_img = cv2.cvtColor(t1w_data, cv2.COLOR_BGR2GRAY)
    mask = cv2.threshold(t1w_data, 210, 255, cv2.THRESH_BINARY)[1][:,:,0]
    #mask2 = cv2.threshold(t1w_data, 210, 255, cv2.THRESH_BINARY)
    print(mask.dtype)
    #print(mask2.shape)
    #mask = np.uint8(mask)
    dst = cv2.inpaint(t1w_data[:,:,0], mask, 3, cv2.INPAINT_NS)
    print(dst.shape)
    np.expand_dims(dst, axis=2)
    print(dst.shape)
    #dst.shape = t1w_data.shape
    dst.shape = np.dstack((dst, t1w_data)).shape
    print(dst)
    new_t1w = new_img_like(t1w, dst, copy_header=True)
    skullstrip.t1w = new_t1w
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    plot_anat(new_t1w)

## Load usable files

### Load one test skullstrip

In [11]:
t1wpath = "/Users/dominik/Downloads/usable/ds-integrament_site-berlin_sub-PIABR9GGF_desc-preproc_T1w.nii.gz"
maskpath = "/Users/dominik/Downloads/usable/ds-integrament_site-berlin_sub-PIABR9GGF_desc-brain_mask.nii.gz"
test_skullstrip = SkullStrip(t1w=nib.load(t1wpath),
                              mask=nib.load(maskpath),)

### Load all usable files from folder

In [12]:
base_directory = Path("/Users/dominik/Downloads/usable")

In [13]:
randomFlippingSkull(t1w_file=t1wpath, mask_file=maskpath)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 2048 and the array at index 1 has size 2047

In [92]:
randomNoise(skullstrip=test_skullstrip)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 2048 and the array at index 1 has size 2047