In [50]:
import os
import re
from collections import Counter

import cv2
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt

In [51]:
_nsre = re.compile('([0-9]+)')
def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower()
            for text in re.split(_nsre, s)]  

segmented_images_path = '/home/demir/Desktop/jhu_project/needle-segmentation/data/test_seg_res'
oct_images_path = '/home/demir/Desktop/jhu_project/oct_scans/jun11/2.1/images'

num_b_scans_volume = 5
target_depth = 0.5

seg_masks = [(cv2.imread(os.path.join(segmented_images_path, img), cv2.IMREAD_UNCHANGED)) for img in sorted(os.listdir(segmented_images_path), key=natural_sort_key)]
oct_images = [cv2.imread(os.path.join(oct_images_path, img), cv2.IMREAD_GRAYSCALE) for img in sorted(os.listdir(oct_images_path), key=natural_sort_key)]

In [52]:
seg_volume = np.stack(seg_masks[0:5], axis=0)
print(seg_volume.shape)
print(np.unique(seg_volume))

(5, 1024, 1024)
[0 1 2 3]


In [87]:
def get_points_and_colors(volume, values=[1, 2, 3]):
    z_dim, x_dim, _ = volume.shape
    first_occurrences = np.empty((0,3))
    point_colors = np.empty((0,3))
    for z in range(z_dim):
        for x in range(x_dim):
            ascan = volume[z, :, x]
            for seg_id in values:
                if seg_id == 1:
                    color = np.array([1, 0, 0])
                elif seg_id == 2:
                    color = np.array([0, 1, 0])
                elif seg_id == 3:
                    color = np.array([0, 0, 1])
                else:
                    color = np.array([0, 0, 0])
                
                first_occurrence = np.argwhere(ascan==seg_id)
                if first_occurrence.size > 0:
                    first_occurrences = np.vstack((first_occurrences, np.array([z, first_occurrence[0][0], x])))
                    point_colors = np.vstack((point_colors, color))

    return first_occurrences, point_colors

def get_depth_map(volume, seg_index):
    z_dim, x_dim, _ = volume.shape
    depth_map = np.zeros((z_dim, x_dim))
    for z in range(z_dim):
        for x in range(x_dim):
            ascan = volume[z, :, x]
            first_occurrence = np.argwhere(ascan==seg_index)
            if first_occurrence.size > 0:
                depth_map[z, x] = first_occurrence[0][0]
    return depth_map

def inpaint_layers(ilm_depth_map, rpe_depth_map):
    ilm_depth_map_max = ilm_depth_map.max()
    rpe_depth_map_max = rpe_depth_map.max()
    # normalize
    ilm_depth_map = ilm_depth_map / ilm_depth_map_max
    rpe_depth_map = rpe_depth_map / rpe_depth_map_max
    # create inpainting masks
    ilm_inpainting_mask = np.where(ilm_depth_map == 0, 1, 0).astype(np.uint8)
    rpe_inpainting_mask = np.where(rpe_depth_map == 0, 1, 0).astype(np.uint8)
    # inpaint
    inpaint_ilm = cv2.inpaint(ilm_depth_map.astype(np.float32), ilm_inpainting_mask, 3, cv2.INPAINT_NS)
    inpaint_rpe = cv2.inpaint(rpe_depth_map.astype(np.float32), rpe_inpainting_mask, 3, cv2.INPAINT_NS)
    # denormalize
    inpaint_ilm = inpaint_ilm * ilm_depth_map_max
    inpaint_rpe = inpaint_rpe * rpe_depth_map_max

    ilm_points = np.empty((0,3))
    rpe_points = np.empty((0,3))
    for i in range(inpaint_ilm.shape[0]):
        for j in range(inpaint_ilm.shape[1]):
            # ilm and rpe final points for 3d visualization
            ilm_point = np.array([i, inpaint_ilm[i, j], j])
            ilm_points = np.vstack((ilm_points, ilm_point))

            rpe_point = np.array([i, inpaint_rpe[i, j], j])
            rpe_points = np.vstack((rpe_points, rpe_point))

    return ilm_points, rpe_points

