In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import lstsq
import jax
import jax.numpy as jnp
import mbirjax
from wind_tomo.ARC import *
from scipy.ndimage import gaussian_filter
import scipy.signal as signal
import time
from math import ceil
import numpy.ma as ma
from scipy.special import factorial
from scipy import stats as st

## Analysis Tools

In [None]:
def zero_mean(recon):
    m=np.mean(recon,axis=(1,2))
    return recon-m[:,np.newaxis,np.newaxis]

def zero_mean_with_mask(recon,mask):
    m=np.mean(recon,axis=(1,2),where=mask)
    return recon-m[:,np.newaxis,np.newaxis]

def nrmse(GT,recon):
    return np.sqrt(np.sum((GT-recon)**2)/GT.size)/(np.max(GT)-np.min(GT))

def nrmse_cropped(GT,recon,mask):
    return np.sqrt(np.sum((GT[mask]-recon[mask])**2)/(np.sum(mask)))/(np.max(GT[mask])-np.min(GT[mask]))

def nrmse_masked(GT,recon):
    return np.sqrt(np.ma.sum((GT-recon)**2)/GT.count())/(np.ma.max(GT)-np.ma.min(GT))


def analyze_zernike_error(GT, recon, max_order,diameter=None,method='Subtractive'):
    """
    Analyze the error of a reconstruction against a ground truth image in terms of the difference in Zernike coefficients.

    Parameters:
        ground_truth_image (2D array): The ground truth masked image.
        reconstructed_image (2D array): The reconstructed masked image.
        max_order (int): The highest order of Zernike polynomials to be fitted.

    Returns:
        avg_error_per_degree (list): The average error of the coefficients for each radial degree.
        error_per_coefficient (list): The error for each coefficient organized by radial degree.

        Plots showing the average error per radial degree and error per coefficient organized by radial degree.
    """
    if diameter==None:
        diameter=min(GT.shape[0],GT.shape[1],recon.shape[0],recon.shape[1])
    GT_coeff=fit_zernike(GT,max_order,diameter)
    Recon_coeff=fit_zernike(recon,max_order,diameter)

    mask=circ_block(np.ones(GT.shape),diameter-4)==0
    recon_imgs=[]
    GT_imgs=[]
    NRMSE=[]
    for i in range(0,max_order+2):
        if method=='Subtractive':
            GT_imgs.append(np.ma.masked_where(mask,circ_block(GT,diameter-4)-circ_block(reconstruct_image(GT_coeff[:i],diameter).data,diameter-4)))
            recon_imgs.append(np.ma.masked_where(mask,circ_block(recon,diameter-4)-circ_block(reconstruct_image(Recon_coeff[:i],diameter).data,diameter-4)))
            NRMSE.append(nrmse_masked(GT_imgs[i],recon_imgs[i]))

    return GT_imgs, recon_imgs, NRMSE

# Experiment: Testing changing the center of rotation

In [None]:
num_views=7
tilt_separation=2
angle_tilt=(tilt_separation*(num_views+1)/2)*np.pi/180

#beam Params in cm
diam=2
stack_offset_cm=0
num_stack=1
# should be at far side of
center_offset_cm_vec=np.array([0,2,4,6,8,10,12])
truncate=4


#PnP Params
sigma=np.array([0.02,0.02,0.02])
num_iterations=10

#phantom Params
num_rows=600

#dimensions and conversions
Window2WindowCMlength=20
cm_per_pixel=Window2WindowCMlength/num_rows
beam_pixel_diam=diam/cm_per_pixel

num_cols=int(25/cm_per_pixel)
num_slices=int(beam_pixel_diam)+2

stack_offset=0
center_offset_vec=center_offset_cm_vec/cm_per_pixel



sigma=sigma/cm_per_pixel

phantom=np.zeros((num_rows,num_cols,num_slices))

volume=np.load('volume.npy')[:num_rows,:num_cols,:num_slices]

phantom=volume-np.average(volume)# MAKE ZERO MEAN

print('A')
phantom_proj=jnp.array(phantom)

#determine array of physical propogation angles
angles = jnp.linspace(-angle_tilt, angle_tilt, num_views)
# Ground_truth=np.array(phantom)
ct_model=mbirjax.ParallelBeamModel((num_views,num_slices,num_cols), angles)
ct_model.set_params(recon_shape=(num_rows,num_cols,num_slices),verbose=0)
sinogram = ct_model.forward_project(phantom_proj)#helps make sure the projection is correct
print('B')
recons=np.zeros((len(center_offset_vec),num_rows,num_cols,num_slices))

