In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

from matplotlib import pyplot as plt
import torch

from grconvnet.dataloading.datasets import YCBSimulationData, CornellDataset
from grconvnet.preprocessing import Preprocessor
from grconvnet.postprocessing import Postprocessor, Img2WorldConverter, GraspHeightAdjuster, Img2WorldCoordConverter, Decropper
from grconvnet.utils.processing import End2EndProcessor
from grconvnet.utils import visualization as vis
from grconvnet.utils.export import Exporter

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
dataset_path = Path("/home/moritz/Documents/ycb_sim_data_1")
dataset = YCBSimulationData(dataset_path)

In [3]:
sample = dataset[0]

cam_intrinsics = sample.cam_intrinsics
cam_pos = sample.cam_pos
cam_rot = sample.cam_rot
image_size = sample.rgb.shape[1:]

In [4]:
resize = True

In [5]:
e2e_processor = End2EndProcessor(
    preprocessor=Preprocessor(
        mask_rgb_neg_color=torch.Tensor([255, 255, 255]),
        mask_rgb_pos_color=torch.Tensor([0, 0, 0]),
        resize=resize,
    ),
    postprocessor=Postprocessor(n_grasps=3),
    img2world_converter=Img2WorldConverter(
        coord_converter=Img2WorldCoordConverter(
            cam_intrinsics,
            cam_rot,
            cam_pos,
        ),
        decropper=Decropper(resized_in_preprocess=resize, original_img_size=image_size),
        height_adjuster=GraspHeightAdjuster(
            min_height=0.01,
            target_grasp_depth=0.04,
        ),
    )
    # if cam_intrinsics is not None and cam_pos is not None and cam_rot is not None
    # else None,
)


In [None]:
for sample in dataset:
    process_data = e2e_processor(sample)

    export_data = {
        "rgb_cropped": process_data["preprocessor"]["rgb_cropped"],
        "depth_cropped": process_data["preprocessor"]["depth_cropped"],
        "rgb_masked": process_data["preprocessor"]["rgb_masked"],
        "q_img": process_data["postprocessor"]["q_img"],
        "angle_img": process_data["postprocessor"]["angle_img"],
        "width_img": process_data["postprocessor"]["width_img"],
        "grasps_img": process_data["grasps_img"],
        "grasps_world": process_data["grasps_world"],
        "model_input": process_data["model_input"],
        "overview": fig,
    }

    exporter = Exporter(export_dir=Path.cwd().parent / "grconvnet" / "results" / "")
    export_path = exporter(export_data, f"_{process_data['sample'].name}")