In [1]:
# Import necessary libraries
import os
import cv2
import numpy as np
from ur3e_dataloader import UR3EDataset
from ur3e_bc.models import UR3EBCModel
from torch.utils.data import Dataset, DataLoader

import time




In [2]:
def visualize_batch(batch, batch_idx=0):
    """
    Visualizes a batch of data from the dataloader in a single window.

    Args:
        batch (dict): A batch of data from the dataloader.
        batch_idx (int): Index within the batch to visualize (default: 0).
    """
    # Extract batch data for the selected sample
    front_cam = batch["front_cam"][batch_idx].cpu().numpy()  # (3, H, W, 3)
    side_cam = batch["side_cam"][batch_idx].cpu().numpy()  # (3, H, W, 3)
    hand_cam = batch["hand_cam"][batch_idx].cpu().numpy()  # (3, H, W, 3)

    

    # Convert images to OpenCV format (H, W, C) and scale up for visibility
    def preprocess_img(img):
        """
        Preprocess image from [f, H, W, C] to OpenCV format [H, W, C] for display.
        """
        # img = np.transpose(img, (1, 2, 0))  # Convert [f, H, W, C] -> [H, W, f, C] -> [H, W, C]
        img = np.clip(img * 255, 0, 255).astype(np.uint8)  # Convert back to 0-255 for display
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                
        return img  # Return as H, W, C

    front_cam_imgs = [preprocess_img(front_cam[i]) for i in reversed(range(3))]
    side_cam_imgs = [preprocess_img(side_cam[i]) for i in reversed(range(3))]
    hand_cam_imgs = [preprocess_img(hand_cam[i]) for i in reversed(range(3))]

    top_row = np.concatenate(front_cam_imgs, axis=1)
    middle_row = np.concatenate(side_cam_imgs, axis=1)
    bottom_row = np.concatenate(hand_cam_imgs, axis=1)

    all_images = np.vstack([top_row, middle_row, bottom_row])

    # Get image dimensions
    H, W, _ = all_images.shape
    step_x = W // 3  # Width of each column
    step_y = H // 3  # Height of each row

    # Titles for rows
    row_titles = ["Front Cam", "Side Cam", "Hand Cam"]
    col_titles = ["t", "t-1", "t-2"]  # Reordered column labels

    font_scale = 0.6  # Reduce font size
    thickness = 1  # Reduce text thickness

    # Add row titles
    for i, title in enumerate(row_titles):
        y_pos = step_y * i + step_y // 2
        cv2.putText(all_images, title, (5, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 
                    font_scale, (0, 255, 0), thickness, cv2.LINE_AA)

    # Add column titles
    for i, title in enumerate(col_titles):
        x_pos = step_x * i + step_x // 2 - 15
        cv2.putText(all_images, title, (x_pos, 20), cv2.FONT_HERSHEY_SIMPLEX, 
                    font_scale, (0, 255, 0), thickness, cv2.LINE_AA)


    # Display the images in one window
    cv2.imshow("Camera Views: Front, Side, Hand (t-2, t-1, t)", all_images)
   

    # Optionally print numerical data for the current sample
    ee_pose = batch["ee_pose"][batch_idx].cpu().numpy()  # (3,7)
    ee_velocity = batch["ee_velocity"][batch_idx].cpu().numpy()  # (6)
    hole_pose = batch["hole_pose"][batch_idx].cpu().numpy()  # (7)
    state = batch["state"][batch_idx].cpu().numpy()  # (1)
    
    print("image size: {}".format(batch["front_cam"].shape))
    print("\nEnd Effector Pose Over Time:")
    print(f"  t-2: {ee_pose[0]}")
    print(f"  t-1: {ee_pose[1]}")
    print(f"  t:   {ee_pose[2]}")
    
    print("\nEnd Effector Velocity:")
    print(f"  [vx, vy, vz, wx, wy, wz]: {ee_velocity}")
    
    print("\nHole Position and Orientation:")
    print(f"  [x, y, z, qx, qy, qz, qw]: {hole_pose}")

    # Print the current state
    state_labels = {0: "Move", 1: "Insert", 2: "Rotate"}
    print(f"\nCurrent State: {state_labels.get(state.item(), 'Unknown')} ({state.item()})")

    cv2.waitKey(0)  # Wait for key press
    # cv2.destroyAllWindows()  # Close the window after key press


In [3]:
# Define the directory for the zip files (data storage directory)
zip_dir = "/home/tanakrit-ubuntu/ur3e_mujoco_tasks/scripts/data"

# Define the batch size
batch_size = 16

# Instantiate the dataset
dataset = UR3EDataset(zip_dir)
print(dataset.episode_dict)

# Create a dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)


{'20250407_134614_0.zip': ['20250407_134614_0_0.pkl', '20250407_134614_0_1.pkl', '20250407_134614_0_2.pkl', '20250407_134614_0_3.pkl', '20250407_134614_0_4.pkl', '20250407_134614_0_5.pkl', '20250407_134614_0_6.pkl', '20250407_134614_0_7.pkl', '20250407_134614_0_8.pkl', '20250407_134614_0_9.pkl', '20250407_134614_0_10.pkl', '20250407_134614_0_11.pkl', '20250407_134614_0_12.pkl', '20250407_134614_0_13.pkl', '20250407_134614_0_14.pkl', '20250407_134614_0_15.pkl', '20250407_134614_0_16.pkl', '20250407_134614_0_17.pkl', '20250407_134614_0_18.pkl', '20250407_134614_0_19.pkl', '20250407_134614_0_20.pkl', '20250407_134614_0_21.pkl', '20250407_134614_0_22.pkl', '20250407_134614_0_23.pkl', '20250407_134614_0_24.pkl', '20250407_134614_0_25.pkl', '20250407_134614_0_26.pkl', '20250407_134614_0_27.pkl', '20250407_134614_0_28.pkl', '20250407_134614_0_29.pkl', '20250407_134614_0_30.pkl', '20250407_134614_0_31.pkl', '20250407_134614_0_32.pkl', '20250407_134614_0_33.pkl', '20250407_134614_0_34.pkl', '20

In [4]:
# for batch in dataloader:
#     for i in range(batch_size):
#         visualize_batch(batch, batch_idx=i) 
#         time.sleep(0.2)

# cv2.destroyAllWindows()

In [5]:
model = UR3EBCModel()

In [6]:
batch = next(iter(dataloader))
a,b,c = model.forward(batch['front_cam'],batch['side_cam'],batch['hand_cam'],batch['ee_pose'],batch['joints'])

print(a)
print(b)
print(c)

tensor([[-2.6283e-01, -3.0827e-01,  2.1476e-02, -1.5383e-01,  7.4507e-02,
          1.3801e-01],
        [-2.6021e-01, -3.0007e-01,  2.2410e-02, -1.7239e-01,  6.7023e-02,
          1.2863e-01],
        [-2.6575e-01, -2.9006e-01,  2.5076e-02, -1.7940e-01,  6.3177e-02,
          1.3938e-01],
        [-2.8135e-01, -2.9773e-01,  1.0582e-02, -1.8806e-01,  8.2221e-02,
          1.5915e-01],
        [-2.6998e-01, -2.9437e-01,  1.3473e-02, -1.7271e-01,  7.4538e-02,
          1.4063e-01],
        [-2.8534e-01, -3.0243e-01,  4.7331e-03, -1.7423e-01,  8.5061e-02,
          1.6027e-01],
        [-2.9356e-01, -3.0987e-01, -7.6245e-03, -1.6713e-01,  8.1294e-02,
          1.4879e-01],
        [-2.8728e-01, -3.0409e-01, -4.0146e-03, -1.6332e-01,  8.3099e-02,
          1.6551e-01],
        [-2.9451e-01, -3.0653e-01, -1.0540e-02, -1.6423e-01,  7.2028e-02,
          1.4777e-01],
        [-2.8709e-01, -3.0277e-01, -1.6554e-02, -1.6471e-01,  8.0071e-02,
          1.4753e-01],
        [-2.9642e-01, -3.1126e