In [9]:
%matplotlib notebook
import pydicom as dicom
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge
import ipywidgets as widgets

In [10]:
def extract_points(x1, v1, v2, w, h):
    '''
    :param x1: the center point of the slice
    :param v1: one vector in the slice
    :param v2: another vector in the slice
    :param w: width of the slice
    :param h: height of the slice
    :return:
    p1, p2, p3：3 points that can form 2 perpendicular vectors
    '''
    # convert the points and vectors to be np.array
    x1, v1, v2 = np.array(x1), np.array(v1), np.array(v2)
    
    # normalize the 2 given vectors
    v1 = v1 / np.linalg.norm(v1)
    v2 = v2 / np.linalg.norm(v2)

    # calculate the normal vector and normalize the normal vector
    nv = np.cross(v1, v2)
    nv = nv / np.linalg.norm(nv)
    [a,b,c] = nv
    
    # use the unit normal vector to get the 90 degree rotation matrix
    rotate_matrix = np.array([[a**2, a*b-c, a*c+b],[b*a+c, b**2, b*c-a],[c*a-b, c*b+a, c**2]])
    
    # calculate v1's 90 degree rotated vector and normalize it
    v1_p = rotate_matrix@v1.T
    v1_p = v1_p / np.linalg.norm(v1_p)
    
    # get the 3 perpendicular points we want for the slice area
    p1 = x1 + h/2*v1 - w/2*v1_p
    p2 = x1 - h/2*v1 - w/2*v1_p
    p3 = x1 - h/2*v1 + w/2*v1_p
    
    return p1, p2, p3

In [11]:
def slice_area(img, p1, p2, p3, rho_range, theta_range, phi_range, rx, ry):
    '''
    :param img: 3d pixel array
    :param p1,p2,p3: 3 points that can form 2 perpendicular vectors
    :param  rho_range, theta_range, phi_range: polar coordinate ranges
    :rx: the resolution on x-axis
    :ry: the resolution on y-axis
    :return:
    slicer：2d pixel array
    '''
    # convert all the input points to be np.array
    p1, p2, p3 = np.array(p1), np.array(p2), np.array(p3)
    
    
    # ratio for linear interpolation and apply them on the line between two points
    t1 = np.linspace(0, 1, rx)[:, None]
    u = t1 * p3[None, :] + (1 - t1) * p2[None, :]
    
    t2 = np.linspace(0, 1, ry)[:, None]
    v = t2 * p1[None, :] + (1 - t2) * p2[None, :]
    
    
    # create the position matrix
    p = u[:,None,:] + (v - v[0])[None,:,:]
    
    # convert to polar coordinate system
    rho = (p ** 2).sum(axis=-1) ** 0.5
    theta = np.arctan2(p[...,2], (p[...,0] ** 2 + p[...,1] ** 2) ** 0.5)
    phi = np.arctan2(p[...,1], p[...,0])
    
    loc = np.zeros(p.shape)

    # Convert rho,theta,phi to i,j,k by mapping points 
    loc[..., 0] = (rho - rho_range[0])/(rho_range[1] - rho_range[0]) * (img.shape[0] - 1)
    loc[..., 1] = (theta - theta_range[0])/(theta_range[1] - theta_range[0]) * (img.shape[1] - 1)
    loc[..., 2] = (phi - phi_range[0])/(phi_range[1] - phi_range[0]) * (img.shape[2] - 1)
    
    # initialize the return slicer
    slicer = np.zeros((rx, ry))
    
    # create a mask that checks if the position is in the bounds
    mask = (loc[..., 0] >= 0) & (loc[..., 0] < img.shape[0]-1) & (
        loc[..., 1] >= 0) & (loc[..., 1] < img.shape[1]-1) & (
        loc[..., 2] >= 0) & (loc[..., 2] < img.shape[2]-1)
    
    # if no intersection points, return the whole black slicer
    if mask.sum() == 0:
        return slicer
    
    #### trilinear interpolation ####
    
    # find the eight points around the position points
    c1 = np.floor(loc[mask]).astype(int)
    c2 = c1 + np.array([0,0,1])[None,:]
    c3 = c1 + np.array([0,1,0])[None,:]
    c4 = c1 + np.array([0,1,1])[None,:]
    c5 = c1 + np.array([1,0,0])[None,:]
    c6 = c1 + np.array([1,0,1])[None,:]
    c7 = c1 + np.array([1,1,0])[None,:]
    c8 = c1 + np.array([1,1,1])[None,:]
    
    # the differences
    d = loc[mask] - c1
    
    # get the differences on x-axis
    x_d = d[:,0]
    
    # limit eight points to four points on differences on x-axis
    c00 = (1 - x_d) * img[c1[:, 0], c1[:, 1], c1[:, 2]] + x_d * img[c5[:, 0], c5[:, 1], c5[:, 2]]
    c01 = (1 - x_d) * img[c2[:, 0], c2[:, 1], c2[:, 2]] + x_d * img[c6[:, 0], c6[:, 1], c6[:, 2]]
    c10 = (1 - x_d) * img[c3[:, 0], c3[:, 1], c3[:, 2]] + x_d * img[c7[:, 0], c7[:, 1], c7[:, 2]]
    c11 = (1 - x_d) * img[c4[:, 0], c4[:, 1], c4[:, 2]] + x_d * img[c8[:, 0], c8[:, 1], c8[:, 2]]
    
    # get the differences on y-axis
    y_d = d[:,1]
    
    # limit four points to two points on differences on y-axis
    c0 = (1 - y_d) * c00 + y_d * c10
    c1 = (1 - y_d) * c01 + y_d * c11
    
    # get the differences on z-axis
    z_d = d[:,2]
    
    # limit two points to the final one points on differences on z-axis
    c = (1 - z_d) * c0 + z_d * c1

    # set the final pixel values to its positions in the initialized slicer
    slicer[mask] = c

    return slicer

