In [None]:
from collections import defaultdict
from dataclasses import dataclass
import io
import json
from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np
from scipy import ndimage
from seaborn import color_palette

import nibabel as nib
from nilearn.image import new_img_like, resample_img
from nilearn.plotting import plot_anat, view_img
from niworkflows.viz.utils import cuts_from_bbox, robust_set_limits
from niworkflows.utils.images import rotation2canonical, rotate_affine

from halfpipe.utils.path import split_ext

In [None]:
%matplotlib inline

In [None]:
base_directory = Path("/mnt/mbServerData/newdata/moods/halfpipe/bonn")

In [None]:
exclude_files = list(base_directory.glob("exclude*.json"))
exclude_entries = list()

for exclude_file in exclude_files:
    with open(exclude_file) as file_handle:
        exclude_entries.extend(json.load(file_handle))

In [None]:
paths_by_tags = defaultdict(lambda: defaultdict(set))
tags_by_paths = dict()

bids_directory = base_directory / "derivatives" / "fmriprep"
for bids_path in bids_directory.glob("**/*"):
    stem, extension = split_ext(bids_path)

    if stem.startswith("."):
        continue  # skip hidden files

    tokens = stem.split("_")

    tags = dict(
        path=str(bids_directory), 
        suffix=tokens.pop(-1),
        extension=extension,
    )
    tags_by_paths[bids_path] = tags

    for token in tokens:
        key, value = token.split("-")
        tags[key] = value
    for key, value in tags.items():
        paths_by_tags[key][value].add(bids_path)

In [None]:
def get(**filters):
    res = None
    
    for key, value in filters.items():
        if key not in paths_by_tags:
            # logger.info(f"Unknown key \"{key}\"")        
            return None
    
        values = paths_by_tags[key]
        if value is None:
            for paths in values.values():
                res -= paths
            continue
                
        elif value not in values:
            # logger.info(f"Unknown value \"{value}\"")        
            return None
        
        paths = values[value]
        if res is not None:
            res &= paths
        else:
            res = paths.copy()
    
    return res

def get_tag_value(path, key):
    return tags_by_paths[path].get(key)

In [None]:
def _plot_anat_with_contours(image, segs=None, **plot_params):
    nsegs = len(segs or [])
    plot_params = plot_params or {}
    # plot_params' values can be None, however they MUST NOT
    # be None for colors and levels from this point on.
    colors = plot_params.pop("colors", None) or []
    levels = plot_params.pop("levels", None) or []
    missing = nsegs - len(colors)
    if missing > 0:  # missing may be negative
        colors = colors + color_palette("husl", missing)

    colors = [[c] if not isinstance(c, list) else c for c in colors]

    if not levels:
        levels = [[0.5]] * nsegs

    # anatomical
    display = plot_anat(image, **plot_params)

    # remove plot_anat -specific parameters
    plot_params.pop("display_mode")
    plot_params.pop("cut_coords")

    plot_params["linewidths"] = 0.5
    for i in reversed(range(nsegs)):
        plot_params["colors"] = colors[i]
        display.add_contours(segs[i], levels=levels[i], **plot_params)
        
    return display

In [None]:
target_width = 2048

def to_rgb(display):
    figure = display.frame_axes.figure
    canvas = figure.canvas
    
    # scale to target_width
    width, height = canvas.get_width_height()
    figure.set_dpi(target_width / width * figure.get_dpi())
    
    canvas.draw()
    width, height = canvas.get_width_height()
    
    image = np.frombuffer(
        canvas.tostring_rgb(), dtype=np.uint8
    ).reshape((height, width, -1))[..., :3]
    
    image = image[:, :target_width, :]  # crop rounding errors
    
    return image

In [None]:
@dataclass
class SkullStrip:
    t1w: nib.Nifti1Image
    mask: nib.Nifti1Image
    
    def to_image(self):
        plot_params = dict(colors=None)

        image_nii = self.t1w
        seg_nii = self.mask

        canonical_r = rotation2canonical(image_nii)
        image_nii = rotate_affine(image_nii, rot=canonical_r)
        seg_nii = rotate_affine(seg_nii, rot=canonical_r)

        data = image_nii.get_fdata()
        plot_params = robust_set_limits(data, plot_params)

        bbox_nii = seg_nii

        cuts = cuts_from_bbox(bbox_nii, cuts=7)

        images = list()
        for d in plot_params.pop("dimensions", ("z", "x", "y")):
            plot_params["display_mode"] = d
            plot_params["cut_coords"] = cuts[d]
            display = _plot_anat_with_contours(
                image_nii, segs=[seg_nii], **plot_params
            )
            images.append(to_rgb(display))
            display.close()

        image = np.vstack(images)
        return image

In [None]:
skull_strips = list()

for exclude_entry in exclude_entries:
    rating = exclude_entry["rating"]
    type = exclude_entry["type"]
    sub = exclude_entry["sub"]
    
    if rating != "good":
        continue
        
    if type != "skull_strip_report":
        continue
        
    t1w_files = get(sub=sub, desc="preproc", res=None, suffix="T1w", extension=".nii.gz")
    if t1w_files is None:
        continue
    (t1w_file,) = t1w_files
    
    mask_files = get(sub=sub, desc="brain", res=None, suffix="mask", extension=".nii.gz")
    if mask_files is None:
        continue
    (mask_file,) = mask_files   
        
    skull_strips.append(
        SkullStrip(
            t1w=nib.load(t1w_file),
            mask=nib.load(mask_file),
        )
    )

In [None]:
skull_strip = skull_strips[0]

In [None]:
image = skull_strip.to_image()
figure = plt.figure(figsize=(20,10))
plt.imshow(image)

In [None]:
mask = skull_strip.mask
mask

In [None]:
plot_anat(mask)

In [None]:
mask_data = np.asanyarray(mask.dataobj).astype(bool)
mask_data.shape

In [None]:
bad_mask_data = ndimage.rotate(mask_data, 45, reshape=False, output=float) > 0.5
bad_mask_data.shape

In [None]:
bad_mask = new_img_like(mask, bad_mask_data, copy_header=True)

In [None]:
plot_anat(bad_mask)

In [None]:
bad_skull_strip = SkullStrip(
    t1w=skull_strip.t1w,
    mask=bad_mask,
)

In [None]:
image = bad_skull_strip.to_image()
figure = plt.figure(figsize=(20,10))
plt.imshow(image)