In [1]:
import cv2 as cv
import numpy as np

In [2]:
import numpy as np
import tkinter as tk
import matplotlib
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

matplotlib.use("TkAgg")


class Animator:
    def __init__(
        self,
        update_freq: int = 100,
        fig_size: tuple[int, int] = (10, 10),
        alpha: float = 0.25,
        point_size: float = 40,
    ) -> None:

        self.point_size = point_size
        self.alpha = alpha

        self.update_freq = update_freq
        self.call_num = 0

        # Create the main window
        self.plot_window = tk.Tk()

        # Create a figure and axis
        fig = Figure(figsize=fig_size, dpi=100)
        self.ax = fig.add_subplot(111, projection="3d")

        # Create a canvas to display the plot
        self.canvas = FigureCanvasTkAgg(fig, master=self.plot_window)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        self.list_x = []
        self.list_y = []
        self.list_z = []
        self.list_losses = []

    def reset(self):
        self.list_x = []
        self.list_y = []
        self.list_z = []
        self.list_losses = []

    # Define the callback function
    def callback_func(self, orient, loss):
        x, y, z = orient

        self.list_x.append(x)
        self.list_y.append(y)
        self.list_z.append(z)
        self.list_losses.append(loss)

        if self.call_num % self.update_freq == 0:
            self.update_plot()

        self.call_num = (self.call_num + 1) % self.update_freq

    # Define the function to update the plot
    def update_plot(self, event=None):
        self.ax.clear()

        self.ax.scatter(
            self.list_x,
            self.list_y,
            self.list_z,
            c=self.list_losses,
            cmap="hot",            
            s=self.point_size,
            alpha=self.alpha,
        )

        self.ax.set_xlabel("X")
        self.ax.set_ylabel("Y")
        self.ax.set_zlabel("Z")
        self.canvas.draw()
        self.plot_window.update()

In [3]:
from view_sampler import ViewSampler, CameraConfig
from manipulated_object import ObjectPosition
from loss_funcs import *
from algs import *

from evaluate.eval_funcs import *
from evaluate.evaluator import Evaluator

animator = Animator()

# Create a camera configuration
cam_config = CameraConfig(location=(0, 0, 0.1), rotation=(np.pi / 2, 0, 0), fov=60)
world_viewer = ViewSampler("data/world_mug.xml", cam_config, simulation_time=5)
sim_viewer = ViewSampler("data/world_mug_sim.xml", cam_config)

loss_func = StructuralSimilarity()

alg1 = ParticleSwarm(sim_viewer, loss_func=loss_func)
alg1_config = ParticleSwarm.Config(time_limit=100, population=30, num_iters=300)

alg2 = SimulatedAnnealing(sim_viewer, loss_func=loss_func)
alg2_config = SimulatedAnnealing.Config(time_limit=100, num_iters=100, samples_per_temp=150)

alg3 = RandomSampling(sim_viewer, loss_func=loss_func)
alg3_config = RandomSampling.Config(time_limit=100, num_samples=10000)

alg4 = DifferentialEvolution(sim_viewer, loss_func=loss_func)
alg4_config = DifferentialEvolution.Config(time_limit=100, population=30, mut_prob=0.3)

alg5 = UniformSampling(sim_viewer, loss_func=loss_func)
alg5_config = UniformSampling.Config(time_limit=100, min_samples=10000, randomized=False)

evaluator = Evaluator(world_viewer, eval_func=IOU_Diff(method="mae"))

init_location = (0, 1.3, 0.3)
random_orientations = np.random.uniform(0, 2 * np.pi, size=(5, 3))
init_positions = [ObjectPosition(orient, init_location) for orient in random_orientations]

alg4.register_loss_callback(animator.callback_func)

for alg, config in [(alg4, alg4_config)]:
    losses = evaluator.evaluate(alg, config, init_positions)
    print(f"{type(alg).__name__}: {losses}")

Evaluating Algorithm: DifferentialEvolution | Config: DifferentialEvolution.Config(time_limit=100, num_iters=100, population=30, mut_prob=0.3, F=0.5, silent=False)


Evaluating:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
cv.destroyAllWindows()