# Simulation for "Simplified 2D Setting"

### Setup Notebook

Before running this interactive notebook, check the following parameters:
- `ENV`: Select your execution environment `"LOCAL"` (VS Code, Jupyter Notebook) or `"COLAB"` (Google Colab).
- `PLT_INTERACTIVE`: Select whether the plots should be interactive to allow zooming, panning and resizing. In general, this can stay activated.

In [None]:
#@title { display-mode: "form" }

ENV = "LOCAL" #@param ["LOCAL", "COLAB"]
PLT_INTERACTIVE = True #@param {type:"boolean"}

# setup environment
LOCAL = "LOCAL"
COLAB = "COLAB"
if ENV == COLAB:
    %cd /content
    !git clone https://github.com/danielyxyang/active_reconstruction.git
    %cd active_reconstruction
    !git submodule update --init
    %pip install -q -r requirements.txt
    %pip install -q -r src/utils_ext/requirements.txt
    %cd src
elif ENV == LOCAL:
    %cd ../src

# setup interactive plots
if PLT_INTERACTIVE:
    if ENV == COLAB:
        from google.colab import output
        output.enable_custom_widget_manager()
    %matplotlib widget
else:
    %matplotlib inline

# ensure automatic reload of imported modules
%load_ext autoreload
%autoreload 2


### CUSTOM SETUP ###

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

import parameters as params
from algorithms.algorithms import GreedyAlgorithm, TwoPhaseAlgorithm, build_algorithms, ALGORITHMS, TRUE_ALGORITHM
from simulation.plotter import SimulationPlotter, WORLD_COLORS, ALGORITHM_COLORS
from simulation.simulation import Simulation
from simulation.camera import Camera
from utils.widgets import CameraControl, KernelSelector, ObjectSelector
from utils_ext.tools import Profiler
from utils_ext.widgets import build_widget_outputs, CheckboxList

plt.ioff() # prevent figures to be displayed without calling plt.show() or display()
SimulationPlotter.set_interactive(PLT_INTERACTIVE)

### Setup Setting

Optionally, change the parameters of the simplified 2D setting.

In [None]:
#@title { display-mode: "form" }

#@markdown *All values are specified in `[m]` except for `CAM_FOV`, which is specified in `[deg]`.*

#@markdown world
params.GRID_H = 0.2 # @param {type:"slider", min:0.05, max:1, step:0.05}

#@markdown object
params.OBJ_D_MAX = 8 #@param {type:"number"}
params.OBJ_D_MIN = 2 #@param {type:"number"}

#@markdown camera
params.CAM_D = 10 #@param {type:"number"} 
params.CAM_DOF = 10 #@param {type:"number"}
params.CAM_FOV = 35  #@param {type:"number"}
params.OBS_NOISE = 0.2 #@param {type:"slider", min:0, max:5, step:0.1}

### Run Simulation

In [None]:
#@title { display-mode: "form" }

if PLT_INTERACTIVE and ENV == COLAB:
    # hack for displaying toolbar of interactive plots in Colab
    html_hack = widgets.HTML("<style> .jupyter-matplotlib-figure { position: relative; } </style>")
    display(html_hack)

