In [None]:
import numpy as np
from PIL import Image
from typing import Optional, List
import matplotlib.pyplot as plt
import cv2
import yaml


VIZ_IMAGE_SIZE = (640, 480)
RED = np.array([1, 0, 0])
GREEN = np.array([0, 1, 0])
BLUE = np.array([0, 0, 1])
CYAN = np.array([0, 1, 1])

YELLOW = np.array([1, 1, 0])
MAGENTA = np.array([1, 0, 1])

In [None]:
def project_points(
    xy: np.ndarray,
    camera_height: float,
    camera_x_offset: float,
    camera_matrix: np.ndarray,
    dist_coeffs: np.ndarray,
):
    """
    Projects 3D coordinates onto a 2D image plane using the provided camera parameters.

    Args:
        xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
        camera_height: height of the camera above the ground (in meters)
        camera_x_offset: offset of the camera from the center of the car (in meters)
        camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
        dist_coeffs: vector of distortion coefficients


    Returns:
        uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
    """
    batch_size, horizon, _ = xy.shape

    # create 3D coordinates with the camera positioned at the given height
    xyz = np.concatenate(
        [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1
    )

    # create dummy rotation and translation vectors
    rvec = tvec = (0, 0, 0)

    xyz[..., 0] += camera_x_offset
    xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1)
    uv, _ = cv2.projectPoints(
        xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs
    )
    uv = uv.reshape(batch_size, horizon, 2)
    # print("UV shape", uv.shape)

    return uv

def get_pos_pixels(
    points: np.ndarray,
    camera_height: float,
    camera_x_offset: float,
    camera_matrix: np.ndarray,
    dist_coeffs: np.ndarray,
    clip: Optional[bool] = False,
):
    """
    Projects 3D coordinates onto a 2D image plane using the provided camera parameters.
    Args:
        points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
        camera_height: height of the camera above the ground (in meters)
        camera_x_offset: offset of the camera from the center of the car (in meters)
        camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
        dist_coeffs: vector of distortion coefficients

    Returns:
        pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
    """
    pixels = project_points(
        points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs
    )[0]

    
    pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0]
    
    if clip:
        pixels = np.array(
            [
                [
                    np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]),
                    np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]),
                ]
                for p in pixels
            ]
        )
    else:
        # print("Max X:", np.max(pixels[:, 0]), "Expected Max X:", VIZ_IMAGE_SIZE[0])
        # print("Max Y:", np.max(pixels[:, 1]), "Expected Max Y:", VIZ_IMAGE_SIZE[1])


        cx, cy = 0, 0
        for p in pixels:
            if p[0]>VIZ_IMAGE_SIZE[0]: 
                cx+=1
            elif p[1]>VIZ_IMAGE_SIZE[1]:
                cy+=1
        print("CX, CY", (cx, cy))

        pixels = np.array(
            [
                (p[1], p[0])
                for p in pixels
                if np.all(p > 0) and np.all(np.array((p[1], p[0])) < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]])
                # p
                # for p in pixels
                # if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]])
            ]
        )

    # print("PIXELS SHAPE: ", pixels.shape)


    return pixels

In [None]:
def plot_trajs_and_points_on_image(
    ax: plt.Axes,
    img: np.ndarray,
    dataset_name: str,
    list_trajs: list,
    list_points: list,
    traj_colors: list = [CYAN, MAGENTA],
    point_colors: list = [RED, GREEN],
):
    """
    Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is.
    Args:
        ax: matplotlib axis
        img: image to plot
        dataset_name: name of the dataset found in data_config.yaml (e.g. "recon")
        list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw)
        list_points: list of points, each point is a numpy array of shape (2,)
        traj_colors: list of colors for trajectories
        point_colors: list of colors for points
    """
    assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories"
    assert len(list_points) <= len(point_colors), "Not enough colors for points"
    assert (
        dataset_name in data_config
    ), f"Dataset {dataset_name} not found in data/data_config.yaml"

    ax.imshow(img)
    if (
        "camera_metrics" in data_config[dataset_name]
        and "camera_height" in data_config[dataset_name]["camera_metrics"]
        and "camera_matrix" in data_config[dataset_name]["camera_metrics"]
        and "dist_coeffs" in data_config[dataset_name]["camera_metrics"]
    ):
        camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"]
        camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"]

        fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"]
        fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"]
        cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"]
        cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"]
        camera_matrix = gen_camera_matrix(fx, fy, cx, cy)

        k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"]
        k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"]
        p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"]
        p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"]
        k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"]
        dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0])

        for i, traj in enumerate(list_trajs):
            xy_coords = traj[:, :2]  # (horizon, 2)
            # print("XY COORDS ", xy_coords)
            # xy_coords = xy_coords[np.newaxis, ...]
            traj_pixels = get_pos_pixels(
                xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False
            )

            # print("TRAJ PIXELS", traj_pixels)
            
            if len(traj_pixels.shape) == 2:
                ax.plot(
                    traj_pixels[:250, 0],
                    traj_pixels[:250, 1],
                    color=traj_colors[i],
                    lw=2.5,
                )

        for i, point in enumerate(list_points):
            if len(point.shape) == 1:
                # add a dimension to the front of point
                point = point[None, :2]
            else:
                point = point[:, :2]
            pt_pixels = get_pos_pixels(
                point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True
            )
            ax.plot(
                pt_pixels[:250, 0],
                pt_pixels[:250, 1],
                color=point_colors[i],
                marker="o",
                markersize=10.0,
            )
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5))
        ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5))
    return ax

