In [53]:
from collections import defaultdict
from dataclasses import dataclass
import io
import json
from pathlib import Path
import os

from typing import Literal
from matplotlib import pyplot as plt
import random
import numpy as np
from scipy import ndimage
from seaborn import color_palette
from skimage.util import random_noise
from skimage.transform import resize
import cv2
from PIL import Image

import nibabel as nib
from nilearn.image import new_img_like, resample_img
from nilearn.plotting import plot_anat, view_img
from niworkflows.viz.utils import cuts_from_bbox, robust_set_limits
from niworkflows.utils.images import rotation2canonical, rotate_affine

from halfpipe.utils.path import split_ext

In [2]:
%matplotlib inline

In [3]:
base_directory = Path("/mnt/mbServerData/newdata/moods/halfpipe/bonn")
#Path("/Users/dominik/Downloads/usable")
#Path("/mnt/mbServerData/newdata/moods/halfpipe/bonn")

In [4]:
exclude_files = list(base_directory.glob("exclude*.json"))
exclude_entries = list()

for exclude_file in exclude_files:
    with open(exclude_file) as file_handle:
        exclude_entries.extend(json.load(file_handle))

In [5]:
paths_by_tags = defaultdict(lambda: defaultdict(set))
tags_by_paths = dict()

bids_directory = base_directory / "derivatives" / "fmriprep"
for bids_path in bids_directory.glob("**/*"):
    stem, extension = split_ext(bids_path)

    if stem.startswith("."):
        continue  # skip hidden files

    tokens = stem.split("_")

    tags = dict(
        path=str(bids_directory), 
        suffix=tokens.pop(-1),
        extension=extension,
    )
    tags_by_paths[bids_path] = tags

    for token in tokens:
        key, value = token.split("-")
        tags[key] = value
    for key, value in tags.items():
        paths_by_tags[key][value].add(bids_path)

In [6]:
def get(**filters):
    res = None
    
    for key, value in filters.items():
        if key not in paths_by_tags:
            # logger.info(f"Unknown key \"{key}\"")        
            return None
    
        values = paths_by_tags[key]
        if value is None:
            for paths in values.values():
                res -= paths
            continue
                
        elif value not in values:
            # logger.info(f"Unknown value \"{value}\"")        
            return None
        
        paths = values[value]
        if res is not None:
            res &= paths
        else:
            res = paths.copy()
    
    return res

def get_tag_value(path, key):
    return tags_by_paths[path].get(key)

In [7]:
def _plot_anat_with_contours(image, segs=None, **plot_params):
    nsegs = len(segs or [])
    plot_params = plot_params or {}
    # plot_params' values can be None, however they MUST NOT
    # be None for colors and levels from this point on.
    colors = plot_params.pop("colors", None) or []
    levels = plot_params.pop("levels", None) or []
    missing = nsegs - len(colors)
    if missing > 0:  # missing may be negative
        colors = colors + color_palette("husl", missing)

    colors = [[c] if not isinstance(c, list) else c for c in colors]

    if not levels:
        levels = [[0.5]] * nsegs

    # anatomical
    display = plot_anat(image, **plot_params)

    # remove plot_anat -specific parameters
    plot_params.pop("display_mode")
    plot_params.pop("cut_coords")

    plot_params["linewidths"] = 0.5
    for i in reversed(range(nsegs)):
        plot_params["colors"] = colors[i]
        display.add_contours(segs[i], levels=levels[i], **plot_params)
        
    return display

In [8]:
target_width = 2048

def to_rgb(display):
    figure = display.frame_axes.figure
    canvas = figure.canvas
    
    # scale to target_width
    width, height = canvas.get_width_height()
    figure.set_dpi(target_width / width * figure.get_dpi())
    
    canvas.draw()
    width, height = canvas.get_width_height()
    
    image = np.frombuffer(
        canvas.tostring_rgb(), dtype=np.uint8
    ).reshape((height, width, -1))[..., :3]
    
    image = image[:, :target_width, :]  # crop rounding errors
    
    return image

In [54]:
@dataclass
class SkullStrip:
    t1w: nib.Nifti1Image
    mask: nib.Nifti1Image
    label: Literal["usable", "unusable"]
    def to_image(self):
        plot_params = dict(colors=None)

        image_nii: nib.Nifti1Image = self.t1w
        seg_nii = self.mask

        canonical_r = rotation2canonical(image_nii)
        image_nii = rotate_affine(image_nii, rot=canonical_r)
        seg_nii = rotate_affine(seg_nii, rot=canonical_r)

        data = image_nii.get_fdata()
        plot_params = robust_set_limits(data, plot_params)

        bbox_nii = seg_nii

        cuts = cuts_from_bbox(bbox_nii, cuts=7)

        images = list()
        for d in plot_params.pop("dimensions", ("z", "x", "y")):
            plot_params["display_mode"] = d
            plot_params["cut_coords"] = cuts[d]
            display = _plot_anat_with_contours(
                image_nii, segs=[seg_nii], **plot_params
            )
            images.append(to_rgb(display))
            display.close()

        image = np.vstack(images)
        return image