def remove_outliers(point_cloud, nb_points=5, radius=4):
    cl, ind = point_cloud.remove_radius_outlier(nb_points=nb_points, radius=radius)
    return point_cloud.select_by_index(ind)

def get_largest_cluster(point_cloud, eps=5, min_points=10):
    labels = np.array(point_cloud.cluster_dbscan(eps=eps, min_points=min_points, print_progress=False))
    largest_cluster_label =  Counter(labels).most_common(1)[0][0]
    largest_cluster_indices = np.where(labels == largest_cluster_label)
    return point_cloud.select_by_index(largest_cluster_indices[0])

def find_lowest_point(point_cloud):
    np_points = np.asarray(point_cloud.points)
    lowest_index = np.argmax(np_points, axis=0)[2]
    lowest_coords = np_points[lowest_index, :]
    return lowest_coords

def find_needle_tip(needle_point_cloud, return_clean_point_cloud=False):
    needle_point_cloud = remove_outliers(needle_point_cloud, nb_points=5, radius=4)
    needle_point_cloud = get_largest_cluster(needle_point_cloud, eps=5, min_points=10)
    needle_tip_coords = find_lowest_point(needle_point_cloud)
    if return_clean_point_cloud:
        return needle_tip_coords, needle_point_cloud
    else:
        return needle_tip_coords
    
def create_mesh_sphere(center, radius=3, color=[1., 0., 1.]):
    """
    Create a mesh sphere with the given center, radius, and color.

    Parameters:
    - center (list): The center coordinates of the sphere in the form [slice, x, y].
    - radius (float): The radius of the sphere.
    - color (list): The color of the sphere in RGB format, with values ranging from 0 to 1.

    Returns:
    - mesh_sphere (o3d.geometry.TriangleMesh): The created mesh sphere.
    """

    mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius)
    mesh_sphere.paint_uniform_color(color)

    your_transform = np.asarray(
                    [[1., 0., 0., center[0]],
                    [0., 1., 0.,  center[1]],
                    [0., 0.,  1., center[2]],
                    [0., 0., 0., 1.0]])
    mesh_sphere.transform(your_transform)
    return mesh_sphere

In [91]:
# create needle point cloud and find needle tip
needle_first_occ_coords, needle_colors = get_points_and_colors(seg_volume, values=[1])
needle_point_cloud = o3d.geometry.PointCloud()
needle_point_cloud.points = o3d.utility.Vector3dVector(needle_first_occ_coords)
needle_colors = np.array([[1, 0, 0] for _ in range(needle_first_occ_coords.shape[0])])
needle_point_cloud.colors = o3d.utility.Vector3dVector(needle_colors)

needle_tip_coords, needle_point_cloud = find_needle_tip(needle_point_cloud, return_clean_point_cloud=True)

# inpaint layers
ilm_depth_map = get_depth_map(seg_volume, seg_index=2)
rpe_depth_map = get_depth_map(seg_volume, seg_index=3)

assert ilm_depth_map.shape == rpe_depth_map.shape

ilm_points, rpe_points = inpaint_layers(ilm_depth_map, rpe_depth_map)

ilm_colors = np.array([[0, 1, 0] for _ in range(ilm_points.shape[0])])
rpe_colors = np.array([[0, 0, 1] for _ in range(rpe_points.shape[0])])

# create layers point cloud
layers_point_cloud = o3d.geometry.PointCloud()
layers_point_cloud.points = o3d.utility.Vector3dVector(np.vstack((ilm_points, rpe_points)))
layers_point_cloud.colors = o3d.utility.Vector3dVector(np.vstack((ilm_colors, rpe_colors)))

# create needle tip annotation
needle_tip_sphere = create_mesh_sphere(needle_tip_coords)

vis = o3d.visualization.Visualizer()

vis.create_window()
vis.add_geometry(needle_point_cloud)
vis.add_geometry(layers_point_cloud)
vis.add_geometry(needle_tip_sphere)   
vis.run()
vis.destroy_window()

In [84]:
needle_first_occ, needle_colors = get_points_and_colors(seg_volume, values=[1])
ilm_first_occ, ilm_colors = get_points_and_colors(seg_volume, values=[2])
rpe_first_occ, rpe_colors = get_points_and_colors(seg_volume, values=[3])

### Create needle point cloud