for i,center_offset in enumerate(center_offset_vec):
    sinogram_weights=sino_window_and_circ_block(np.ones(sinogram.shape),angles,(num_rows,num_cols),beam_pixel_diam,num_stack,stack_offset,center_offset,for_weights=True)
    sinogram_tilt_removed=remove_tip_tilt(np.ma.masked_array(sinogram,mask=(sinogram_weights==0)),axis=0).data
    cropped_sinogram_tilt_removed=sino_window_and_circ_block(np.array(sinogram_tilt_removed),angles,(num_rows,num_cols),beam_pixel_diam,num_stack,stack_offset,center_offset,for_weights=False)
    recons[i]=Gaussian_plugandplay(ct_model,jnp.array(cropped_sinogram_tilt_removed),jnp.array(sinogram_weights),sigma,truncate,num_iterations,convg=0.07,show_iter=25)

In [None]:
np.save('axis_location_sweep_0_to_12.npy',recons)

### Analysis of results

In [None]:
vmin=min(np.min(phantom[:,:,30]),np.min(recons[[0,2,4,6],:,:,30]))
vmax=max(np.max(phantom[:,:,30]),np.max(recons[[0,2,4,6],:,:,30]))
extent=(0,20,0,25)
plt.figure(figsize=(5.5*5,5))
plt.subplot(1,5,1)
plt.imshow(phantom[:,:,30].T,vmin=vmin,vmax=vmax,extent=extent)
plt.xlabel('y-axis')
plt.ylabel('x-axis')
plt.title('Ground Truth: mid z-slice')
plt.colorbar()
plt.subplot(1,5,2)
plt.imshow(recons[0,:,:,30].T,vmin=vmin,vmax=vmax,extent=extent)
plt.title('Axis at center: mid z-slice')
plt.xlabel('y-axis')
plt.ylabel('x-axis')
plt.colorbar()
plt.subplot(1,5,3)
plt.imshow(recons[2,:,:,30].T,vmin=vmin,vmax=vmax,extent=extent)
plt.title('Axis at 4 cm from center: mid z-slice')
plt.xlabel('y-axis')
plt.ylabel('x-axis')
plt.colorbar()
plt.subplot(1,5,4)
plt.imshow(recons[4,:,:,30].T,vmin=vmin,vmax=vmax,extent=extent)
plt.title('Axis at 8 cm from center: mid z-slice')
plt.xlabel('y-axis')
plt.ylabel('x-axis')
plt.colorbar()
plt.subplot(1,5,5)
plt.imshow(recons[6,:,:,30].T,vmin=vmin,vmax=vmax,extent=extent)
plt.title('Axis at 12 cm from center: mid z-slice')
plt.xlabel('y-axis')
plt.ylabel('x-axis')
plt.colorbar()
plt.tight_layout()
plt.suptitle("Comparison of Z-slices",y=1.05)
plt.show()

In [None]:
weights=sino_window_and_circ_block(np.ones((3,num_slices,num_cols)),np.array([0,1,2]),(num_cols,num_rows),beam_pixel_diam,num_stack,stack_offset,center_offset,for_weights=True)[0]
twoD_mask=(weights==0).T
threeD_mask=np.repeat(twoD_mask[np.newaxis,:,:],num_rows,axis=0)

In [None]:
Sections=11
inc=num_rows//Sections
Full_nrmse_axis_sweep=[]

Full_nrmse_sections_axis_sweep=[]

Sectional_nrmse_axis_sweep=np.zeros((Sections,len(center_offset_cm_vec)))

Sectional_nrmse_summed_axis_sweep=np.zeros((Sections,len(center_offset_cm_vec)))

ind2=int(num_cols/2+beam_pixel_diam/2)+1
ind1=int(num_cols/2-beam_pixel_diam/2)

masked_GT=np.ma.masked_where(threeD_mask,phantom)
GT=remove_tip_tilt(masked_GT,axis=0).data

for i in range(len(center_offset_cm_vec)):
    masked_recons=np.ma.masked_where(threeD_mask,recons[i])
    recons_axis_sweep=remove_tip_tilt(masked_recons,axis=0).data
    Full_nrmse_axis_sweep.append(nrmse_cropped(GT,recons_axis_sweep,~threeD_mask))

    for j in range(Sections):
        Sectional_nrmse_axis_sweep[j,i]=nrmse_cropped(GT[j*inc:(j+1)*inc],recons_axis_sweep[j*inc:(j+1)*inc],~threeD_mask[j*inc:(j+1)*inc])

        img1=np.average(phantom[j*inc:(j+1)*inc],axis=0)
        img1=np.ma.masked_where(twoD_mask,img1)
        img1=remove_tip_tilt(img1)

        img3=np.average(recons[i,j*inc:(j+1)*inc],axis=0)
        img3=np.ma.masked_where(twoD_mask,img3)
        img3=remove_tip_tilt(img3)

        Sectional_nrmse_summed_axis_sweep[j,i]=nrmse_masked(img1,img3)

    Full_nrmse_sections_axis_sweep.append(np.average(Sectional_nrmse_summed_axis_sweep[:,i],axis=0))