In [10]:
skull_strips = list()

for exclude_entry in exclude_entries:
    rating = exclude_entry["rating"]
    type = exclude_entry["type"]
    sub = exclude_entry["sub"]
    
    if rating != "good":
        continue
        
    if type != "skull_strip_report":
        continue
        
    t1w_files = get(sub=sub, desc="preproc", res=None, suffix="T1w", extension=".nii.gz")
    if t1w_files is None:
        continue
    (t1w_file,) = t1w_files
    
    mask_files = get(sub=sub, desc="brain", res=None, suffix="mask", extension=".nii.gz")
    if mask_files is None:
        continue
    (mask_file,) = mask_files   
        
    skull_strips.append(
        SkullStrip(
            t1w=nib.load(t1w_file),
            mask=nib.load(mask_file),
        )
    )

### Nifti2Numpy

In [11]:
class NiftiToNumpy():
    def __call__(self,file):
        a = nib.load(file)
        a = nib.as_closest_canonical(a) # transform into RAS as proposed in nibabel documentation
        a = np.array(a.dataobj)
        a = np.float32(a)
        return a

## Augmentation Strategies functions

### Random Rotate Mask

In [43]:
def randomRotateMask(skullstrip: SkullStrip, max_angle: int):
    angle = random.uniform(-max_angle, max_angle)
    
    mask = skullstrip.mask
    mask_data = np.asanyarray(mask.dataobj).astype(bool)
    
    bad_mask_data = ndimage.rotate(mask_data, angle, reshape=False, axes=(0,1), output=float) > 0.5
    if not mask_data.shape == bad_mask_data.shape:
        return None
    bad_mask = new_img_like(mask, bad_mask_data, copy_header=True)
    skullstrip.mask = bad_mask
    
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    # Show images
    #plot_anat(bad_mask)
    #image = new_skullstrip.to_image()
    #figure = plt.figure(figsize=(20,10))
    #plt.imshow(image)
    #print(angle)
    return new_skullstrip

### Random Flipping

In [42]:
def randomSkullFlip(t1w_file: Path, mask_file: Path):
    testskull = SkullStrip(t1w=nib.load(t1w_file), mask=nib.load(mask_file))
    a = NiftiToNumpy()
    b = a(t1w_file)
    #t1w = skullstrip.t1w
    #t1w_data = nib.as_closest_canonical(t1w)
    #t1w_data = np.array(t1w_data.dataobj).astype(np.float32)
    #plot_anat(nib.load(t1w_file))
    flipped_data = np.fliplr(b) # can change to ud and maybe on different axis (0,1)(1,2)(0,2)
    
    nifti_mask = nib.load(mask_file)
    ref_affine = nifti_mask.affine
    new_t1w = nib.Nifti1Image(flipped_data, affine=ref_affine)
    testskull.t1w = new_t1w
    return testskull
    #new_image = testskull.to_image()
    #figure = plt.figure(figsize=(20,10))
    #plt.imshow(new_image)

### Increase Contrast