def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray:
    """
    Args:
        fx: focal length in x direction
        fy: focal length in y direction
        cx: principal point x coordinate
        cy: principal point y coordinate
    Returns:
        camera matrix
    """
    return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])

In [None]:
path = '/nfs/nfs1/users/riadoshi/octo_files/files/04_06'
obses = np.load(f'{path}/obs_imgs.npy', allow_pickle=True)
waypoints = np.load(f'{path}/waypoints.npy', allow_pickle=True)

# np.squeeze(waypoints, axis=1)
# waypoints = waypoints[:,0,:]

# load data_config.yaml
with open('/nfs/nfs1/users/riadoshi/octo_files/files/04_06/data_config.yaml') as f:
    data_config = yaml.safe_load(f)

In [None]:

# plt.plot(waypoints[0,:, 0], waypoints[0, :, 1])
# # scale axis to have same range
# plt.gca().set_aspect('equal', adjustable='box')

In [None]:
obses.shape

In [None]:
from PIL import Image
import cv2


chosen_idx = 0

def plot(i, ax):
    obs, sampled_actions = obses[i], waypoints[i]
    obs_pil = Image.fromarray(obs)

    colors = ["yellow" for _ in range(len(sampled_actions))]
    # print("len sampled _actions", len(sampled_actions))
    colors[chosen_idx] = "blue"

    # move the chosen idx to the back so we can see it
    indices = list(range(len(sampled_actions)))
    indices.remove(chosen_idx)
    indices.append(chosen_idx)
    sampled_actions = sampled_actions[indices] # move chosen_idx to the end
    colors = [colors[i] for i in indices]
    colors[-1] = "blue" # make chosen_idx blue
                

    plot_trajs_and_points_on_image(ax, 
                                    obs_pil.resize((640, 480), Image.Resampling.NEAREST), 
                                    "recon", sampled_actions, 
                                    [], colors, 
                                    [])
    

# Assuming you have a predefined `plot` function that takes an index and an `ax` (axis) as arguments
def display_plots_side_by_side(start, end):
    # Calculate the number of plots and configure your grid accordingly
    num_plots = end - start + 1
    cols = 5  # Number of columns you want to display
    rows = (num_plots + cols - 1) // cols  # Calculate rows needed
    
    fig, axs = plt.subplots(rows, cols, figsize=(cols*5, rows*5))  # Adjust figsize as per your need
    axs = axs.flatten()  # Flatten the axs for easy iteration
    
    for i, ax in enumerate(axs):
        plot_idx = start + i
        if plot_idx <= end:
            plot(plot_idx, ax=ax)  # Your existing plot function modified to accept an ax
        else:
            ax.axis('off')  # Turn off axis for any subplot not being used
    
    plt.tight_layout()
    plt.show()

# plot actions
def display_actions(idx, ax):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    actions = waypoints[idx][0] # shape (1,4,2)
    ax.plot(actions[:, 0], actions[:, 1])
    plt.gca().set_aspect('equal', adjustable='box')
    # plt.show()


def display_actions_sbs(start, end):
    
    num_plots = end - start + 1
    cols = 5  # Number of columns you want to display
    # plot display_actions(30) display_actions(31) ... display_actions(80) all side by side
    rows = (num_plots + cols - 1) // cols  # Calculate rows needed

    fig, axs = plt.subplots(rows, cols, figsize=(cols*5, rows*5))  # Adjust figsize as per your need
    axs = axs.flatten()  # Flatten the axs for easy iteration
    
    for i, ax in enumerate(axs):
        plot_idx = start + i
        if plot_idx <= end:
            display_actions(plot_idx, ax=ax)  
        else:
            ax.axis('off')  # Turn off axis for any subplot not being used
    
    plt.tight_layout()
    plt.show()








In [None]:
display_actions_sbs(30,80)

In [None]:
# Call this function with your desired range
display_plots_side_by_side(30, 80)

In [None]:
display_actions(30)

In [None]:
display_actions(42)

In [None]:
fig, ax = plt.subplots()
plot(10)
plt.show()

In [None]:
fig, ax = plt.subplots()
plot(40)
plt.show()