In [1]:
from nilearn.image import threshold_img, load_img
from niworkflows import NIWORKFLOWS_LOG
from niworkflows.viz.utils import cuts_from_bbox, compose_view
from nipype.interfaces.base import File, isdefined
from nipype.interfaces.mixins import reporting



In [5]:
def plot_registration(
    anat_nii,
    div_id,
    plot_params=None,
    order=("z", "x", "y"),
    cuts=None,
    estimate_brightness=False,
    label=None,
    contour=None,
    compress="auto",
):
    """
    Plot the foreground and background views.
    Default order is: axial, coronal, sagittal
    """
    from uuid import uuid4

    from lxml import etree
    from nilearn.plotting import plot_anat
    from svgutils.transform import SVGFigure
    from niworkflows.viz.utils import robust_set_limits, extract_svg, SVGNS

    plot_params = plot_params or {}

    # Use default MNI cuts if none defined
    if cuts is None:
        raise NotImplementedError  # TODO

    out_files = []
    if estimate_brightness:
        plot_params = robust_set_limits(anat_nii.get_data().reshape(-1), plot_params)

    # Plot each cut axis
    for i, mode in enumerate(list(order)):
        plot_params["display_mode"] = mode
        plot_params["cut_coords"] = cuts[mode]
        if i == 0:
            plot_params["title"] = label
        else:
            plot_params["title"] = None

        # Generate nilearn figure
        display = plot_anat(anat_nii, **plot_params)
        if contour is not None:
            display.add_contours(contour, colors="g", levels=[0.5], linewidths=0.5)

        svg = extract_svg(display, compress=compress)
        display.close()

        # Find and replace the figure_1 id.
        xml_data = etree.fromstring(svg)
        find_text = etree.ETXPath("//{%s}g[@id='figure_1']" % SVGNS)
        find_text(xml_data)[0].set("id", "%s-%s-%s" % (div_id, mode, uuid4()))

        svg_fig = SVGFigure()
        svg_fig.root = xml_data
        out_files.append(svg_fig)

    return out_files

In [4]:
ref_nii = load_img()
fmap_nii = load_img()
contour_nii = load_img()
out_report = "test.svg"

mask_nii = threshold_img(ref_nii, 1e-3)
n_cuts = 7
cuts = cuts_from_bbox(mask_nii, cuts=n_cuts)

TypeError: load_img() missing 1 required positional argument: 'img'

In [None]:
fmap_data = fmap_nii.get_fdata()
vmax = max(fmap_data.max(), abs(fmap_data.min()))

In [None]:
compose_view(
    plot_registration(ref_nii, 'fixed-image',
                      estimate_brightness=True,
                      cuts=cuts,
                      label='reference',
                      contour=contour_nii,
                      compress=False),
    plot_registration(fmapnii, 'moving-image',
                      estimate_brightness=True,
                      cuts=cuts,
                      label='fieldmap (Hz)',
                      contour=contour_nii,
                      compress=False,
                      plot_params={'cmap': coolwarm_transparent(),
                                   'vmax': vmax,
                                   'vmin': -vmax}),