In [39]:
def increaseContrast(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    #plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj)
    
    minval = np.percentile(t1w_data, 2)
    maxval = np.percentile(t1w_data, 98)
    pixvals = np.clip(t1w_data, minval, maxval)
    pixvals = ((pixvals - minval) / (maxval - minval)) * 255
    new_t1w = new_img_like(t1w, pixvals, copy_header=True)
    skullstrip.t1w = new_t1w
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    #plot_anat(new_t1w)
    #image = new_skullstrip.to_image()
    #figure = plt.figure(figsize=(20,10))
    #plt.imshow(image)
    return new_skullstrip

### Random Noise


In [40]:
def randomNoise(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    #plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj).astype(np.uint8) # because gaussian noise adds a np.uint8 mask to the img
    
    noise_t1w_data = random_noise(t1w_data, mode='gaussian', var=0.01**2)
    noise_img = (255*noise_t1w_data).astype(np.uint8)

    noise_t1w = new_img_like(t1w, noise_img, copy_header=True)
    
    # skullstrip.t1w = noise_t1w
    # new_skullstrip = skullstrip # save as variable to return as file if necessary
    new_skullstrip = SkullStrip(t1w=noise_t1w, mask=skullstrip.mask)
    #plot_anat(noise_t1w)
    #image = new_skullstrip.to_image()
    #figure = plt.figure(figsize=(20,10))
    #plt.imshow(image)
    return new_skullstrip

### Remove text and axis from nifti Image (not working yet!)

In [16]:
def removeUnnecessaryText(skullstrip: SkullStrip):
    t1w = skullstrip.t1w
    plot_anat(t1w)
    t1w_data = np.asanyarray(t1w.dataobj).astype(np.uint8)
    print(t1w_data.shape)
    
    # coeffs = np.array([0.114, 0.587, 0.229])
    # image_gray = (t1w_data.astype(np.float) * coeffs).sum(axis=-1)
    # image_gray = image_gray.astype(t1w_data.dtype)
    #t1w_data = t1w_data[:,:,0]
    #grayscale_img = cv2.cvtColor(t1w_data, cv2.COLOR_BGR2GRAY)
    mask = cv2.threshold(t1w_data, 210, 255, cv2.THRESH_BINARY)[1][:,:,0]
    #mask2 = cv2.threshold(t1w_data, 210, 255, cv2.THRESH_BINARY)
    print(mask.dtype)
    #print(mask2.shape)
    #mask = np.uint8(mask)
    dst = cv2.inpaint(t1w_data[:,:,0], mask, 3, cv2.INPAINT_NS)
    print(dst.shape)
    np.expand_dims(dst, axis=2)
    print(dst.shape)
    #dst.shape = t1w_data.shape
    dst.shape = np.dstack((dst, t1w_data)).shape
    print(dst)
    new_t1w = new_img_like(t1w, dst, copy_header=True)
    skullstrip.t1w = new_t1w
    new_skullstrip = skullstrip # save as variable to return as file if necessary
    plot_anat(new_t1w)

## Load usable files


### Load one test skullstrip

In [17]:
t1wpath = "/Users/dominik/Downloads/usable/ds-integrament_site-berlin_sub-PIABR9GGF_desc-preproc_T1w.nii.gz"
maskpath = "/Users/dominik/Downloads/usable/ds-integrament_site-berlin_sub-PIABR9GGF_desc-brain_mask.nii.gz"
test_skullstrip = SkullStrip(t1w=nib.load(t1wpath),
                              mask=nib.load(maskpath),)

In [18]:
#randomNoise(skullstrip=test_skullstrip)

In [19]:
#increaseContrast(skullstrip=test_skullstrip)

In [20]:
#randomRotateMask(skullstrip=test_skullstrip, max_angle=90)

In [21]:
#randomSkullFlip(t1w_file=t1wpath, mask_file=maskpath)

In [22]:
#skull_strip = test_skullstrip
#image = skull_strip.to_image()
#print(image.shape)
#print(type(image))
#print(image.dtype)
#new_img = resize(image, (227,227,256))
#print(new_img.shape)
#print(new_img.dtype)
#u_img = np.copy(new_img).astype('uint8')
#print(u_img.dtype)
#figure = plt.figure(figsize=(20,10))
#plt.imshow(image)

In [23]:
def randomRotation(image, max_angle): # add types
    # randomly rotate image at random angle
    # add random selection on which axis
    
    # rotate along z-axis (xy plane)
    angle = random.uniform(-max_angle, max_angle)
    image1 = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(0, 1), reshape=False)

    # rotate along y-axis (xz plane)
    #angle = random.uniform(-max_angle, max_angle)
    #image = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(0, 2), reshape=False)

    # rotate along x-axis (yz plane)
    #angle = random.uniform(-max_angle, max_angle)
    #image = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(1, 2), reshape=False)

In [24]:
# show images
#nifti_img = nib.load(t1wpath)
#nifti_img = nib.as_closest_canonical(nifti_img)
#nii_data = nifti_img.get_fdata()
#nii_aff = nifti_img.affine
#nii_hdr = nifti_img.header
#print(nii_aff, "\n", nii_hdr)
#print(nii_data.shape)
#for slice_Number in range(nii_data.shape[2]):
#    plt.imshow(nii_data[:,:,slice_Number])
#    plt.show()

In [25]:
#im = nib.load(t1wpath).get_fdata()
#result1 = resize(im, (227,227,256), order=1,preserve_range=False)
#for slice_Number in range(result1.shape[2]):
#    plt.imshow(result1[:,:,slice_Number])
#    plt.show()

## Load complete datafolder

In [26]:
root_dir = "/Users/dominik/Downloads/usable"

Iterate through folder; ###Merke: only skullflipping requires paths all other just need skullstrips
skullstrips zusammenbauen aus gleichnamigen t1w + masks
do augmentation -> add to usable/unusable list
keep 50/50 ratio between lists

## training data generator

In [51]:
def build_trainingData(file_dir):
    file_list = sorted(os.listdir(file_dir))
    usable_skullstrips = []
    unusable_skullstrips = []
    
    file_list_len = len(file_list)
    idx = 0
    for cur_t1w,cur_mask in zip(file_list[0::2], file_list[1::2]):
        old_skullstrip = SkullStrip(t1w=nib.load(cur_t1w), mask=nib.load(cur_mask))
        
        if idx<file_list_len//4:
            usable_skullstrips.append(increaseContrast(skullstrip=old_skullstrip))
            
        rand_num = random.randint(0,2)
        if rand_num == 0:
            unusable_skullstrips.append(randomNoise(skullstrip=old_skullstrip))
        elif rand_num == 1:
            unusable_skullstrips.append(randomSkullFlip(t1w_file=t1w, mask_file=mask))
        elif rand_num == 2:
            unusable_skullstrips.append(randomRotateMask(skullstrip=old_skullstrip))
        idx +=1
    print(len(usable_skullstrips))
    print(len(unusable_skullstrips))
    


In [52]:
build_trainingData(root_dir)

FileNotFoundError: No such file or no access: 'ds-integrament_site-berlin_sub-PIAB16T5Z_desc-brain_mask.nii.gz'