In [None]:
import os
import numpy as np
from tomography_radon import get_rmse, plot_fbp_reconstruction, get_system
from skimage.filters import threshold_mean
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

## Calculate RMSE for each reconstruction, image and matrix size

In [None]:
search_path = 'Results/ResultsBinary/'
images = os.listdir("Results/ResultsBinary/")
rms = np.zeros((4, 8, 4))
print(images)
for f, i in zip(images, range(len(images))):
    sizes = os.listdir(search_path + f)
    sizes.sort(key=int)
    for s, j in zip(sizes, range(len(sizes))):
        res = [x for x in os.listdir(search_path + f + '/' + s) if x.endswith('.npz')]
        if res: 
            res_npz = np.load(search_path + f + '/' + s + '/' + res[0])
            image = res_npz['image']
            recon_fbp = res_npz['reconstruction_fbp']
            recon_sart = res_npz['reconstruction_sart']
            recon_qa = res_npz['reconstruction_qa'][0].reshape(image.shape)
            thresh_fbp = threshold_mean(recon_fbp)
            recon_fbp = recon_fbp > thresh_fbp
            thresh_sart = threshold_mean(recon_sart)
            recon_sart = recon_sart > thresh_sart
            system = get_system(image)
            sinogram = res_npz['sinogram']
            recon_pi = np.linalg.pinv(system).dot(sinogram.flatten()).reshape(image.shape)
            thresh_sart = threshold_mean(recon_pi)
            recon_pi = recon_pi > thresh_sart
            rmse_qa = get_rmse(recon_qa.astype(int), image)
            rmse_fbp = get_rmse(recon_fbp.astype(int), image)
            rmse_sart = get_rmse(recon_sart.astype(int), image)
            rmse_pi = get_rmse(recon_pi.astype(int), image)
            rms[i, j, 0] = rmse_fbp
            rms[i, j, 1] = rmse_sart
            rms[i, j, 2] = rmse_pi
            rms[i, j, 3] = rmse_qa

## Plot all images

In [None]:
search_path = 'Results/ResultsBinary/'
images = [_ for _ in os.listdir(search_path) if not _.endswith(".npy") and not _.endswith(".png")]
print(images)
for f, i in zip(images, range(len(images))):
    sizes = os.listdir(search_path + f)
    sizes.sort(key=int)
    w = 20
    h = 20
    
    columns = 5
    print(columns)
    rows = len(sizes)
    fig, ax = plt.subplots(rows, columns, figsize=(16, 16))
    for s, j in zip(sizes, range(len(sizes))):
        res = [x for x in os.listdir(search_path + f + '/' + s) if x.endswith('.npz')]
        if res: 
            res_npz = np.load(search_path + f + '/' + s + '/' + res[0])
            image = res_npz['image']
            recon_fbp = res_npz['reconstruction_fbp']
            recon_sart = res_npz['reconstruction_sart']
            recon_qa = res_npz['reconstruction_qa'][0].reshape(image.shape)
            thresh_fbp = threshold_mean(recon_fbp)
            recon_fbp = recon_fbp > thresh_fbp
            thresh_sart = threshold_mean(recon_sart)
            recon_sart = recon_sart > thresh_sart
            system = get_system(image)
            sinogram = res_npz['sinogram']
            recon_pi = np.linalg.pinv(system).dot(sinogram.flatten()).reshape(image.shape)
            thresh_sart = threshold_mean(recon_pi)
            recon_pi = recon_pi > thresh_sart
            print(recon_pi.astype(int))
            print(recon_qa)
            # fig.add_subplot(rows, columns, j*columns + 1 )
            ax[j, 0].imshow(image, cmap='gray')
            # ax[j, 0].axis("off")
            ax[j, 0].set_xticks([])
            ax[j, 0].set_yticks([])  
            # fig.add_subplot(rows, columns, j*columns + 2)
            ax[j, 1].imshow(recon_fbp, cmap='gray')
            #ax[j, 1].axis("off")
            ax[j, 1].set_xticks([])
            ax[j, 1].set_yticks([])
            # fig.add_subplot(rows, columns, j*columns + 3)
            ax[j, 2].imshow(recon_sart, cmap='gray')
            # ax[j, 2].axis("off")
            ax[j, 2].set_xticks([])
            ax[j, 2].set_yticks([])
            # fig.add_subplot(rows, columns, j*columns + 4)
            ax[j, 3].imshow(recon_pi, cmap='gray')
            #ax[j, 3].axis("off")
            ax[j, 3].set_xticks([])
            ax[j, 3].set_yticks([])
            # fig.add_subplot(rows, columns, j*columns + 5)
            ax[j, 4].imshow(recon_qa, cmap='gray')
            #ax[j, 4].axis("off")
            ax[j, 4].set_xticks([]) 
            ax[j, 4].set_yticks([])
    print("Helloe")
    cols = ["Original", "FBP", "SART", "PI", "QA"]
    for axe, col in zip(ax[0], cols):
        axe.set_title(col, size=40)

    # rows = ["4 x 4 px", "8 x 8 px", "12 x 12 px", "16 x 16 px", "20 x20 px", "24 x 24 px", "28 x 28 px", "32 x 32 px"]
    rows = ["4", "8", "12", "16", "20", "24", "28", "32"]
    for axe, row in zip(ax[:,0], rows):
        axe.set_ylabel(row, rotation=0, size=40)
        axe.yaxis.set_label_coords(-.5, 0.3)
    fig.tight_layout()
    plt.savefig(search_path + f + ".png")
    plt.show()
    
            

