# ToDo
* Add polar and euclidian grid per world view
* Find a way to hide axes and images that are on the other side (white half sphere on the inside?)
* Show info about image (location, direction, resolution, focal length, etc.)
* Allow to shoot a ray from a specigic pixel from an image, or even region of image by holding click and creating a rectangle
* Add 3d model in the middle mimcing what's seen by each image
* Show image full screen when clicking on any image
* Read more about intrinsic/extrinsic matrices, and why the z-axis is point away from object?
* Add comments
* Navigate in 3d like ghosts in video games
* Shoot rays to 3d model from each pixel of each image, find a way to visualize coverage density on 3d model
* Add keyboard shortcut to allow for showing and hiding information

In [None]:
import vtk
import os
import numpy as np

intrinsics_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/fox/train/intrinsics/'
extrinsics_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/fox/train/pose/'
images_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/fox/imgs/'
#intrinsics_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/400x400/train/intrinsics/'
#extrinsics_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/400x400/train/pose/'
#images_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/400x400/imgs/'

img_stop = 90

cameras_intrinsic = import_intrinsics(intrinsics_path, stop=img_stop)
cameras_extrinsic = import_extrinsics(extrinsics_path, stop=img_stop)
axes = gen_axes(cameras_extrinsic, axes_length=0.25)
frames = gen_img_frame(cameras_intrinsic, cameras_extrinsic, zoom_factor=0.5)

# Create a renderer, render window, and interactor
renderer = vtk.vtkRenderer()
renderer.SetBackground(255, 255, 255)  # Background color
renderer.SetUseDepthPeeling(1)
renderer.SetMaximumNumberOfPeels(100)
renderer.SetOcclusionRatio(0.1)

renderWindow = vtk.vtkRenderWindow()
renderWindow.SetAlphaBitPlanes(1)  # Use alpha bit-planes
renderWindow.SetMultiSamples(0)  # Disable multi-sampling for depth peeling
renderWindow.AddRenderer(renderer)

style = vtk.vtkInteractorStyleTrackballCamera()

renderWindowInteractor = vtk.vtkRenderWindowInteractor()
renderWindowInteractor.SetRenderWindow(renderWindow)
renderWindowInteractor.SetInteractorStyle(style)


for axis in range(axes.shape[0]):
    draw_line(renderer, axes[axis,0,:3], axes[axis,0,3:], color=(1,0,0))
    draw_line(renderer, axes[axis,1,:3], axes[axis,1,3:], color=(0,1,0))
    draw_line(renderer, axes[axis,2,:3], axes[axis,2,3:], color=(0,0,1))

for frame in range(frames.shape[0]):
    draw_line(renderer, frames[frame,0,:3], frames[frame,0,3:], color=(1,0,0))
    draw_line(renderer, frames[frame,1,:3], frames[frame,1,3:], color=(0,0,0))
    draw_line(renderer, frames[frame,2,:3], frames[frame,2,3:], color=(0,0,0))
    draw_line(renderer, frames[frame,3,:3], frames[frame,3,3:], color=(0,1,0))
    rectangle_coords = [tuple(frames[frame,0,:3]), tuple(frames[frame,1,:3]), tuple(frames[frame,2,:3]), tuple(frames[frame,3,:3])]
    draw_solid_rectangle(renderer, rectangle_coords, color=(1, 1, 1))

img_file_names = [f for f in os.listdir(images_path ) if f.endswith('.png')]
for i, img_file_name in enumerate(img_file_names):
    rectangle_coords = [tuple(frames[i,0,:3]), tuple(frames[i,1,:3]), tuple(frames[i,2,:3]), tuple(frames[i,3,:3])]
    draw_png_img(renderer, f'{images_path}{img_file_name}', rectangle_coords)
    if i == img_stop - 1:
        break

# Parameters for the grid
#origin = (0, 0, 0)  # Origin of the grid
#spacing = (5, 5, 5)  # Spacing between lines
#num_lines = (25, 25, 25)  # Number of lines in x, y, z directions
#draw_grid(renderer, origin, spacing, num_lines)

# Render and start interaction
renderWindow.Render()
renderWindowInteractor.Start()

In [None]:
def import_intrinsics(file_path, stop = -1):
    file_names = [f for f in os.listdir(file_path) if f.endswith('.txt')]
    cams_intrinsic = []
    
    for i, file_name in enumerate(file_names):
        s_intrinsic = open(f'{file_path}{file_name}').read().split()
        intrinsic = np.array(s_intrinsic, dtype=float).reshape(4, 4)
        cams_intrinsic.append(intrinsic)
        if i == stop - 1:
            break
    
    return np.array(cams_intrinsic)

