## Mice 3D Micro-CT Segmentation and Visualization

In [13]:
import sys
sys.path.append("..") 
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import os
import imageio
import torch.cuda
import segmentation_models_pytorch as smp
from collections import defaultdict
import nibabel as nib
from tqdm import tqdm
from Utils.dataset_utils import *
from Utils.prediction_utils import *
from Utils.visualization_utils import *
from ipywidgets import *
from IPython.display import clear_output, display
from tkinter import Tk, filedialog
import SimpleITK as sitk

In [14]:
DEVICE = 'cuda'
MODEL_NAME='Unet'
ENCODER='efficientnet-b4'
ENCODER_WEIGHTS='imagenet'
BEST_WEIGHTS=r"G:\Projects and Work\Mouse Heart Segmentation\Trained Weights - 13 Mice Volumes\Unet_efficientnet-b4\best_score.pt"

**Import necessary libraries**

**Select Model Weights**

In [15]:
def select_files(b):
    clear_output()                                         # Button is deleted after it is clicked.
    root = Tk()
    root.withdraw()                                        # Hide the main window.
    root.call('wm', 'attributes', '.', '-topmost', True)   # Raise the root to the top of all windows.
    b.files = filedialog.askopenfilename(multiple=True)    # List of selected files will be set button's file attribute.
    print(b.files)  
def select_folder(b):
    clear_output() 
    root = Tk() # pointing root to Tk() to use it as Tk() in program.
    root.withdraw() # Hides small tkinter window.
    root.attributes('-topmost', True) # Opened windows will be active. above all windows despite of selection.
    b.folder = filedialog.askdirectory()
    print(b.folder)

In [16]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
best_model = torch.load(BEST_WEIGHTS)

FileNotFoundError: [Errno 2] No such file or directory: 'G:\\Projects and Work\\Mouse Heart Segmentation\\Trained Weights - 13 Mice Volumes\\Unet_efficientnet-b4\\best_score.pt'

In [None]:
BEST_WEIGHTS = weight_select.files[0]

**Select Subject**

**Load Model**

In [8]:
weight_select = Button(description="Select Weights")
weight_select.on_click(select_files)
display(weight_select)

Button(description='Select Weights', style=ButtonStyle())

In [None]:
SUBJECT_PATH=subject_select.files[0]

In [None]:
subject_select = Button(description="Select Subject")
subject_select.on_click(select_files)
display(subject_select)

In [None]:
volume,gt_mask,affine=load_case(SUBJECT_PATH,"")

In [None]:
SUBJECT=os.path.split(SUBJECT_PATH)[-1]

In [None]:
volume_pred_mask=predict_volume(best_model,volume,True,preprocessing_fn)   # Predict volume
volume_pred_mask=np.round(volume_pred_mask)

**Predict Volume**

**Saving predicted mask as .nii**

In [None]:
def create_folder(path):
    if os.path.exists(path)==False:
        os.mkdir(path)
save_folder=os.path.split(SUBJECT_PATH)[:-1][0]
create_folder(os.path.join(save_folder,'Output'))
create_folder(os.path.join(save_folder,'Movie'))

In [None]:
save_mask_nii(volume_pred_mask,affine,os.path.join(save_folder,'Output','prediction.nii.gz'))

**Display Results - 2D**

In [14]:
REFERENCE_WIDTH_5MM=272

In [None]:
def calculate_dim(mask_slice):
    thresh=mask_slice.copy()
    thresh=thresh.astype(dtype="uint8")
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours)<1:
        return 0,0
    c = max(contours, key=cv2.contourArea)
    drawn = np.zeros((512,512))
    drawn=cv2.drawContours(drawn,[c],0,(1,1,1),-1)
    x,y,w,h = cv2.boundingRect(c)
    width = round(w/REFERENCE_WIDTH_5MM*5,3)
    height = round(h/REFERENCE_WIDTH_5MM*5,3)
    return width,height
def display_2d(idx):
  
    
    n = 3
    plt.figure(figsize=(16, 16))
    
    w,h=(calculate_dim(volume_pred_mask[idx,:,:]))
    title=f'Max Width: {w}, Max Height: {h}'
    plt.subplot(1, n,1)
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    plt.imshow(draw_mask(volume[idx,:,:],volume_pred_mask[idx,:,:]),cmap="gray")
    
    w,h=(calculate_dim(volume_pred_mask[:,idx,:]))
    title=f'Max Width: {w}, Max Height: {h}'
    plt.subplot(1, n,2)
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    plt.imshow(draw_mask(volume[:,idx,:],volume_pred_mask[:,idx,:]),cmap="gray")
    
    w,h=(calculate_dim(volume_pred_mask[:,:,idx]))
    title=f'Max Width: {w}, Max Height: {h}'
    plt.subplot(1, n,3)
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    plt.imshow(draw_mask(volume[:,:,idx],volume_pred_mask[:,:,idx]),cmap="gray")
    
    plt.show()
#     visualize(
#             Z=draw_mask(volume[idx,:,:],volume_pred_mask[idx,:,:]),
#             Y=draw_mask(volume[:,idx,:],volume_pred_mask[:,idx,:]),
#             X=draw_mask(volume[:,idx,:],volume_pred_mask[:,:,idx])
#              )
interact(display_2d,idx=widgets.IntSlider(min=0, max=512, step=1, value=255))


### Calculate Heart Volume in mm3

In [None]:
def calculate_volume( mask_image ):
    # Input:
    # image = sitk.Image, mask or binary image (1 values where organ, 0 values otherwise)
    # Output:
    # vol = float, volume in mm3 
    space = mask_image.GetSpacing()         # image spacing
    voxel = np.prod(space)                  # voxel volume
    img = sitk.GetArrayFromImage(mask_image)
    vol = voxel*np.sum(img)
    return vol
v=sitk.ReadImage(os.path.join(save_folder,'Output','prediction.nii.gz'))
print(f'Volume of mice Heart: {round(calculate_volume(v),3)} mm3')

**Save GIF**

In [None]:
def make_gif(volume,volume_pred_mask=None):
    img_arr=[]
    for idx in tqdm(range(volume.shape[0])):
        Z=draw_mask(volume[idx,:,:],volume_pred_mask[idx,:,:])
        Y=draw_mask(volume[:,idx,:],volume_pred_mask[:,idx,:])
        X=draw_mask(volume[:,:,idx],volume_pred_mask[:,:,idx])
#         print(X.shape,Y.shape,Z.shape)
        stacked=np.hstack((X,Y,Z))
        img_arr.append(stacked)
    imageio.mimsave(os.path.join(save_folder,'Movie','movie.gif'), img_arr)
    del img_arr
make_gif(volume,volume_pred_mask)

**Display Results - 3D**

In [None]:
import meshplot as mp
from skimage.measure import marching_cubes_lewiner
v1,f1,_,_ = marching_cubes_lewiner(volume_pred_mask)
p = mp.plot(v1,f1, return_plot=True)