## Plot mean RMSE for each reconstruction over matrix size

In [None]:
rms_mean = np.sum(rms, axis=0)/4
mat_sizes = [4, 8, 12, 16, 20, 24, 28, 32]
plt.plot (rms_mean[:, 0], label='FBP')
plt.fill_between(range(8), rms_mean[:, 0]-np.std(rms[:, :, 0], axis=0), rms_mean[:, 0]+np.std(rms[:, :, 0], axis=0), alpha=.1)
plt.plot (rms_mean[:, 1], label='SART')
plt.fill_between(range(8), rms_mean[:, 1]-np.std(rms[:, :, 1], axis=0), rms_mean[:, 1]+np.std(rms[:, :, 1], axis=0), alpha=.1)
plt.plot (rms_mean[:, 2]+0.005, label='PI')
plt.fill_between(range(8), rms_mean[:, 2]-np.std(rms[:, :, 2], axis=0), rms_mean[:, 2]+np.std(rms[:, :, 2], axis=0), alpha=.1)
plt.plot (rms_mean[:, 3], label='QA')
plt.fill_between(range(8), rms_mean[:, 3]-np.std(rms[:, :, 3], axis=0), rms_mean[:, 3]+np.std(rms[:, :, 3], axis=0), alpha=.1)
plt.xticks(range(8), mat_sizes)
plt.xlabel("Image size")
plt.ylabel("RMSE")
plt.title("Mean RMSE: Image size")
plt.legend()
plt.savefig("Results/ResultsBinary/mean_rmse.png")
plt.show()

## Plot RMSE for each reconstruction over matrix size for Foam.png

In [None]:
rms_foam = rms[0, ...]
mat_sizes = [4, 8, 12, 16, 20, 24, 28, 32]
plt.plot(rms_foam[:, 0], label='FBP')
plt.plot(rms_foam[:, 1], label='SART')
plt.plot(rms_foam[:, 2]+0.005, label='PI')
plt.plot(rms_foam[:, 3], label='QA')
plt.xticks(range(8), mat_sizes)
plt.xlabel("Image size")
plt.ylabel("RMSE")
plt.title("Foam RMSE: Image size")
plt.legend()
plt.savefig("Results/ResultsBinary/foam_rmse.png")
plt.show()## Plot RMSE for each reconstruction over matrix size for Foam.png

## Plot RMSE for each reconstruction over matrix size for Molecule.png

In [None]:
rms_molecule = rms[1, ...]
mat_sizes = [4, 8, 12, 16, 20, 24, 28, 32]
plt.plot (rms_molecule[:, 0], label='FBP')
plt.plot (rms_molecule[:, 1], label='SART')
plt.plot (rms_molecule[:, 2]+0.005, label='PI')
plt.plot (rms_molecule[:, 3], label='QA')
plt.xticks(range(8), mat_sizes)
plt.xlabel("Image size")
plt.ylabel("RMSE")
plt.title("Molecule RMSE: Image size")
plt.legend()
plt.savefig("Results/ResultsBinary/molecule_rmse.png")
plt.show()

## Plot RMSE for each reconstruction over matrix size for Snowflake.png

In [None]:
rms_snowflake = rms[2, ...]
mat_sizes = [4, 8, 12, 16, 20, 24, 28, 32]
plt.plot (rms_snowflake[:, 0], label='FBP')
plt.plot (rms_snowflake[:, 1], label='SART')
plt.plot (rms_snowflake[:, 2]+0.005, label='PI')
plt.plot (rms_snowflake[:, 3], label='QA')
plt.xticks(range(8), mat_sizes)
plt.xlabel("Image size")
plt.ylabel("RMSE")
plt.title("Snowflake RMSE: Image size")
plt.legend()
plt.savefig("Results/ResultsBinary/snowflake_rmse.png")
plt.show()

## Plot RMSE for each reconstruction over matrix size for Tree.png

In [None]:
rms_tree = rms[3, ...]
mat_sizes = [4, 8, 12, 16, 20, 24, 28, 32]
plt.plot (rms_tree[:, 0], label='FBP')
plt.plot (rms_tree[:, 1], label='SART')
plt.plot (rms_tree[:, 2]+0.005, label='PI')
plt.plot (rms_tree[:, 3], label='QA')
plt.xticks(range(8), mat_sizes)
plt.xlabel("Image size")
plt.ylabel("RMSE")
plt.title("Tree RMSE: Image size")
plt.legend()
plt.savefig("Results/ResultsBinary/tree_rmse.png")
plt.show()

In [None]:
rms_mean = np.sum(rms, axis=0)/4
df_fbp = pd.DataFrame(rms_mean)
sns.lineplot(data=df_fbp)
df_sart = pd.DataFrame(rms_mean)
sns.lineplot(data=df_sart)
df_qa = pd.DataFrame(rms_mean)
sns.lineplot(data=df_qa)
plt.show()