In [None]:
import cv2 as cv
import numpy as np
from camera_simulator import CameraSimulator

In [None]:
class ImageManipulator:
    """
    Image manipulation functions that work both on batched data and single images.
    Image batch dim is assumed to be [N, W, H, 3]
    Single image dim is assumed to be [W, H, 3]
    """
    @staticmethod
    def calc_mask(images: np.ndarray, bg_value: int = 0, orig_dims: bool = False) -> np.ndarray:
        mask = np.any(images != bg_value, axis=-1)
        if orig_dims:
            mask = np.broadcast_to(np.expand_dims(mask, axis=-1), images.shape)
        return mask

    @staticmethod
    def calc_bboxes(mask_batch: np.ndarray, margin_factor: float = 1.2) -> np.ndarray:
        x = np.any(mask_batch, axis=-1)
        y = np.any(mask_batch, axis=-2)

        def argmin_argmax(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
            # find smallest and largest indices
            imin = np.argmax(arr, axis=-1)
            arr = np.flip(arr, axis=-1)
            length = arr.shape[-1]
            imax = length - np.argmax(arr, axis=-1)

            # add margin to the indices
            diff = imax - imin
            margin = (diff * (margin_factor - 1)).astype(imin.dtype)
            imin = imin - margin
            imax = imax + margin

            # make sure we're within bounds
            imin = np.maximum(imin, 0)
            imax = np.minimum(imax, length - 1)
            return imin, imax

        xmin, xmax = argmin_argmax(x)
        ymin, ymax = argmin_argmax(y)

        return np.stack((xmin, ymin, xmax, ymax), axis=-1)

In [None]:
class ImageSampler:
    def __init__(
        self,
        world_file: str,
        obj_pos: tuple[int, int, int],
        cam_pos: tuple[int, int, int],
        cam_rot: np.ndarray,
        cam_res: tuple[int, int] = (300, 300),
        cam_fov: int = 45,
        cam_depth: bool = False,
    ):
        self._cam = CameraSimulator(resolution=cam_res, fovy=cam_fov, world_file=world_file)
        self._cam.set_object_position(obj_pos)

        self._cam_pos = cam_pos
        self._cam_rot = cam_rot
        self._obj_pos = obj_pos
        self._cam_depth = cam_depth

    def _render(self):
        if self._cam_depth:
            return self._cam.render_depth(self._cam_rot, self._cam_pos)
        return self._cam.render(self._cam_rot, self._cam_pos)

    def get_view(self, orient: tuple[float, float, float]) -> np.ndarray:
        self._cam.set_obj_orient_euler(orient)
        image = self._render()
        return image

    def get_view_batch(self, orient_list: list[tuple[float, float, float]]):
        image_list = []
        for orient in orient_list:
            self._cam.set_obj_orient_euler(orient)
            image = self._render()
            image_list.append(image)
        return image_list

    def get_view_cropped(self, orient: tuple[float, float, float], margin_factor: float = 1.2) -> np.ndarray:
        self._cam.set_obj_orient_euler(orient)
        image = self._render()
        mask = ImageManipulator.calc_mask(image, bg_value=0, orig_dims=False)
        x1, y1, x2, y2 = ImageManipulator.calc_bboxes(mask, margin_factor)
        cropped = image[x1:x2, y1:y2, :]
        return cropped

    def simulate_seconds(self, seconds: float):
        self._cam.simulate_seconds(seconds)


class OrientSampler(ImageSampler):
    def __init__(
        self,
        world_file: str,
        obj_pos: tuple[int, int, int] = (0, 0, -0.7),
        cam_res: tuple[int, int] = (300, 300),
        cam_fov: int = 45,
        cam_depth: bool = False,
    ):
        super().__init__(
            world_file=world_file,
            obj_pos=obj_pos,
            cam_pos=[0, 0, 0],  # camera is at the origin
            cam_rot=np.identity(3),  # camera is looking down at the z axis
            cam_res=cam_res,
            cam_fov=cam_fov,
            cam_depth=cam_depth,
        )

In [None]:
viewer = OrientSampler(world_file="data/world_mug.xml", obj_pos=(0, 1, 0), cam_depth=False)

random_orientations = np.random.uniform(0, 2 * np.pi, size=(1000, 3))
for orient in random_orientations:
    im = viewer.get_view(orient) 
    cv.imshow("img", im)
    cv.waitKey(0)

In [None]:
cv.destroyAllWindows()