In [None]:
plt.figure(figsize=(10,5))
plt.plot(center_offset_cm_vec,Full_nrmse_axis_sweep,'o''-')
plt.xlabel('Distance from center')
plt.title("NRMSE over Central Beam's Path")
plt.show()

In [None]:
Sections=11
dist=20
step=2
plt.figure(figsize=(10*.7,8*.7))
plt.subplot(2,1,1)
plt.plot([(i+1/2)*dist/Sections -dist/2 for i in range(0,Sections)],Sectional_nrmse_summed_axis_sweep[:,::step],'o''-')
plt.legend([f"{i} cm from center" for i in center_offset_cm_vec[::step]])
plt.title('NRMSE Relative to Block Averaged Region in wind tunnel \n ')
plt.xlabel('Region')
plt.ylabel('NRMSE')

plt.subplot(2,1,2)
plt.plot([(i+1/2)*dist/Sections -dist/2 for i in range(0,Sections)],np.average(Sectional_nrmse_summed_axis_sweep,axis=1),'o''-')
plt.legend(['Averaged across all viewing parameters'])
plt.title('NRMSE Relative to Block Averaged Region in wind tunnel \n ')
plt.xlabel('Region')
plt.ylabel('NRMSE')

plt.tight_layout()

In [None]:
#zernike Analysis
Sections=5
max_order=8
ind2=int(num_cols/2+beam_pixel_diam/2)+1
ind1=int(num_cols/2-beam_pixel_diam/2)
img1=np.zeros((Sections,61,61))
img2=np.zeros((len(center_offset_cm_vec),Sections,61,61))

zern_nrmse=np.zeros((len(center_offset_cm_vec),Sections,max_order+2))
zern_gt_list=[]
zern_recon_list=[[] for i in range(len(center_offset_cm_vec))]

inc=num_rows//Sections
for i in range(Sections):
    img1[i]=np.average(phantom[i*inc:(i+1)*inc,ind1:ind2,1:],axis=0)
    for j in range(len(center_offset_cm_vec)):
        img2[j,i]=np.average(recons[j,i*inc:(i+1)*inc,ind1:ind2,1:],axis=0)
        zern_gt,zern_recon,zern_nrmse[j,i]=analyze_zernike_error(img1[i], img2[j,i], max_order)
        zern_recon_list[j].append(zern_recon)
    zern_gt_list.append(zern_gt)

In [None]:
top_order=8
plt.figure(figsize=(10*1.3,5*1.3))
plt.plot(center_offset_cm_vec,np.average(zern_nrmse[:,:,2:top_order],axis=1),'o''-')
plt.legend([f'Degree $\leq$ {i-1} removed' for i in range(2,top_order)])
plt.xlabel('Distance from Center')
plt.ylabel(f'NRMSE Averaged over {Sections} Sections')
plt.title('Average NRMSE Relative to Axis Location')
plt.show()

In [None]:
plt.figure(figsize=(10,5))
step=1
plt.plot([i for i in range(-1,max_order+1)],np.average(zern_nrmse[::step,:,:],axis=1).T)
plt.legend([f"{i} cm from center" for i in center_offset_cm_vec[::step]])
plt.xlabel("Maximum Radial Degree Removed")
plt.ylabel(f'NRMSE Averaged over {Sections} Sections')
plt.title('NRMSE Relative to Maximum Radial Degree Removed')
plt.show()

In [None]:
zern_mode_index=5

regions=[0,2,4]
#comparing
axis_loc1=0
axis_loc2=4
axis_loc3=8
axis_loc4=12 #

#display
index1=int(axis_loc1/2)
index2=int(axis_loc2/2)
index3=int(axis_loc3/2)
index4=int(axis_loc4/2)

vmin=min(min(np.ma.min(zern_recon_list[index1][i][zern_mode_index]) for i in regions),
         min(np.ma.min(zern_recon_list[index2][i][zern_mode_index]) for i in regions),
         min(np.ma.min(zern_recon_list[index3][i][zern_mode_index]) for i in regions),
         min(np.ma.min(zern_recon_list[index4][i][zern_mode_index]) for i in regions),
         min(np.ma.min(zern_gt_list[i][zern_mode_index]) for i in regions))