In [12]:
def plot_2d(s,w,h):
    '''
    :param s: 2d pixel array
    :param w: width of the slice
    :param h: height of the slice
    :return: the image of the 2d pixel array
    '''
    plt.figure(figsize=(8, 8))
    plt.imshow(s.T, cmap='gray', vmin=0, vmax=s.max(), extent=(0,w,0,h))
    plt.gca().set_aspect('equal')
    plt.show()

In [13]:
def convert_3d_2d(img,c,v1,v2,x_range,y_range,z_range,w,h,rx,ry):
    '''
    :param img: 3d pixel array
    :param c: the center point of the slice
    :param v1: one vector in the slice
    :param v2: another vector in the slice
    :param w: width of the slice
    :param h: height of the slice
    :rx: the resolution on x-axis
    :ry: the resolution on y-axis
    :return:
    slicer：2d pixel array
    '''
    x1,x2,x3 = extract_points(c,v1,v2,w,h)
    s = slice_area(img,x1,x2,x3,x_range,y_range,z_range,rx,ry) 
    return s

In [14]:
img = np.load('/workspace/data/NAS2/3D_DICOMs/examples/test_img.npy')

In [15]:
# the rho, phi, theta inf
rho = [0, 110.82]
phi = [22.486 * np.pi / 180, 82.480 * np.pi / 180]
theta = [-28.7 * np.pi / 180, 28.7 * np.pi / 180]

In [16]:
#### Interactive 2d Slice User Interface ####

x = widgets.IntSlider(min=-100, max=100, value=50, step=1, description='x', continuous_update=True)
y = widgets.IntSlider(min=-100, max=100, value=40, step=1, description='y', continuous_update=True)
z = widgets.IntSlider(min=-100, max=100, value=0, step=1, description='z', continuous_update=True)
a = widgets.IntSlider(min=-90, max=90, value=0, step=1, description='angle1', continuous_update=True)
b = widgets.IntSlider(min=-90, max=90, value=0, step=1, description='angle2', continuous_update=True)
w = widgets.IntSlider(min=0, max=2*rho[1], value=130, step=1, description='width',continuous_update=True)
h = widgets.IntSlider(min=0, max=rho[1], value=90, step=1, description='height', continuous_update=True)

plt.figure(figsize=(5.5,5))
ax1 = plt.subplot(212)
image = ax1.imshow(convert_3d_2d(img[..., 0],[0,0,0],[1,0,0],[0,0,1],rho,theta,phi,100,50,800,800).T,
                   cmap='gray', vmin=0, vmax=255, extent=(0,200,0,110))

ax2 = plt.subplot(221)
plt.xlabel('Z')
plt.ylabel('X')
front = ax2.add_patch(Wedge((0,0), rho[1], theta[0]/np.pi*180-90, theta[1]/np.pi*180-90, width = rho[1]-rho[0]))
plt.xlim(-rho[1],rho[1])
plt.ylim(-rho[1],0)
line1, = ax2.plot([0,0,0,0,0], [0,0,0,0,0], 'ro-')
plt.gca().set_aspect('equal')
ax3 = plt.subplot(222)
plt.xlabel('Y')
plt.ylabel('X')
side = ax3.add_patch(Wedge((0,0), rho[1], phi[0]/np.pi*180-90, phi[1]/np.pi*180-90, width = rho[1]-rho[0]))
plt.xlim(-rho[1],rho[1])
plt.ylim(-rho[1],0)
line2, = ax3.plot([0,0,0,0,0], [0,0,0,0,0], 'ro-')
plt.gca().set_aspect('equal')
plt.tight_layout()
 
def plot(x,y,z,a,b,w,h):
    image.set_data(convert_3d_2d(img[..., 0],[x,y,z],[np.cos(a*np.pi/180),np.sin(a*np.pi/180),0],
                                 [np.sin(b*np.pi/180),0,np.cos(b*np.pi/180)],
                                 rho,theta,phi,w,h,800,800).T)
    image.set_extent((0,w,0,h))


def plot2(x,y,z,a,b,w,h):
    p1,p2,p3 = extract_points([x,y,z],[np.cos(a*np.pi/180),np.sin(a*np.pi/180),0],
                              [np.sin(b*np.pi/180),0,np.cos(b*np.pi/180)],w,h)
    p4 = p1 + (p3 - p2)
    x1 = [p1[2],p2[2],p3[2],p4[2],p1[2]]
    x2 = [p1[1],p2[1],p3[1],p4[1],p1[1]]
    y = [-p1[0],-p2[0],-p3[0],-p4[0],-p1[0]]
    line1.set_data(x1,y)
    line2.set_data(x2,y)
    
ui = widgets.VBox([x,y,z,a,b,w,h])

out = widgets.interactive_output(plot, {'x': x, 'y': y, 'z': z, 'a': a, 'b': b, 'w': w, 'h': h})

out2 = widgets.interactive_output(plot2, {'x': x, 'y': y, 'z': z, 'a': a, 'b': b, 'w': w, 'h': h})

display(out, out2, ui)