In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import nibabel

underlay = r'../evaldata/sub-001/ses-01/sub-001_ses-01_ref.nii'
overlay = r'evalresults/sub-001_ses-01_moco_ants_affine.nii'

ul = nibabel.load(underlay).get_fdata()
ol = nibabel.load(overlay).get_fdata()

ul = ul.swapaxes(0, 2)
ol = ol[:,:,:,4]
ol = ol.swapaxes(0, 2)


In [7]:
import matplotlib.pyplot as plt
%matplotlib widget
import torch
from ipywidgets import interact,fixed
import ipywidgets as widgets

def imyshow(im):
    def myshow(arr,vmin,vmax,z=0,t=0):
        plt.imshow(arr[t,z],vmin=vmin,vmax=vmax,cmap='gray')  
    if type(im)==torch.Tensor:
        arr=im.cpu().detach().numpy()
    elif type(im)==np.ndarray or type(im)==np.core.memmap:
        arr=np.copy(im)
    else:
        print('Unknown format, nothing to display!')
        return None
    if len(arr.shape)>3:
        arr=np.squeeze(arr)
    print(arr.shape)
    if len(arr.shape)==2:
        interact(myshow,arr=fixed(arr[None,None,:]),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=fixed(0),t=fixed(0))
    if len(arr.shape)==3:
        interact(myshow,arr=fixed(arr[None,:]),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=(1,arr.shape[0]-1),t=fixed(0))
    if len(arr.shape)==4:
        interact(myshow,arr=fixed(arr),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=(1,arr.shape[1]-1),t=(0,arr.shape[0]-1))

# expanded interactive display of 3d images, takes pytorch, sitk and np arrays
def imyshowWOL(im, ol, op=0.1):
    def myshow(arr,ol,vmin,vmax,z=0,t=0):
        plt.imshow(arr[t,z],vmin=vmin,vmax=vmax,cmap='gray')  
        plt.imshow(ol[t,z],vmin=vmin,vmax=vmax,cmap='hot',alpha=op)
    if type(im)==torch.Tensor:
        arr=im.cpu().detach().numpy()
        ol=ol.cpu().detach().numpy()
    elif type(im)==np.ndarray or type(im)==np.core.memmap:
        arr=np.copy(im)
        ol=np.copy(ol)
    else:
        print('Unknown format, nothing to display!')
        return None
    if len(arr.shape)>3:
        arr=np.squeeze(arr)
        ol=np.squeeze(ol)
    print(arr.shape)
    if len(arr.shape)==2:
        interact(myshow,arr=fixed(arr[None,None,:]),ol=fixed(ol[None,None,:]),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=fixed(0),t=fixed(0))
    if len(arr.shape)==3:
        interact(myshow,arr=fixed(arr[None,:]),ol=fixed(ol[None,:]),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=(1,arr.shape[0]-1),t=fixed(0))
    if len(arr.shape)==4:
        interact(myshow,arr=fixed(arr),ol=fixed(ol),vmin=fixed(arr.min()),vmax=fixed(arr.max()),z=(1,arr.shape[1]-1),t=(0,arr.shape[0]-1))


In [9]:

imyshowWOL(ul, ol, op=0.5)


(116, 512, 512)


interactive(children=(IntSlider(value=1, description='z', max=115, min=1), Output()), _dom_classes=('widget-in…

In [None]:
# expanded interactive display of 3d images, takes pytorch, sitk and np arrays

def imyshowRGBWOL(ul, ol, op=0.1):

    def myshow(arr1,arr2,z=0,t=0):
        image1 = arr1[t,z]
        image2 = arr2[t,z]

        # Normalize images to [0, 1] range for RGB overlay
        image1_norm = (image1 - np.min(image1)) / (np.max(image1) - np.min(image1))
        image2_norm = (image2 - np.min(image2)) / (np.max(image2) - np.min(image2))

        # Create an RGB image with image1 in red and image2 in green
        rgb_image = np.zeros((image1.shape[0], image1.shape[1], 3))
        rgb_image[..., 0] = image1_norm  # Red channel
        rgb_image[..., 1] = image2_norm  # Green channel
        
        plt.figure()
        plt.imshow(rgb_image)
        plt.axis("off")
        plt.show()

    if type(ul)==torch.Tensor:
        arr=ul.cpu().detach().numpy()
        ol=ol.cpu().detach().numpy()
    elif type(ul)==np.ndarray or type(ul)==np.core.memmap:
        arr=np.copy(ul)
        ol=np.copy(ol)
    else:
        print('Unknown format, nothing to display!')
        return None

    if ul.shape != ol.shape:
        print("underlay and overlay array shapes differ! exiting!")
        return None

    if len(arr.shape)>3:
        arr=np.squeeze(arr)
        ol=np.squeeze(ol)
    print(arr.shape)

    if len(arr.shape)==2:
        interact(myshow, 
                 arr=fixed(arr[None,None,:]), 
                 ol=fixed(ol[None,None,:]), 
                 vmin=fixed(arr.min()), 
                 vmax=fixed(arr.max()), 
                 z=fixed(0), 
                 t=fixed(0))
    if len(arr.shape)==3:
        interact(myshow, 
                 arr1=fixed(arr[None,:]), 
                 arr2=fixed(ol[None,:]), 
                 z=(1,arr.shape[0]-1), 
                 t=fixed(0))
    if len(arr.shape)==4:
        interact(myshow, 
                 arr=fixed(arr), 
                 ol=fixed(ol), 
                 vmin=fixed(arr.min()), 
                 vmax=fixed(arr.max()), 
                 z=(1,arr.shape[1]-1), 
                 t=(0,arr.shape[0]-1))

imyshowRGBWOL(ul, ol)

(116, 512, 512)


interactive(children=(IntSlider(value=1, description='z', max=115, min=1), Output()), _dom_classes=('widget-in…