In [None]:
def import_extrinsics(file_path, stop = -1):
    file_names = [f for f in os.listdir(file_path) if f.endswith('.txt')]
    cams_extrinsic = []
    
    for i, file_name in enumerate(file_names):
        s_extrinsic = open(f'{file_path}{file_name}').read().split()
        extrinsic = np.array(s_extrinsic, dtype=float).reshape(4, 4)
        cams_extrinsic.append(extrinsic)
        if i == stop - 1:
            break
    
    return np.array(cams_extrinsic)

In [None]:
def cam2world(camera_point, camera_extrinsic):
    rotation_matrix = camera_extrinsic[:3, :3]
    translation_vector = camera_extrinsic[:3, 3]
    
    rotated_camera_point = rotation_matrix @ camera_point
    world_point = rotated_camera_point + translation_vector

    return world_point

In [None]:
def pixel2world(x, y, camera_intrinsic, camera_extrinsic, zoom_factor=1):
    fx, fy, cx, cy = camera_intrinsic[0, 0], camera_intrinsic[1, 1], camera_intrinsic[0, 2], camera_intrinsic[1, 2]
    #print(f'fx: {fx}, fy:{fy}, cx: {cx}, cy: {cy}')
    
    # Normalize pixel coordinates
    u_prime = (x - cx) / fx
    v_prime = (y - cy) / fy
    #print(f'u: {u_prime}, v:{v_prime}')
    
    homogeneous_coordinates = np.array([u_prime, v_prime, 1])
    #print(f'Homogeneous coordinates: {homogeneous_coordinates}')
    euclidian_distance = np.sqrt(homogeneous_coordinates[0]**2 + homogeneous_coordinates[1]**2 + homogeneous_coordinates[2]**2)
    #print(f'Euclidian distance: {euclidian_distance}')

    # Convert to camera coordinates using the inverse of the intrinsic matrix
    #camera_coordinates = np.linalg.inv(camera_intrinsic[:3, :3]).dot(homogeneous_coordinates)
    #print(f'Camera coordinates: {camera_coordinates}')

    camera_coordinates = homogeneous_coordinates / euclidian_distance * zoom_factor
    world_coordinates = cam2world(camera_coordinates, camera_extrinsic)
    #print(f'World coordinates: {world_coordinates}')
    
    return world_coordinates

In [None]:
def gen_axes(cameras_extrinsic, axes_length = 1.):
    axes = []
    
    axis_o, axis_x, axis_y, axis_z = [0, 0, 0], [axes_length, 0, 0], [0, axes_length, 0], [0, 0, axes_length]
    axes.append([np.concatenate([axis_o, axis_x], axis=0),   # X axis
                 np.concatenate([axis_o, axis_y], axis=0),   # Y axis
                 np.concatenate([axis_o, axis_z], axis=0)])  # Z axis
    
    axis_o, axis_x, axis_y, axis_z = [0, 0, 0], [axes_length, 0, 0], [0, axes_length, 0], [0, 0, axes_length]
    for camera_extrinsic in cameras_extrinsic:
        world_o = cam2world(axis_o, camera_extrinsic)
        world_x = cam2world(axis_x, camera_extrinsic)
        world_y = cam2world(axis_y, camera_extrinsic)
        world_z = cam2world(axis_z, camera_extrinsic)
        axes.append([np.concatenate([world_o, world_x], axis=0),   # X axis
                     np.concatenate([world_o, world_y], axis=0),   # Y axis
                     np.concatenate([world_o, world_z], axis=0)])  # Z axis
    
    return np.array(axes)

In [None]:
def gen_img_frame(cameras_intrinsic, cameras_extrinsic, zoom_factor=1.0):
    frames = []
    
    for camera_intrinsic, camera_extrinsic in zip(cameras_intrinsic, cameras_extrinsic):
        world_pt1 = pixel2world(0, 0, camera_intrinsic, camera_extrinsic, zoom_factor=zoom_factor)
        world_pt2 = pixel2world(400, 0, camera_intrinsic, camera_extrinsic, zoom_factor=zoom_factor)
        world_pt3 = pixel2world(400, 400, camera_intrinsic, camera_extrinsic, zoom_factor=zoom_factor)
        world_pt4 = pixel2world(0, 400, camera_intrinsic, camera_extrinsic, zoom_factor=zoom_factor)
        frames.append([np.concatenate([world_pt1, world_pt2], axis=0),
                       np.concatenate([world_pt2, world_pt3], axis=0),
                       np.concatenate([world_pt3, world_pt4], axis=0),
                       np.concatenate([world_pt4, world_pt1], axis=0)])
    
    return np.array(frames)

