In [31]:
import cv2 as cv
import numpy as np
import torch
from camera_simulator import CameraSimulator
from camera_utils import xy_axes_to_frame_rotation, get_torch3d_R_T
from matplotlib import pyplot as plt
from sam_segmentation import SAMSegmentation
from utils import (
    plot_segmentation_mask,
    plot_grid_segmentation_masks,
    masks_intersection_batch,
    compute_masks_IOU_batch,
)
from geometric_mesh_segmentation import CameraParameters, get_mesh_segmentation_batch

In [32]:


mesh = "./data/meshes/mug.obj"
scale = 0.02

cam_pos = [0.5, -0.5, 1.75]
cam_xy_axes = [[0.685, 0.728, 0.000], [-0.487, 0.458, 0.743]]
cam_frame_R = xy_axes_to_frame_rotation(cam_xy_axes[0], cam_xy_axes[1])
cam_resx, cam_resy = 300, 300
cam_fov = 45
cam_znear, cam_zfar = 0.1, 100

table_height = 1
obj_position_actual = [0, 0, 0.08 + table_height]
obj_orientation_actual = [2.1, 0, 1.57]
obj_positions = [
    [0, 0, 0.08 + table_height],
    [0, 0, 0.08 + table_height],
    [0, 0, 0.12 + table_height],
    [0, 0, 0.12 + table_height],
]

# range is [0 - 2pi, -pi/2 - pi/2, 0 - 2pi]
obj_orientations = [[2.1, 0, 1.57], [0, 0, 1.57], [0, -0.2, 0], [3.14, -0.2, 0]]

In [33]:
masks = np.zeros([cam_resy, cam_resx])

cam_sim = CameraSimulator(resolution=(cam_resy, cam_resx), fovy=cam_fov, world_file="data/world_mug_sim.xml")

random_orientations = np.random.uniform(0, 2 * np.pi, size=(1000, 3))
random_position = np.asanyarray(obj_position_actual) + np.random.uniform(-0.1, 0.1, size=(1000, 3))

images_list = []
for position, orient in zip(random_position, random_orientations):
    cam_sim.set_manipulated_object_position(position.tolist())
    cam_sim.set_manipulated_object_orientation_euler(orient)
    im = cam_sim.render(cam_frame_R, cam_pos)
    images_list.append(im)

In [34]:
def calc_mask(images: np.ndarray, bg_value: int = 0) -> np.ndarray:
    mask = np.any(images != bg_value, axis=-1)
    mask = np.broadcast_to(np.expand_dims(mask, axis=-1), images.shape)
    return mask


def get_bboxes(image_batch: np.ndarray, bg_value: int = 0, margin_factor: float = 1.2) -> np.ndarray:
    masks = np.any(image_batch != bg_value, axis=-1)
    x = np.any(masks, axis=-1)
    y = np.any(masks, axis=-2)

    def argmin_argmax(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        # find smallest and largest indices
        imin = np.argmax(arr, axis=-1)
        arr = np.flip(arr, axis=-1)
        length = arr.shape[-1]
        imax = length - np.argmax(arr, axis=-1)

        # add margin to the indices
        diff = imax - imin
        margin = (diff * (margin_factor - 1)).astype(imin.dtype)
        imin = imin - margin
        imax = imax + margin

        # make sure we're within bounds
        imin = np.maximum(imin, 0)
        imax = np.minimum(imax, length - 1)
        return imin, imax

    xmin, xmax = argmin_argmax(x)
    ymin, ymax = argmin_argmax(y)

    return np.stack((xmin, ymin, xmax, ymax), axis=-1)

In [36]:
image_batch = np.stack(axis=0, arrays=images_list)

In [37]:
masks = calc_mask(image_batch)
#cv.imshow("img", image_batch[0])
#cv.waitKey(0)
cv.destroyAllWindows()
print(masks.shape)

(1000, 300, 300, 3)


In [38]:
bboxes = get_bboxes(image_batch, margin_factor=1)

In [30]:
print(bboxes[214])

for i, bbox in enumerate(bboxes[0]):
    cv.imshow("img", image_batch[i][bbox[0] : bbox[2], bbox[1] : bbox[3]])
    cv.waitKey(0)

cv.destroyAllWindows()

In [None]:
image_batch[1][masks[0]] = 100

cv.imshow("img", image_batch[1])
cv.waitKey(0)

In [None]:
# use sam to segment actual image and show it
cam_sim.set_manipulated_object_position(obj_position_actual)
cam_sim.set_manipulated_object_orientation_euler(obj_orientation_actual)
im_actual = cam_sim.render(cam_frame_R, cam_pos)

sam = SAMSegmentation()
mask_sam, score = sam.segment_image_center(im_actual, best_out_of_3=True)
plot_segmentation_mask(im_actual, mask_sam, mask_alpha=0.95, color=[30, 255, 30])
iou = compute_masks_IOU_batch(torch.from_numpy(mask_sam).unsqueeze(0), torch.cat(masks))
# save sam mask:
np.save("./sam_mask.npy", mask_sam)

# plot intersection images:
masks_geometric = torch.cat(masks, dim=0)
masks_sam = torch.from_numpy(mask_sam).unsqueeze(0)
intersection = masks_intersection_batch(masks_geometric, masks_sam)
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
for im, ax, iou_score in zip(intersection, axs, iou):
    ax.imshow(im)
    ax.axis("off")
    ax.set_title(f"IOU: {iou_score.item():.2f}")
    ax.title.set_size(30)
plt.show()

# use softmax to get pose distribution, with high low temperature:
pose_distribution = torch.softmax(iou * 3, dim=0).numpy().squeeze()
# plot pose distribution wide:
plt.figure(figsize=(20, 5))
plt.bar(range(len(pose_distribution)), pose_distribution)
plt.xticks(range(len(pose_distribution)), ["Pose 1", "Pose 2", "Pose 3", "Pose 4"])
plt.title("Pose distribution", fontsize=30)
plt.xticks(fontsize=20)
plt.show()