# Inference and Visualization of UW-Madison Data


# Import Libraries

In [3]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
from tqdm.notebook import tqdm
tqdm.pandas()
import time
import copy
import joblib
import gc
from IPython import display as ipd
from joblib import Parallel, delayed

# visualization
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import image as mpimg
from matplotlib.patches import Rectangle

# 3D Image Utilities

In [4]:
def load_3d(folder, subfolder, case_day):    
    path = f"{folder}/{subfolder}/{case_day}.npy"
    # print(path)
    img = np.load(path, encoding='bytes')
    return img

def show_3d(folder, case_day, num_wanted, show_mask, scale):

    image_3d = load_3d(folder, "images", case_day)
    if show_mask:
        mask_3d = load_3d(folder, "masks", case_day)

    max_slice = image_3d.shape[2]
    if num_wanted== -1:
        num_wanted = max_slice + 1
    slices = [max_slice  * (sample_num + 1)  // (num_wanted+1) for sample_num in range(num_wanted)]
    # slices = [92] # zona
    print(f"slices {slices}")

    nrows = len(slices)
    ncols = 2
    plt.figure(1)
    fix, axs = plt.subplots(nrows, ncols, figsize=(ncols*scale, nrows*scale),
                            subplot_kw={'xticks': [], 'yticks': []})

    index = 0
    for slice in slices:

        index += 1
        plt.subplot(nrows, ncols, index)
        plt.title(f"Slice {slice}")

        image_2d = image_3d[:, :, slice]
        plt.imshow(image_2d,  interpolation='none', cmap='bone')
        # plt.imshow(image_2d)

        index += 1
        if show_mask:
            mask_2d = mask_3d[:, :, slice]
            counts, buckets = np.histogram(mask_2d, bins=256, range=(-0.5, 255.5))
            count_string = (f"noclass:{counts[0]}\n" +
                            f"a:{counts[1]} b:{counts[2]} c:{counts[4]}\n" +
                            f"ab:{counts[3]} ac:{counts[5]} bc:{counts[6]} abc:{counts[7]}")

            # mask_2d_normalized = mask_3d_normalized[:, :, slice]
            # print(np.histogram(mask_2d, bins=8, range=(-0.5, 7.5)))
        
            plt.subplot(nrows, ncols, index)
            plt.title(count_string, loc='left')
            # plt.imshow(mask_2d_normalized)
            # plt.imshow(mask_2d_normalized, cmap='bone')
            # plt.imshow(mask_2d, cmap = plt.colormaps["plasma"])
            # plt.imshow(mask_2d, cmap = plt.colormaps["inferno"])
            # plt.imshow(mask_2d, vmin=0, vmax=7, cmap = plt.colormaps["magma"])
            plt.imshow(mask_2d, vmin=0, vmax=7, interpolation='none', cmap = plt.colormaps["nipy_spectral"])


        #     # plt.imshow(np.ma.masked_where(mask!=1, mask), alpha=0.5, cmap='autumn')
        #     plt.imshow(mask, alpha=0.5)
        #     handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        #     labels = [ "Large Bowel", "Small Bowel", "Stomach"]
        #     plt.legend(handles,labels)


# Check 3D Image (series of 2D slices)

In [5]:
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [6]:
import datetime

In [7]:
datetime.datetime.now()

datetime.datetime(2023, 10, 17, 19, 59, 10, 824949)

# Load Model Checkpoint

In [17]:
from capsnet_model_3d import CapsNet3D
model = CapsNet3D()

path = '/mnt/d/code_medimg_aneja_lab/results_231010/saved_model.pth.tar'
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
_ = model.eval()

# Visualize test results

In [None]:
### in progress....

In [None]:
base_folder = '/mnt/d/code_medimg_aneja_lab/data_uwmadison_01c_preprocessed_3d'
case_day = 'case_85_day_29'
show_3d(base_folder, case_day, num_wanted=-1, show_mask=True, scale=8)

In [None]:
 def validate(self):
        print('>>>   Validating   <<<')
        self.model.eval()

        this_epoch_losses = []

        for i, data_batch_cpu in enumerate(self.valid_dataloader):
            # inputs, targets = data_batch
            # inputs, targets = inputs.to(self.device), targets.to(self.device)

            inputs_cpu, targets_cpu = data_batch_cpu

            inputs_cpu = torch.unsqueeze(inputs_cpu, 0)
            inputs_cpu = torch.permute(inputs_cpu, (1, 0, 2, 3, 4))
            inputs_cpu_clone = torch.clone(inputs_cpu)
            inputs = inputs_cpu_clone.to(self.device, dtype=torch.float32)

            targets_cpu = torch.unsqueeze(targets_cpu, 0)
            targets_cpu = torch.permute(targets_cpu, (1, 0, 2, 3, 4))
            targets_cpu_clone = torch.clone(targets_cpu)
            targets = targets_cpu_clone.to(self.device, dtype=torch.float32)

            with torch.no_grad():
                outputs = self.model(inputs)
                losses = self.criterion_individual_losses(outputs, targets)

            this_epoch_losses += list(losses.cpu().numpy())

        self.valid_losses = pd.concat([self.valid_losses,
                                       pd.DataFrame({f'm{self.miniepoch}_e{self.epoch}': this_epoch_losses})],
                                      axis=1)

        self.model.train()

# Visualizing 3D

https://www.geeksforgeeks.org/displaying-3d-images-in-python/

In [None]:

# Import libraries
# import numpy as np
# import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
 
# Change the Size of Graph using 
# Figsize
fig = plt.figure(figsize=(10, 10))
 
# Generating a 3D sine wave
ax = plt.axes(projection='3d')
 
# Creating array points using 
# numpy
x = np.arange(0, 20, 0.1)
y = np.sin(x)
z = y*np.sin(x)
c = x + y
 
# To create a scatter graph
ax.scatter(x, y, z, c=c)
 
# turn off/on axis
plt.axis('off')
 
# show the graph
plt.show()

In [None]:
# Import libraries
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# import numpy as np
 
# Change the Size of Graph using 
# Figsize
fig = plt.figure(figsize=(10, 10))
 
# Generating a 3D sine wave
ax = plt.axes(projection='3d')
 
# Create axis
axes = [5, 5, 5]
 
# Create Data
data = np.ones(axes)
 
# Control Tranperency
alpha = 0.7
 
# Control colour
colors = np.empty(axes + [4])
 
colors[0] = [1, 0, 0, alpha]  # red
colors[1] = [0, 1, 0, alpha]  # green
colors[2] = [0, 0, 1, alpha]  # blue
colors[3] = [1, 1, 0, alpha]  # yellow
colors[4] = [1, 1, 1, alpha]  # grey
 
# turn off/on axis
plt.axis('off')
 
# Voxels is used to customizations of
# the sizes, positions and colors.
_ = ax.voxels(data, facecolors=colors, edgecolors='grey')

In [None]:

#Import libraries
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
 
#Change the Size of Graph using Figsize
fig = plt.figure(figsize=(10,10))
 
#Generating a 3D sine wave
ax = plt.axes(projection='3d')
 
 
# assigning coordinates 
x = np.linspace(-1, 5, 10)
y = np.linspace(-1, 5, 10)
X, Y = np.meshgrid(x, y)
# Z = np.sin(np.sqrt(X ** 2 + Y ** 2))
Z = np.ones((len(x), len(x)))
 
# creating the visualization
ax.plot_wireframe(X, Y, Z, color ='green')
 
# turn off/on axis
plt.axis('on')
plt.title("Whatever")
plt.xlabel("The X")
plt.ylabel("The Y")
plt.zlabel("The Z")
plt.set_zlabel("foo")

x

In [None]:

from numpy import linspace
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
 
 
# Creating 3D figure
fig = plt.figure(figsize=(4, 4))
ax = plt.axes(projection='3d')
 
# Creating Dataset
z = np.linspace(0, 15, 1000)
x = np.sin(z)
y = np.cos(z)
_ = ax.plot3D(x, y, z, 'green')
 
# 360 Degree view
for angle in range(0, 360):
    _ = ax.view_init(angle, 30)
    _ = plt.draw()
    _ = plt.pause(.1)
 
    _ = plt.show()