In [1]:
import cv2 as cv
import numpy as np
import torch
from camera_simulator import CameraSimulator
from camera_utils import xy_axes_to_frame_rotation, get_torch3d_R_T
from matplotlib import pyplot as plt
from sam_segmentation import SAMSegmentation
from utils import (
    plot_segmentation_mask,
    masks_intersection_batch,
    compute_masks_IOU_batch,
)
from geometric_mesh_segmentation import CameraParameters, get_mesh_segmentation_batch

In [2]:
class ImageManipulator:
    """
    Image manipulation functions that work both on batched data and single images.
    Image batch dim is assumed to be [N, W, H, 3]
    Single image dim is assumed to be [W, H, 3]
    """
    @staticmethod
    def calc_mask(images: np.ndarray, bg_value: int = 0, orig_dims: bool = False) -> np.ndarray:
        mask = np.any(images != bg_value, axis=-1)
        if orig_dims:
            mask = np.broadcast_to(np.expand_dims(mask, axis=-1), images.shape)
        return mask

    @staticmethod
    def calc_bboxes(mask_batch: np.ndarray, margin_factor: float = 1.2) -> np.ndarray:
        x = np.any(mask_batch, axis=-1)
        y = np.any(mask_batch, axis=-2)

        def argmin_argmax(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
            # find smallest and largest indices
            imin = np.argmax(arr, axis=-1)
            arr = np.flip(arr, axis=-1)
            length = arr.shape[-1]
            imax = length - np.argmax(arr, axis=-1)

            # add margin to the indices
            diff = imax - imin
            margin = (diff * (margin_factor - 1)).astype(imin.dtype)
            imin = imin - margin
            imax = imax + margin

            # make sure we're within bounds
            imin = np.maximum(imin, 0)
            imax = np.minimum(imax, length - 1)
            return imin, imax

        xmin, xmax = argmin_argmax(x)
        ymin, ymax = argmin_argmax(y)

        return np.stack((xmin, ymin, xmax, ymax), axis=-1)

In [3]:
class ViewProvider:
    def __init__(
        self,
        world_file: str,
        cam_res: tuple[int, int] = (300, 300),
        cam_fov: int = 45,
        cam_height: float = 0.75,
        render_depth: bool = False,
    ):
        self._camera = CameraSimulator(resolution=cam_res, fovy=cam_fov, world_file=world_file)
        self._camera.set_object_position([0, 0, 0])
        self._camera_height = cam_height
        self._camera_position = [0, 0, cam_height]
        self._render_depth = render_depth

    def _render_image(self):
        if self._render_depth:
            return self._camera.render_depth(torch.eye(3), self._camera_position)
        return self._camera.render(torch.eye(3), self._camera_position)

    def get_view(self, orient: tuple[float, float, float]) -> np.ndarray:
        self._camera.set_object_orientation_euler(orient)
        image = self._render_image()
        return image

    def get_view_batch(self, orient_list: list[tuple[float, float, float]]):
        image_list = []
        for orient in orient_list:
            self._camera.set_object_orientation_euler(orient)
            image = self._render_image()
            image_list.append(image)
        return image_list

    def get_view_cropped(self, orient: tuple[float, float, float], margin_factor: float = 1.2) -> np.ndarray:
        self._camera.set_object_orientation_euler(orient)
        image = self._render_image()
        mask = ImageManipulator.calc_mask(image, bg_value=0, orig_dims=False)
        x1, y1, x2, y2 = ImageManipulator.calc_bboxes(mask, margin_factor)
        cropped = image[x1:x2, y1:y2, :]
        return cropped

In [4]:
random_orientations = np.random.uniform(0, 2 * np.pi, size=(100, 3))

viewer = ViewProvider(world_file="data/world_mug_sim.xml")

for orient in random_orientations:
    im = viewer.get_view_cropped(orient)
    cv.imshow("img", im)
    cv.waitKey(0)

QStandardPaths: wrong permissions on runtime directory /run/user/1000/, 0755 instead of 0700


KeyboardInterrupt: 

In [None]:


mesh = "./data/meshes/mug.obj"
scale = 0.02

_camera_position = [0.5, -0.5, 1.75]
cam_xy_axes = [[0.685, 0.728, 0.000], [-0.487, 0.458, 0.743]]
cam_frame_R = xy_axes_to_frame_rotation(cam_xy_axes[0], cam_xy_axes[1])
cam_resx, cam_resy = 300, 300
cam_fov = 45

table_height = 0
obj_position_actual = [0, 0, 0.08 + table_height]
obj_orientation_actual = [2.1, 0, 1.57]
obj_positions = [
    [0, 0, 0.08 + table_height],
    [0, 0, 0.08 + table_height],
    [0, 0, 0.12 + table_height],
    [0, 0, 0.12 + table_height],
]

# range is [0 - 2pi, -pi/2 - pi/2, 0 - 2pi]
obj_orientations = [[2.1, 0, 1.57], [0, 0, 1.57], [0, -0.2, 0], [3.14, -0.2, 0]]

In [None]:
masks = np.zeros([cam_resy, cam_resx])

cam_sim = CameraSimulator(resolution=(cam_resy, cam_resx), fovy=cam_fov, world_file="data/world_mug_sim.xml")

random_orientations = np.random.uniform(0, 2 * np.pi, size=(100, 3))
random_position = np.asanyarray(obj_position_actual) + np.random.uniform(-0.1, 0.1, size=(100, 3))

images_list = []
for position, orient in zip(random_position, random_orientations):
    cam_sim.set_object_position(position.tolist())
    cam_sim.set_object_orientation_euler(orient)
    im = cam_sim.render(torch.eye(3), [0, 0, 1.75])
    images_list.append(im)
    cv.imshow("img", im)
    cv.waitKey(0)

In [None]:
cv.destroyAllWindows()

In [None]:
# use sam to segment actual image and show it
cam_sim.set_object_position(obj_position_actual)
cam_sim.set_object_orientation_euler(obj_orientation_actual)
im_actual = cam_sim.render(cam_frame_R, _camera_position)

sam = SAMSegmentation()
mask_sam, score = sam.segment_image_center(im_actual, best_out_of_3=True)
plot_segmentation_mask(im_actual, mask_sam, mask_alpha=0.95, color=[30, 255, 30])
iou = compute_masks_IOU_batch(torch.from_numpy(mask_sam).unsqueeze(0), torch.cat(masks))
# save sam mask:
np.save("./sam_mask.npy", mask_sam)

# plot intersection images:
masks_geometric = torch.cat(masks, dim=0)
masks_sam = torch.from_numpy(mask_sam).unsqueeze(0)
intersection = masks_intersection_batch(masks_geometric, masks_sam)
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
for im, ax, iou_score in zip(intersection, axs, iou):
    ax.imshow(im)
    ax.axis("off")
    ax.set_title(f"IOU: {iou_score.item():.2f}")
    ax.title.set_size(30)
plt.show()

# use softmax to get pose distribution, with high low temperature:
pose_distribution = torch.softmax(iou * 3, dim=0).numpy().squeeze()
# plot pose distribution wide:
plt.figure(figsize=(20, 5))
plt.bar(range(len(pose_distribution)), pose_distribution)
plt.xticks(range(len(pose_distribution)), ["Pose 1", "Pose 2", "Pose 3", "Pose 4"])
plt.title("Pose distribution", fontsize=30)
plt.xticks(fontsize=20)
plt.show()