# Imports

In [None]:
import numpy as np
from scipy.spatial.transform import Rotation

import plotly.graph_objects as go

from src.data.language_sequence import LanguageSequence
from src.data.point_cloud import PointCloud
from src.networks.scenescript_model import SceneScriptWrapper

# Plotting Lib

In [None]:
UNIT_CUBE_VERTICES = (
    np.array(
        [
            (1, 1, 1),
            (1, 1, -1),
            (1, -1, 1),
            (1, -1, -1),
            (-1, 1, 1),
            (-1, 1, -1),
            (-1, -1, 1),
            (-1, -1, -1),
        ]
    )
    * 0.5
)


UNIT_CUBE_LINES_IDXS = np.array(
    [
        [0, 1],
        [0, 2],
        [0, 4],
        [1, 3],
        [1, 5],
        [2, 3],
        [2, 6],
        [3, 7],
        [4, 5],
        [4, 6],
        [5, 7],
        [6, 7],
    ]
)


PLOTTING_COLORS = {
    "wall": "#FBFAF5",
    "door": "#F7C59F",
    "window": "#53F4FF",
    "bbox": "#CC3FD1",
    "points": "#C7DAE8",
    "trajectory": "#F92A82",
}

In [None]:
def language_to_bboxes(entities):
    """
    Args:
        entities: List[BaseEntity].
    """
    box_definitions = []
    # lookup table
    lookup = {}
    
    for entity in entities:

        entity_id = int(entity.params["id"])
        class_name = entity.COMMAND_STRING[5:]  # remove "make_"

        if entity.COMMAND_STRING == "make_wall":
            height = entity.params["height"]
            thickness = 0.0
            # corners
            corner_a = np.array(
                [
                    entity.params["a_x"],
                    entity.params["a_y"],
                    entity.params["a_z"],
                ]
            )
            corner_b = np.array(
                [
                    entity.params["b_x"],
                    entity.params["b_y"],
                    entity.params["b_z"],
                ]
            )
            length = np.linalg.norm(corner_a - corner_b)

            direction = corner_b - corner_a
            angle = np.arctan2(direction[1], direction[0])
            lookup[entity_id] = {**entity.params, "angle": angle}

            centre = (corner_a + corner_b) * 0.5 + np.array([0, 0, 0.5 * height])
            scale = np.array([length, thickness, height])
            rotation = Rotation.from_rotvec([0, 0, angle]).as_matrix()

        elif entity.COMMAND_STRING in {"make_door", "make_window"}:

            # Find valid wall pointer
            # NOTE: this part differs from the original implementation of this function.
            for key in ["wall_id", "wall0_id", "wall1_id"]:
                wall_id = entity.params.get(key, None)
                wall = lookup.get(wall_id, None)
                if wall is not None:
                    break
            if wall is None:
                continue
            angle, thickness = wall["angle"], wall["thickness"]

            centre = np.array(
                [
                    entity.params["position_x"],
                    entity.params["position_y"],
                    entity.params["position_z"],
                ]
            )
            rotation = Rotation.from_rotvec([0, 0, angle]).as_matrix()
            scale = np.array(
                [
                    entity.params["width"],
                    thickness,
                    entity.params["height"],
                ]
            )

        elif entity.COMMAND_STRING == "make_bbox":

            centre = np.array(
                [
                    entity.params["position_x"],
                    entity.params["position_y"],
                    entity.params["position_z"],
                ]
            )
            rotation = Rotation.from_rotvec([0, 0, entity.params["angle_z"]]).as_matrix()
            scale = np.array(
                [
                    entity.params["scale_x"],
                    entity.params["scale_y"],
                    entity.params["scale_z"],
                ]
            )
            class_name = entity.params["class"]

        box = {
            "id": entity_id,
            "cmd": entity.COMMAND_STRING,
            "class": class_name,
            "centre": centre,
            "rotation": rotation,
            "scale": scale,
        }
        box_definitions.append(box)
        
    return box_definitions


def plot_box_wireframe(box):
    box_verts = UNIT_CUBE_VERTICES * box["scale"]
    box_verts = (box["rotation"] @ box_verts.T).T
    box_verts = box_verts + box["centre"]

    lines_x = []
    lines_y = []
    lines_z = []
    for pair in UNIT_CUBE_LINES_IDXS:
        for idx in pair:
            lines_x.append(box_verts[idx, 0])
            lines_y.append(box_verts[idx, 1])
            lines_z.append(box_verts[idx, 2])
        lines_x.append(None)
        lines_y.append(None)
        lines_z.append(None)

    if box["cmd"] == "make_bbox":
        class_name = f"bbox_{box['class']}"
        plot_color = PLOTTING_COLORS["bbox"]
    else:  # wall/door/window
        class_name = box["class"]
        plot_color = PLOTTING_COLORS[class_name]
    
    wireframe = go.Scatter3d(
        x=lines_x,
        y=lines_y,
        z=lines_z,
        mode="lines",
        name=f"{class_name}_{box['id']}",
        line={
            "color": plot_color,
            "width": 10,
        },
    )
    
    return wireframe


def plot_point_cloud(point_cloud, max_points_to_plot=50_000):
    if len(point_cloud) > max_points_to_plot:
        print(
            f"The number of points ({len(point_cloud)}) exceeds the maximum that can be reliably plotted."
        )
        print(f"Randomly subsampling {max_points_to_plot} points for the plot.")
        sampled = np.random.choice(len(point_cloud), max_points_to_plot, replace=False)
        point_cloud = point_cloud[sampled]
        
    return go.Scatter3d(
        x=point_cloud[:, 0],
        y=point_cloud[:, 1],
        z=point_cloud[:, 2],
        mode="markers",
        name="Semi-dense Point Cloud",
        marker={
            "size": 1.0,
            "opacity": 0.3,
            "color": PLOTTING_COLORS["points"],
        },
    )


# Main plotting function
def plot_3d_scene(
    language_sequence=None,
    point_cloud=None,
    max_points_to_plot=50_000,
    fig_width=1000,
):
    
    traces = []
    if point_cloud is not None:
        traces.append(plot_point_cloud(point_cloud, max_points_to_plot))

    if language_sequence is not None:
        boxes = language_to_bboxes(language_sequence.entities)
        for box in boxes:
            traces.append(plot_box_wireframe(box))

    assert traces, "Nothing to visualize."
    fig = go.Figure(data=traces)
    fig.update_layout(
        template="plotly_dark",
        scene={
            "xaxis": {"showticklabels": False, "title": ""},
            "yaxis": {"showticklabels": False, "title": ""},
            "zaxis": {"showticklabels": False, "title": ""},
        },
        width=fig_width,
        height=fig_width // 2,
        scene_aspectmode="data",
        hoverlabel={"namelength": -1},
    )
    fig.show()

# Load Model + Point Cloud

In [None]:
ckpt_path = "..."  # TODO: path to checkpoint
model_wrapper = SceneScriptWrapper.load_from_checkpoint(ckpt_path).cuda()

In [None]:
point_cloud_path = "..."  # TODO: path to semidense point cloud
point_cloud_obj = PointCloud.load_from_file(point_cloud_path)

# Run Model

In [None]:
lang_seq = model_wrapper.run_inference(
    point_cloud_obj.points,
    nucleus_sampling_thresh=0.05,  # 0.0 is argmax, 1.0 is random sampling
    verbose=True,
)

# Visualisation

In [None]:
plot_3d_scene(
    lang_seq,
    point_cloud_obj.points,
    max_points_to_plot=50_000,
    fig_width=1100,
)