In [None]:
from manten.data.dataset_maniskill import ManiSkillDataset

cameras = ["camera1", "gripper1"]
# cameras = ["camera1"]
dataset = ManiSkillDataset(
    simulated_length=10000000,
    test_ratio=0.05,
    task="PegInsertionSide-v1",
    # task="PickCube-v1",
    pack_root="/home/i53/student/yagmurlu/code/manten/data/maniskill2/packed_demos",
    obs_horizon=2,
    pred_horizon=16,
    obs_mode="pointcloud",
    state_modality_keys=["tcp_pose"],
    rgb_modality_keys=cameras,
    control_mode="pd_ee_target_delta_pose",
    # control_mode="pd_ee_delta_pose",
    use_mmap=True,
    # use_mmap=False,
    rotation_transform="rotation_6d",
    load_count=2,
)

# print(dataset[0])

dataset_info = dataset.get_dataset_info()

print(dataset_info)

In [None]:
import optree

fi = dataset[0]

optree.tree_map(lambda x: x.shape, fi)

In [None]:
import einops
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_pil_image

data = optree.tree_map(lambda x: torch.tensor(x), dataset.get_episode(4))

mask = data["observations"]["pcd_mask"]
masked_zeroed_rgb_obs = optree.tree_map(
    lambda x, m: x * m, data["observations"]["rgb_obs"], mask
)
masked_zeroed_pcd_obs = optree.tree_map(
    lambda x, m: x * m, data["observations"]["pcd_obs"], mask
)

print(len(masked_zeroed_rgb_obs["camera1"]))

fig, axs = plt.subplots(1, 3, figsize=(15, 10))
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove margins
step = 10


def update(ix):
    ix = ix * step
    rgb_obs = to_pil_image(
        einops.rearrange(
            [masked_zeroed_rgb_obs[cam][ix] for cam in cameras],
            "cam c h w -> c (cam h) w",
        )
    )
    pcd_obs = (
        einops.rearrange(
            [masked_zeroed_pcd_obs[cam][ix] for cam in cameras],
            "cam c h w -> c (cam h) w",
        )
    ).permute(1, 2, 0) * 0.5 + 0.5
    pcd_mask = to_pil_image(
        einops.rearrange(
            [mask[cam][ix] for cam in cameras],
            "cam c h w -> c (cam h) w",
        ).float()
    )

    axs[0].imshow(rgb_obs)
    axs[1].imshow(pcd_obs)
    axs[2].imshow(pcd_mask)
    return axs


# ani = animation.FuncAnimation(fig, update, frames=len(masked_zeroed_rgb_obs["camera1"])//step, blit=False)
# from IPython.display import HTML
# HTML(ani.to_jshtml())

update(0)
plt.show()

In [None]:
import plotly.graph_objects as go
import torch.nn.functional as F


def inv_scale_action(action, low, high):
    """Inverse of `clip_and_scale_action` without clipping."""
    return (action - 0.5 * (high + low)) / (0.5 * (high - low))


def get_scattered_points(points, n_points=None, std=None):
    if n_points is None:
        n_points = 10
    if std is None:
        std = 0.001

    exp_points = einops.repeat(points, "b c -> (b n_points) c", n_points=n_points)
    noise = torch.randn_like(exp_points) * std

    return noise + exp_points


def get_scatter_trace(
    pcd,
    rgb,
    mask,
    special_points=None,
    special_point_colors=None,
    special_point_scatter_n_points=None,
    special_point_scatter_std=None,
):
    pcd = pcd[mask]
    rgb = rgb[mask]

    x, y, z = pcd[:, 0], pcd[:, 1], pcd[:, 2]
    r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]

    # sort by z so that it stands above the rest
    z, indices = torch.sort(z)
    x = x[indices]
    y = y[indices]
    r = r[indices]
    g = g[indices]
    b = b[indices]

    if special_points is not None:
        scattered = get_scattered_points(
            special_points, special_point_scatter_n_points, special_point_scatter_std
        )

        x = torch.cat([x, scattered[:, 0]])
        y = torch.cat([y, scattered[:, 1]])
        z = torch.cat([z, scattered[:, 2]])

        repeat_len = len(scattered) / len(special_points)
        repeated_colors = einops.repeat(
            special_point_colors, "b c -> (b n_points) c", n_points=repeat_len
        )

        r = torch.cat([r, repeated_colors[:, 0]])
        g = torch.cat([g, repeated_colors[:, 1]])
        b = torch.cat([b, repeated_colors[:, 2]])

    color = [f"rgb({int(r[i]*255)}, {int(g[i]*255)}, {int(b[i]*255)})" for i in range(len(r))]

    return x, y, z, color


