In [None]:
import numpy as np
import matplotlib.pyplot as plt
from inference.tomogram_reconstruction import z_smooth, surface_reconstruction

In [None]:
import pyvista as pv

## Importing the segmented tomogram

First, we import the segmented slices produced by the notebook ```tomogram_2D_segmentation```.

In [None]:
input_file_path = "../outputs/segmented_tomogram.npy"

seg_slices = np.load (input_file_path)

Notice that the order of dimensions in the stack of segmented slices is z, x, y

In [None]:
z_max, x_max, y_max = seg_slices.shape

print(f"Tomogram shape: L_x = {x_max}, L_y = {y_max}, L_z = {z_max}")

### Smoothing

We use a Gaussian filter that smoothes the binary segmented tomograms. This filter works by applying this kernel along the z-dimension, which smoothes out each slice by interpolating pixel-information from neighboring slices within the $\pm$```smoothing_depth``` of it. 

In [None]:
smoothing_depth = 2

smoothed_slices = z_smooth(seg_slices, smoothing_depth)

## 3D reconstruction

The following function uses the Marching Cube algorithm to reconstruct the voxel-based segmented tomogram into tessellated surfaces.

In [None]:
mesh = surface_reconstruction(smoothed_slices)

## Visualization

Using the **PyVista** package (https://docs.pyvista.org), we can visualize the reconstructed tomogram in 3D. 

In [None]:
def get_cam_position(radius, theta, phi):
    """Helper function for camera positioning"""
    
    cam_pos = (0.5 * y_max + radius * np.sin(theta) * np.cos(phi),
               0.5 * z_max + radius * np.sin(theta) * np.sin(phi),
               0.5 * x_max + radius * np.cos(theta))

    cam_viewup = (-np.cos(theta) * np.cos(phi), -np.cos(theta) * np.sin(phi), np.sin(theta))
    
    return cam_pos, cam_viewup    

In [None]:
plotter = pv.Plotter(notebook=True, off_screen=True, multi_samples=2, polygon_smoothing=True)

surf = plotter.add_mesh(mesh, smooth_shading=True)

plotter.set_background([0.1, 0.1, 0.1])

_center = np.array([0.5 * x_max, 0.5 * y_max, 0.5 * z_max])

cam_pos, cam_viewup = get_cam_position(2.4 * x_max, 65.0 * np.pi / 180.0, -20.0 * np.pi / 180.0)

plotter.set_position(cam_pos)
plotter.set_focus(_center)
plotter.set_viewup(cam_viewup)

plotter.show_grid()
plotter.show_axes()

plotter.show()

## Exporting the reconstructed tomogram

Using the following command, you can save the mesh output of the surface reconstruction process into a variety of formats offered by the PyVista package. You can refer to the online PyVista documentation (https://docs.pyvista.org) for further information. 

In [None]:
pv.save_meshio("../outputs/reconstructed_tomogram.obj", mesh)