vmax=max(max(np.ma.max(zern_recon_list[index1][i][zern_mode_index]) for i in regions),
         max(np.ma.max(zern_recon_list[index2][i][zern_mode_index]) for i in regions),
         max(np.ma.max(zern_recon_list[index3][i][zern_mode_index]) for i in regions),
         max(np.ma.max(zern_recon_list[index4][i][zern_mode_index]) for i in regions),
         max(np.ma.max(zern_gt_list[i][zern_mode_index]) for i in regions))

plt.figure(figsize=(5*4,len(regions)*4))
itr=0
for itr,i in enumerate(regions):
    plt.subplot(len(regions),5,itr*5+1)
    plt.imshow(zern_gt_list[i][zern_mode_index].T,'jet',vmin=vmin,vmax=vmax)
    plt.title(f'Ground Truth Section {i+1}')
    plt.colorbar()
    plt.subplot(len(regions),5,itr*5+2)
    plt.imshow(zern_recon_list[index1][i][zern_mode_index].T,'jet',vmin=vmin,vmax=vmax)
    plt.title(f'{center_offset_cm_vec[index1]} cm from center \n Section {i+1} NRMSE={zern_nrmse[index1,i,zern_mode_index]*100: 0.3f}%')
    plt.colorbar()
    plt.subplot(len(regions),5,5*itr+3)
    plt.imshow(zern_recon_list[index2][i][zern_mode_index].T,'jet',vmin=vmin,vmax=vmax)
    plt.title(f'{center_offset_cm_vec[index2]} cm from center  \n Section {i+1} NRMSE={zern_nrmse[index2,i,zern_mode_index]*100: 0.3f}%')
    plt.colorbar()
    plt.subplot(len(regions),5,5*itr+4)
    plt.imshow(zern_recon_list[index3][i][zern_mode_index].T,'jet',vmin=vmin,vmax=vmax)
    plt.title(f'{center_offset_cm_vec[index3]} cm from center \n Section {i+1} NRMSE={zern_nrmse[index3,i,zern_mode_index]*100: 0.3f}%')
    plt.colorbar()
    plt.subplot(len(regions),5,5*itr+5)
    plt.imshow(zern_recon_list[index4][i][zern_mode_index].T,'jet',vmin=vmin,vmax=vmax)
    plt.title(f'{center_offset_cm_vec[index4]} cm from center \n Section {i+1} NRMSE={zern_nrmse[index4,i,zern_mode_index]*100: 0.3f}%')
    plt.colorbar()
    itr+=1
plt.suptitle(f"Radial Degree $\leq$ {zern_mode_index-1} Removed",y=1.01)
plt.tight_layout()

Multi-axis visualizations

In [None]:

def plot_rectangle_with_lines(locations, angles, diameter=0):
    # Create a figure and axis with the specified dimensions
    fig, ax = plt.subplots(figsize=(24, 29))

    # Set the limits of the plot
    ax.set_xlim(-12, 12)
    ax.set_ylim(-14.5, 14.5)

    # Draw the rectangle centered at (0, 0) with width 20 and height 25
    rectangle = plt.Rectangle((-10, -12.5), 20, 25, edgecolor='black', facecolor='white', zorder=0)
    ax.add_patch(rectangle)

    # Plot lines from each location at specified angles
    for i, location in enumerate(locations):
        x, y = location
        for angle in angles[i]:
            # Convert angle to radians relative to the negative x-axis
            angle_rad = np.deg2rad(angle + 180)
            # Calculate the end point of the line to the outer edge of the plot
            if np.cos(angle_rad) != 0:
                x_end = 12 if np.cos(angle_rad) > 0 else -12
                y_end = y + (x_end - x) * np.tan(angle_rad)
            else:
                y_end = 14.5 if np.sin(angle_rad) > 0 else -14.5
                x_end = x + (y_end - y) / np.tan(angle_rad)

            # Plot the line with increased thickness
            ax.plot([x, x_end], [y, y_end], 'r--', linewidth=2, zorder=1)

            # If diameter is positive, add a translucent red band around the line
            if diameter > 0:
                length = np.sqrt((x_end - x)**2 + (y_end - y)**2)
                angle_deg = np.rad2deg(angle_rad)
                band = plt.Rectangle((x+diameter*np.sin(angle_rad)/2, y - diameter*np.cos(angle_rad)/2), length, diameter,
                                     angle=angle_deg, color='red', alpha=0.3, zorder=0.5)
                ax.add_patch(band)

    # Show the plot
    plt.show()

In [None]:
# Example usage
locations = [(12, -0.9), (12, 0.9)]
angles = [(-6,-3,0,3), (-3,0,3,6)]
plot_rectangle_with_lines(locations, angles,diameter=1)