In [None]:
def draw_grid(renderer, origin, spacing, num_lines):
    # Create a source of lines for the grid
    for i in range(num_lines[0] + 1):
        for j in range(num_lines[1] + 1):
            # Vertical lines
            start = (origin[0] + i * spacing[0], origin[1] + j * spacing[1], origin[2])
            end = (start[0], start[1], start[2] + num_lines[2] * spacing[2])
            draw_line(renderer, start, end, color=(192,192,192), width=1)

            if i == 0 or i == num_lines[0]:
                # Horizontal lines on the front and back sides
                for k in range(num_lines[2] + 1):
                    start = (origin[0], origin[1] + j * spacing[1], origin[2] + k * spacing[2])
                    end = (start[0] + num_lines[0] * spacing[0], start[1], start[2])
                    draw_line(renderer, start, end, color=(192,192,192), width=1)

    # Horizontal lines on the bottom and top sides
    for i in range(num_lines[0] + 1):
        for k in range(num_lines[2] + 1):
            start = (origin[0] + i * spacing[0], origin[1], origin[2] + k * spacing[2])
            end = (start[0], start[1] + num_lines[1] * spacing[1], start[2])
            draw_line(renderer, start, end, color=(192,192,192), width=1)

def draw_point(renderer, position, color=(1, 0, 0), size=10):
    points = vtk.vtkPoints()
    points.InsertNextPoint(position)

    polyData = vtk.vtkPolyData()
    polyData.SetPoints(points)

    glyphFilter = vtk.vtkVertexGlyphFilter()
    glyphFilter.SetInputData(polyData)
    glyphFilter.Update()

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(glyphFilter.GetOutputPort())

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetPointSize(size)
    actor.GetProperty().SetColor(color)

    renderer.AddActor(actor)


def draw_line(renderer, start_point, end_point, color=(1, 0, 0), width=2):
    points = vtk.vtkPoints()
    points.InsertNextPoint(start_point)
    points.InsertNextPoint(end_point)

    line = vtk.vtkLine()
    line.GetPointIds().SetId(0, 0)  # the index of the start point
    line.GetPointIds().SetId(1, 1)  # the index of the end point

    lines = vtk.vtkCellArray()
    lines.InsertNextCell(line)

    linePolyData = vtk.vtkPolyData()
    linePolyData.SetPoints(points)
    linePolyData.SetLines(lines)

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputData(linePolyData)

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(color)
    actor.GetProperty().SetLineWidth(width)

    renderer.AddActor(actor)


def draw_solid_rectangle(renderer, points_list, color=(0.5, 0.5, 0.5)):
    # Ensure there are exactly four points
    if len(points_list) != 4:
        raise ValueError("points_list must contain exactly four points.")

    # Create the points and the polygon
    points = vtk.vtkPoints()
    polygon = vtk.vtkPolygon()
    polygon.GetPointIds().SetNumberOfIds(4)  # Rectangle has 4 vertices

    for i, (x, y, z) in enumerate(points_list):
        points.InsertNextPoint(x, y, z)
        polygon.GetPointIds().SetId(i, i)

    polygons = vtk.vtkCellArray()
    polygons.InsertNextCell(polygon)

    # Create a PolyData object to hold the polygon data
    polyData = vtk.vtkPolyData()
    polyData.SetPoints(points)
    polyData.SetPolys(polygons)

    # Create a mapper and actor for the polygon
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputData(polyData)

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(color)  # Set the color of the rectangle

    # Add the actor to the renderer
    renderer.AddActor(actor)


def draw_png_img(renderer, image_path, rectangle_coords):
    # Ensure rectangle_coords has four sets of (x, y, z) coordinates
    if len(rectangle_coords) != 4:
        raise ValueError("rectangle_coords must contain exactly four (x, y, z) coordinate tuples.")

    # Create the points and polygon for the rectangle
    points = vtk.vtkPoints()
    polygon = vtk.vtkPolygon()
    polygon.GetPointIds().SetNumberOfIds(4)  # A rectangle has four corners

    for i, (x, y, z) in enumerate(rectangle_coords):
        points.InsertNextPoint(x, y, z)
        polygon.GetPointIds().SetId(i, i)

    polygons = vtk.vtkCellArray()
    polygons.InsertNextCell(polygon)

    # Create a PolyData
    polyData = vtk.vtkPolyData()
    polyData.SetPoints(points)
    polyData.SetPolys(polygons)

    # Texture coordinates
    textureCoordinates = vtk.vtkFloatArray()
    textureCoordinates.SetNumberOfComponents(2)  # (u, v) pairs
    textureCoordinates.SetName("TextureCoordinates")
    textureCoordinates.InsertNextTuple((0.0, 0.0))
    textureCoordinates.InsertNextTuple((1.0, 0.0))
    textureCoordinates.InsertNextTuple((1.0, 1.0))
    textureCoordinates.InsertNextTuple((0.0, 1.0))
    polyData.GetPointData().SetTCoords(textureCoordinates)

    # Apply the texture
    reader = vtk.vtkPNGReader()
    reader.SetFileName(image_path)
    texture = vtk.vtkTexture()
    texture.SetInputConnection(reader.GetOutputPort())

    # Map texture to polydata
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputData(polyData)

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.SetTexture(texture)

    renderer.AddActor(actor)