def render_masked_3d_scatter(x, y, z, color, frames=None):
    scatter = go.Scatter3d(
        x=x, y=y, z=z, mode="markers", marker={"size": 3, "color": color, "opacity": 0.8}
    )

    if frames is not None:
        layout = go.Layout(
            scene={
                "xaxis": {"title": "X Axis", "range": [-1, 1], "autorange": False},
                "yaxis": {"title": "Y Axis", "range": [-1, 1], "autorange": False},
                "zaxis": {"title": "Z Axis", "range": [-1, 1], "autorange": False},
            },
            scene_aspectmode="cube",
            margin={"l": 0, "r": 0, "b": 0, "t": 0},
            height=750,
            updatemenus=[
                {
                    "buttons": [
                        {"args": [None], "label": "Play", "method": "animate"},
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                }
            ],
        )
        frames = [
            go.Frame(
                data=[
                    go.Scatter3d(
                        x=x,
                        y=y,
                        z=z,
                        mode="markers",
                        marker={"size": 3, "color": color, "opacity": 0.8},
                    )
                ]
            )
            for x, y, z, color in frames
        ]
        fig = go.Figure(data=[scatter], layout=layout, frames=frames)
    else:
        layout = go.Layout(
            scene={
                "xaxis": {"title": "X Axis"},
                "yaxis": {"title": "Y Axis"},
                "zaxis": {"title": "Z Axis"},
            },
            margin={"l": 0, "r": 0, "b": 0, "t": 0},
            height=750,
        )
        fig = go.Figure(data=[scatter], layout=layout)

    return fig


data = optree.tree_map(lambda x: torch.tensor(x), dataset.get_episode(4))
# data = dataset[0]

# goal_pos = data["observations"]["state_obs"]  # shape: (obs_horizon, 3) (position)
# goal_rgb = torch.tensor([0.0, 1.0, 0.0]).expand(goal_pos.shape)
goal_pos = torch.zeros((0,))
goal_rgb = torch.zeros((0,))

tcp_pos = data["observations"]["state_obs"]["tcp_pose"][
    ..., :3
]  # shape: (obs_horizon, 7) (position + quaternion)
tcp_rgb = torch.tensor([1.0, 0.0, 0.0]).expand(tcp_pos.shape)

delta_trajectory = data["actions"][..., :3] * 0.1
# delta_trajectory = data['actions'][..., :3]
# delta_trajectory = inv_scale_action(data["actions"][..., :3], -1, 1)

# # cumsum only works for !target! delta, not delta with current position
trajectory = torch.cat([tcp_pos[:1, :3], delta_trajectory], dim=0).cumsum(dim=0)
trajectory_rgb = torch.tensor([0.0, 0.0, 1.0]).expand(trajectory.shape)
# trajectory = torch.zeros((0,))
# trajectory_rgb = torch.zeros((0,))

mask = data["observations"]["pcd_mask"]
pcd = data["observations"]["pcd_obs"]
rgb = data["observations"]["rgb_obs"]

scale_factor = 0
if scale_factor:
    rgb = optree.tree_map(
        lambda x: F.interpolate(x, scale_factor=1 / scale_factor, mode="bilinear"), rgb
    )
    pcd = optree.tree_map(
        lambda x: F.interpolate(x, scale_factor=1 / scale_factor, mode="bilinear"), pcd
    )
    mask = optree.tree_map(
        lambda x: -F.max_pool2d(-x.float(), kernel_size=scale_factor) > (1 / 2),
        mask,
    )

mask = optree.tree_map(lambda x: einops.rearrange(x, "b 1 h w -> b (h w)"), mask)
pcd = optree.tree_map(lambda x: einops.rearrange(x, "b c h w -> b (h w) c"), pcd)
rgb = optree.tree_map(lambda x: einops.rearrange(x, "b c h w -> b (h w) c"), rgb)

combined_pcd = torch.cat([pcd[cam] for cam in pcd], dim=1)
combined_rgb = torch.cat([rgb[cam] for cam in rgb], dim=1)
combined_mask = torch.cat([mask[cam] for cam in mask], dim=1)

next 2 cells for animated pcd, 3rd for static

In [None]:
# frames = []
# for idx in range(0, len(pcd["camera1"]), 5):
#     x, y, z, color = get_scatter_trace(
#         combined_pcd[idx],
#         combined_rgb[idx],
#         # torch.ones_like(mask[idx]),
#         combined_mask[idx],
#         # special_points=torch.cat([tcp_pos, trajectory, goal_pos[:1]], dim=0),
#         # special_point_colors=torch.cat([tcp_rgb, trajectory_rgb, goal_rgb[:]], dim=0),
#     )
#     frames.append((x, y, z, color))

In [None]:
# fig = render_masked_3d_scatter(*frames[0], frames=frames[::1])
# fig.show()

In [None]:
idx = 0
x, y, z, color = get_scatter_trace(
    combined_pcd[idx],
    combined_rgb[idx],
    # torch.ones_like(mask[idx]),
    combined_mask[idx],
    # special_points=torch.cat([tcp_pos, trajectory, goal_pos[:1]], dim=0),
    # special_point_colors=torch.cat([tcp_rgb, trajectory_rgb, goal_rgb[:1]], dim=0),
)
fig = render_masked_3d_scatter(x, y, z, color)
fig.show()

In [None]:
def distance_weights(points, mask, center, inverse_power=2):
    # Calculate the distance of each point in combined_pcd from S
    distances = torch.linalg.vector_norm(
        points - center.reshape(-1, 1, 3), dim=2
    )  # b, n_points, c

    # Calculate weights inversely proportional to the distances
    weights = 1 / (
        distances**inverse_power + 1e-8
    )  # Adding a small value to avoid division by zero

    weights = mask * weights

    weights = weights / weights.amax(dim=1, keepdim=True)

    return weights

In [None]:
def calculate_bounding_box_volume(pcd, keep_mask=None):
    # Find the minimum and maximum coordinates along each axis
    # pcd (b, npoints, 3)
    if keep_mask is None:
        pcd_for_min = pcd
        pcd_for_max = pcd
    else:
        pcd_for_min = pcd.clone()
        pcd_for_min[~keep_mask] = float("inf")
        pcd_for_max = pcd.clone()
        pcd_for_max[~keep_mask] = float("-inf")

    min_coords = torch.amin(pcd_for_min, dim=1)
    max_coords = torch.amax(pcd_for_max, dim=1)

    # Calculate the side lengths of the bounding box along each axis
    side_lengths = max_coords - min_coords  # Shape: (b,3)

    # Calculate the volume of the bounding box
    volumes = torch.prod(side_lengths, dim=1)  # Shape: (b,)

    return volumes


def camera_volume_weights(pcd, separate_mask=None):
    volumes = {
        cam: calculate_bounding_box_volume(pcd[cam], separate_mask[cam]).reshape(-1, 1)
        for cam in pcd
    }
    stacked_volumes = torch.cat([volumes[cam] for cam in pcd], dim=1).reshape(-1, len(pcd), 1)
    volumes = {cam: volumes[cam] / (stacked_volumes.amax(dim=1) + 1e-6) for cam in pcd}

    if separate_mask is not None:
        combined_weights = torch.cat(
            [separate_mask[cam] * volumes[cam] for cam in pcd], dim=1
        )
    else:
        combined_weights = torch.cat([volumes[cam] for cam in pcd], dim=1)
    return combined_weights

In [None]:
import torch_fpsample

"""
Fazit: torch_fpsample is pretty efficient, but rn only works on cpu
                    CPU     GPU
    DGL:             27s     1s
    torch_fpsample: 0.4s     NA

So it seems like we should pass data to cpu??? in the middle of training
"""


def run_fps(pcd, n_samples, mask=None):
    pcd = pcd.to("cuda")
    pcd = pcd.clone()
    if mask is not None:
        mask = mask.to("cuda")
        # pcd[~mask] = float('inf') # works fine with dgl, not with quickfps
        pcd[~mask] = 0  # works fine with both

    # sampled_inds = dgl_geo.farthest_point_sampler(
    #     pcd,
    #     n_samples,
    #     0,
    # ).long()

    _, sampled_inds = torch_fpsample.sample(pcd, n_samples)

    return sampled_inds

In [None]:
import torch

from manten.agents.utils.normalization import T3DMinMaxScaler

# Define the specific point S
# S = torch.tensor([0.0, 0.0, 0.0])
S = tcp_pos

pcd_a_s = dataset_info["pcd_stats"]
scaler = T3DMinMaxScaler(**pcd_a_s, preserve_aspect_ratio=True)

tcp_pos = data["observations"]["state_obs"]["tcp_pose"][
    ..., :3
]  # shape: (obs_horizon, 7) (position + quaternion)
tcp_rgb = torch.tensor([1.0, 0.0, 0.0]).expand(tcp_pos.shape)

delta_trajectory = data["actions"][..., :3] * 0.1
# delta_trajectory = data['actions'][..., :3]
# delta_trajectory = inv_scale_action(data["actions"][..., :3], -1, 1)

scaled_pcd = scaler.scale(combined_pcd)
tcp_pos = scaler.scale(tcp_pos)
delta_trajectory = scaler.scale_without_translation(delta_trajectory)

# # cumsum only works for !target! delta, not delta with current position
trajectory = torch.cat([tcp_pos[:1, :3], delta_trajectory], dim=0).cumsum(dim=0)


# volume_w = camera_volume_weights(pcd, mask)
# # volume_w = camera_volume_weights(pcd, optree.tree_map(lambda x: torch.ones_like(x), mask))
# distance_w = distance_weights(combined_pcd, combined_mask, S, inverse_power=0.5)
# one = torch.ones_like(distance_w)

# volume_normed_distance_w = (volume_w * 1) / (volume_w * 1).sum(dim=1, keepdim=True)

# # sample_indices = torch.multinomial(volume_normed_distance_w, 128*128, replacement=False)
# sample_indices = torch.multinomial(volume_normed_distance_w, 2048, replacement=False)

sample_indices = run_fps(combined_pcd, 2048, mask=combined_mask)
sample_indices = sample_indices.to("cpu")

sample_mask = torch.zeros(combined_pcd.shape[:-1]).scatter(1, sample_indices, 1).bool()

# f_mask = combined_mask
# f_mask = combined_mask & sample_mask
f_mask = sample_mask

# frames = []
# for idx in range(0, len(pcd["camera1"]), 15):
#     frames.append(get_scatter_trace(
#         scaled_pcd[idx],
#         combined_rgb[idx],
#         f_mask[idx],
#         special_points=torch.cat([tcp_pos, trajectory, goal_pos[:1]], dim=0),
#         special_point_colors=torch.cat([tcp_rgb, trajectory_rgb, goal_rgb[:]], dim=0),
#     ))
# fig = render_masked_3d_scatter(*frames[0], frames=frames[::1])

idx = 0
x, y, z, color = get_scatter_trace(
    scaled_pcd[idx],
    combined_rgb[idx],
    # torch.ones_like(f_mask[idx]),
    f_mask[idx],
    special_points=torch.cat([tcp_pos, trajectory, goal_pos[:1]], dim=0),
    special_point_colors=torch.cat([tcp_rgb, trajectory_rgb, goal_rgb[:1]], dim=0),
)
fig = render_masked_3d_scatter(x, y, z, color)

fig.show()
0