def run_simulation():
    
    ### DEFINE FUNCTIONS ###
    
    def compute(all=False):
        with profiler.cm("computation"):
            for name, algorithm in algorithms.items():
                # compute visualizable estimates at current camera location
                if isinstance(algorithm, GreedyAlgorithm) and not algorithm.objective.closed_form:
                    output["current"][name] = algorithm.compute_estimate_points(simulation.camera)
                # compute estimates at all camera locations and NBV
                if all:
                    nbv, estimates = algorithm.compute_nbv(with_estimates=True)
                    output["nbv"][name] = nbv
                    if isinstance(algorithm, GreedyAlgorithm):
                        output["all"][name] = np.array([algorithm.thetas, estimates])
                    elif isinstance(algorithm, TwoPhaseAlgorithm):
                        output["all"][name] = {
                            "phase1": np.array([algorithm.thetas, estimates[0]]),
                            "phase2": np.array([algorithm.thetas, estimates[1]]),
                        }

    def plot():
        with profiler.cm("plotting"):
            # map checkbox states to arguments of plotting functions
            show_grid = cboxes_components.value["grid"]
            show_camera = {
                "show_camera":      cboxes_components.value["camera.camera_position"],
                "show_fov":         cboxes_components.value["camera.camera_fov"],
                "show_los":         cboxes_components.value["camera.camera_los"],
                "show_view_circle": cboxes_components.value["camera.view_circle"],
            }
            show_object = {
                "show_object":      cboxes_components.value["object.surface"],
                "show_points":      cboxes_components.value["object.surface_points"],
                "show_pixels":      cboxes_components.value["object.surface_pixels"],
                "show_bounds":      cboxes_components.value["object.bounds"],
            }
            show_confidence = {
                "show_confidence":  cboxes_components.value["confidence.region"],
                "show_points":      cboxes_components.value["confidence.boundary_points"],
                "show_pixels":      cboxes_components.value["confidence.boundary_pixels"],
            }
            show_observations = cboxes_components.value["observations"]
            show_nbv_cameras = {name: {
                "show_camera":      cboxes_nbv_cameras.value[name],
                "show_fov":         cboxes_nbv_cameras.value[name],
                "show_los":         cboxes_nbv_cameras.value[name],
                "show_view_circle": False,
            } for name in cboxes_nbv_cameras.value.keys()}
            show_objectives = cboxes_algorithms.value
            
            def plot_world(plotter):
                plotter.plot_grid(show_grid=show_grid)
                for name, nbv in output["nbv"].items():
                    plotter.plot_camera(Camera(nbv), **show_nbv_cameras[name], color=ALGORITHM_COLORS[name], name=name)
                plotter.plot_camera(simulation.camera, **show_camera)
                plotter.plot_object(simulation.obj, **show_object)
                plotter.plot_confidence(simulation.algorithm.gp, **show_confidence)
                plotter.plot_observations(simulation.algorithm.observations, show=show_observations)
                for name, estimate_points in output["current"].items():
                    kwargs = dict(markersize=4, alpha=1) if "Observed" in name else dict(markersize=1, alpha=0.25)
                    plotter.plot_points(estimate_points, show=show_objectives[name], color=ALGORITHM_COLORS[name], name=name, **kwargs)
           
            # plot real world
            plot_world(plt_real_world)
            
            # plot polar world
            show_object["show_pixels"] = False
            show_confidence["show_pixels"] = False
            plot_world(plt_polar_world)
            
            # plot objectives
            with plt_estimates.use_axis("main", ylabel="estimated #points"):
                # plot estimates
                plt_estimates.axis.set_yscale("symlog", linthresh=10, linscale=0.25, subs=[2,3,4,5,6,7,8])
                plt_estimates.axis.set_ylim(auto=True)
                plt_estimates.axis.set_ymargin(0.2)
                for name, estimates in output["all"].items():
                    if isinstance(estimates, dict): # estimates of two-phase algorithm
                        plt_estimates.dynamic_plot(name + "_phase1", *estimates["phase1"], color=ALGORITHM_COLORS[name], alpha=0.5, linestyle="--", visible=show_objectives[name])
                        plt_estimates.dynamic_plot(name + "_phase2", *estimates["phase2"], color=ALGORITHM_COLORS[name], visible=show_objectives[name])
                    else:
                        plt_estimates.dynamic_plot(name, *estimates, color=ALGORITHM_COLORS[name], visible=show_objectives[name])
                plt_estimates.axis.set_ylim(bottom=0)
            with plt_estimates.use_axis("camera"):
                # plot cameras
                plt_estimates.axis.set_ylim([0, params.CAM_D * 1.1])
                plt_estimates.axis.set_yticks([])
                for name, nbv in output["nbv"].items():
                    plt_estimates.plot_camera(Camera(nbv), show_camera=show_nbv_cameras[name]["show_camera"], show_fov=False, show_los="position" if show_nbv_cameras[name]["show_los"] else False, show_view_circle=False, color=ALGORITHM_COLORS[name], name=name)
                plt_estimates.plot_camera(simulation.camera, show_camera=show_camera["show_camera"], show_fov=False, show_los="position" if show_camera["show_los"] else False, show_view_circle=False)

            results = simulation.results()
            
            # plot relative number of total observations
            x, y = results.rounds, results.n_total_rel
            plt_total.axis.set_xlim([0, np.max(x, initial=10) + 2])
            plt_total.axis.set_ylim([0, 1.2])
            plt_total.static("n_total_target", lambda: plt_total.axis.axhline(1, color="limegreen", linestyle="--", linewidth=1))
            plt_total.dynamic_plot("n_total", x, y, marker="o", markersize=4, color="limegreen")
            x_finish = x[y == 1][:1]
            plt_total.dynamic_plot("finish", x_finish, np.full_like(x_finish, 1), marker="*", markersize=10, color="limegreen")

            # plot absolute number of marginal observations
            x, y = results.rounds, results.n_marginal
            plt_marginal.axis.set_xlim([0, np.max(x, initial=10) + 2])
            plt_marginal.axis.set_ylim([0, np.max(y, initial=30) * 1.2])
            plt_marginal.dynamic_plot("n_marginal", x, y, marker="o", markersize=4, color="green")
            
            # plot regret
            x, y = results.rounds, results.regret
            plt_regret.axis.set_xlim([0, np.max(x, initial=10) + 2])
            plt_regret.axis.set_ylim([-2, np.max(y, initial=10) * 1.2])
            plt_regret.dynamic_plot("regret", x, y, marker="o", markersize=4, color="red")

            # enable NBV buttons if object not fully observed
            if len(x_finish) > 0:
                button_nbv.disabled = True
                button_nbv_measure.disabled = True
            else:
                button_nbv.disabled = False
                button_nbv_measure.disabled = False
        
        with profiler.cm("display"):
            # display plots
            plt_real_world.display(out["fig_world"])
            plt_polar_world.display(out["fig_polar"])
            plt_estimates.display(out["fig_estimates"])
            plt_total.display(out["fig_total"])
            plt_marginal.display(out["fig_marginal"])
            plt_regret.display(out["fig_regret"])
                    
        with out["log"]:
            clear_output(wait=True)
            # print finish
            if len(x_finish) > 0:
                print("Object fully observed after {} measurements!".format(x_finish[0]))
                print()
            # print profiling
            profiler.merge(simulation.obj.profiler)
            profiler.merge(simulation.algorithm.gp.profiler)
            profiler.print(["initialization", "discretization (object)", "discretization (GP)", "computation", "plotting", "display"])
            profiler.reset()
            simulation.obj.profiler.reset()
            simulation.algorithm.gp.profiler.reset()

    ### DEFINE WIDGET HANDLERS ###

    def enable_start(*args):
        button_start.disabled = False

    def start(*args):
        button_start.disabled = True

        # initialize simulation
        nonlocal simulation
        simulation = Simulation.build(
            object=object_selector.value,
            camera=Camera(slider_cam.value),
            kernel=kernel_selector.value,
        )
        # initialize estimates
        nonlocal output
        output = {
            "current": {}, # algorithm -> list of estimated points at current camera location
            "all": {},     # algorithm -> list of number of estimated points at all camera locations
            "nbv": {},     # algorithm -> NBV
        }
        # initialize algorithms    
        nonlocal algorithms
        algorithms = build_algorithms(object=object_selector.value)
        for algorithm in algorithms.values():
            algorithm.link(simulation.algorithm)

        compute(all=True)
        plot()

    def refresh_plot(*args):
        plot()
    
    def reset(*args):
        simulation.reset()
        compute(all=True)
        plot()

    def move_camera(*args):
        simulation.move_camera(slider_cam.value)
        compute()
        plot()
    
    def measure(*args):
        simulation.take_measurement()
        compute(all=True)
        plot()

    def move_nbv(*args):
        slider_cam.value = output["nbv"][dropdown_algorithm.value]

    def move_nbv_and_measure(*args):
        slider_cam.value = output["nbv"][dropdown_algorithm.value]
        button_measure.click()
    
    def run(*args):
        simulation.converged = False
        while not simulation.is_converged():
            slider_cam.value = output["nbv"][dropdown_algorithm.value]
            button_measure.click()

    ### INITIALIZATION ###

    profiler = Profiler()
    with profiler.cm("initialization"):
        # define global variables
        simulation = None
        output = None
        algorithms = None
        
        # initialize plotters
        plt_real_world = SimulationPlotter(mode="real", figsize=3)
        plt_polar_world = SimulationPlotter(mode="polar", figsize=3)
        plt_estimates = SimulationPlotter(mode="polar", figsize=3, title="Objectives")
        plt_total = SimulationPlotter(figsize=(3, 2), title="Reconstruction Progress", xlabel="rounds", ylabel="progress")
        plt_marginal = SimulationPlotter(figsize=(3, 2), title="Marginal Observations", xlabel="rounds", ylabel="#points")
        plt_regret = SimulationPlotter(figsize=(3, 2), title="Individual Regret", xlabel="rounds", ylabel="#points")
        
        ### SETUP WIDGETS ###

        # define widgets for simulation setup
        object_selector = ObjectSelector()
        object_selector.observe(enable_start)
        kernel_selector = KernelSelector()
        kernel_selector.observe(enable_start)
        button_start = widgets.Button(description="start")
        button_start.on_click(start)

        display(
            widgets.HBox([object_selector, kernel_selector]),
            button_start,
        )
        
        # define checkbox widgets for simulation
        layout_cboxes = dict(margin="0 10px 0 10px", max_height="300px", min_width="250px")
        cboxes_components = CheckboxList(
            options=[
                "grid",
                "camera",
                "camera.view_circle",
                "camera.camera_position",
                "camera.camera_fov",
                "camera.camera_los",
                "object",
                "object.bounds",
                "object.surface",
                "object.surface_points",
                "object.surface_pixels",
                "confidence",
                "confidence.region",
                "confidence.boundary_points",
                "confidence.boundary_pixels",
                "observations",
            ],
            value=["camera", "camera.*", "object.surface", "object.bounds", "confidence", "confidence.region", "observations"],
            colors={
                "grid": WORLD_COLORS["grid"],
                "camera": WORLD_COLORS["camera"],
                "camera.*": WORLD_COLORS["camera"],
                "object": WORLD_COLORS["object"],
                "object.*": WORLD_COLORS["object"],
                "confidence": WORLD_COLORS["confidence"],
                "confidence.*": WORLD_COLORS["confidence"],
                "observations": WORLD_COLORS["observations"],
            },
            layout=layout_cboxes,
            description="show components",
        )
        cboxes_components.observe(refresh_plot, names="value")
        cboxes_algorithms = CheckboxList(
            options=ALGORITHMS,
            value=[TRUE_ALGORITHM],
            colors=ALGORITHM_COLORS,
            layout=layout_cboxes,
            description="show algorithms",
        )
        cboxes_algorithms.observe(refresh_plot, names="value")
        cboxes_nbv_cameras = CheckboxList(
            options=ALGORITHMS,
            colors=ALGORITHM_COLORS,
            layout=layout_cboxes,
            description="show NBV cameras",
        )
        cboxes_nbv_cameras.observe(refresh_plot, names="value")
        
        # define action widgets for simulation
        slider_cam = CameraControl()
        slider_cam.observe(move_camera, names="value")
        dropdown_algorithm = widgets.Dropdown(options=ALGORITHMS, value=TRUE_ALGORITHM, description="algorithm")
        button_nbv = widgets.Button(description="NBV", layout=dict(width="var(--jp-widgets-inline-width-tiny)"))
        button_nbv.on_click(move_nbv)
        button_measure = widgets.Button(description="measure", layout=dict(width="var(--jp-widgets-inline-width-tiny)"))
        button_measure.on_click(measure)
        button_nbv_measure = widgets.Button(description="NBV & measure")
        button_nbv_measure.on_click(move_nbv_and_measure)
        button_run = widgets.Button(description="run")
        button_run.on_click(run)
        button_reset = widgets.Button(description="reset")
        button_reset.on_click(reset)
        
        # define output widgets for simulation
        out = build_widget_outputs(["fig_world", "fig_polar", "fig_estimates", "fig_total", "fig_marginal", "fig_regret", "log"])
        
        display(
            widgets.HBox([cboxes_components, cboxes_algorithms, cboxes_nbv_cameras]),
            slider_cam,
            dropdown_algorithm,
            widgets.HBox([button_nbv, button_measure, button_nbv_measure]),
            widgets.HBox([button_run, button_reset]),
            widgets.HBox([out["fig_world"], out["fig_polar"], out["fig_estimates"]]),
            widgets.HBox([out["fig_total"], out["fig_marginal"], out["fig_regret"]]),
            out["log"],
        )

    ### START ###

    start()

plt.close("all")
run_simulation()