In [85]:
needle_point_cloud = o3d.geometry.PointCloud()

needle_point_cloud.points = o3d.utility.Vector3dVector(needle_first_occ)
needle_point_cloud.colors = o3d.utility.Vector3dVector(needle_colors)

### Apply outlier removal to needle point cloud
Needle is more prone to errors

In [56]:
cl, ind = needle_point_cloud.remove_radius_outlier(nb_points=5, radius=4)
needle_point_cloud = needle_point_cloud.select_by_index(ind)

### Clustering using dbscan

In [57]:
labels = np.array(needle_point_cloud.cluster_dbscan(eps=5, min_points=10, print_progress=True))

# Find and select largest cluster label. Filter point cloud with label indexes
print(labels.shape)
counter = Counter(labels)
print(counter)
largest_cluster_label = counter.most_common(1)[0][0]
print(largest_cluster_label)

largest_cluster_indices = np.where(labels == largest_cluster_label)
print(largest_cluster_indices[0].shape)

needle_point_cloud = needle_point_cloud.select_by_index(largest_cluster_indices[0])



### Find needle tip
Assumption: Needle tip is the lowest point in the largest cluster

In [60]:
print(np.asarray(needle_point_cloud.points).shape)

lowest_coord_index = np.argmax(np.asarray(needle_point_cloud.points), axis=0)[2]
print(lowest_coord_index)
lowest_zxy = np.asarray(needle_point_cloud.points)[lowest_coord_index, :]
print(lowest_zxy)


(390, 3)
283
[  1. 207. 641.]


### Navier strokes in painting for ILM and RPE

In [61]:
ilm_depth_map = get_depth_map(seg_volume, 2)
rpe_depth_map = get_depth_map(seg_volume, 3)

assert ilm_depth_map.shape == rpe_depth_map.shape

ilm_points, rpe_points = inpaint_layers(ilm_depth_map, rpe_depth_map)
ilm_colors = np.array([[0, 1, 0] for _ in range(ilm_points.shape[0])])
rpe_colors = np.array([[0, 0, 1] for _ in range(rpe_points.shape[0])])

### combine ilm and rpe point clouds and visualize everything together

In [62]:
layers_point_cloud = o3d.geometry.PointCloud()

layers_points = np.vstack((ilm_points, rpe_points))
layers_colors = np.vstack((ilm_colors, rpe_colors))

layers_point_cloud.points = o3d.utility.Vector3dVector(layers_points)
layers_point_cloud.colors = o3d.utility.Vector3dVector(layers_colors)

# o3d.visualization.draw_geometries([needle_point_cloud, layers_point_cloud])

### Create sphere to mark needle tip point

In [80]:
mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=3.0)
mesh_sphere.paint_uniform_color([1., 0., 1.])

your_transform = np.asarray(
                [[1., 0., 0.,  1.],
                [0., 1., 0.,  207.],
                [0., 0.,  1., 641.],
                [0., 0., 0., 1.0]])
mesh_sphere.transform(your_transform)

TriangleMesh with 762 points and 1520 triangles.

In [81]:
vis = o3d.visualization.Visualizer()

vis.create_window()
vis.add_geometry(needle_point_cloud)
vis.add_geometry(layers_point_cloud)
vis.add_geometry(mesh_sphere)   

# options = vis.get_render_option()
# options.load_from_json("render_option.json")
vis.run()
vis.destroy_window()

## All point clouds unedited for reference

In [82]:
all_point_cloud = o3d.geometry.PointCloud()

all_points = np.vstack((needle_first_occ, ilm_first_occ, rpe_first_occ))
all_colors = np.vstack((needle_colors, ilm_colors, rpe_colors))

# Convert the non-zero coordinates to a point cloud
all_point_cloud.points = o3d.utility.Vector3dVector(all_points)
all_point_cloud.colors = o3d.utility.Vector3dVector(all_colors)
o3d.visualization.draw_geometries([all_point_cloud])

#### Outlier removal on whole point cloud

In [50]:
cl, ind = all_point_cloud.remove_statistical_outlier(nb_neighbors=20, std_ratio=0.7)
point_cloud = all_point_cloud.select_by_index(ind)
o3d.visualization.draw_geometries([point_cloud])