In [1]:
import gymnasium as gym
import numpy as np
import mediapy as media
import torch
import torchvision
# torch.multiprocessing.set_start_method('spawn')
import gym_lite6.env, gym_lite6.pickup_task
# %env MUJOCO_GL=egl # Had to export this before starting jupyter server
# import mujoco
import time


env: MUJOCO_GL=egl # Had to export this before starting jupyter server


In [2]:

class MLPPolicy(torch.nn.Module):
  def __init__(self, hidden_layer_dims, state_dims=9):
    """
    state_dims: 6 for arm, 3 for gripper
    """
    super().__init__()

    # self.img_feature_extractor = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet18', )
    
    self.img_feature_extractor = self._create_img_feature_extractor()
    # Resnet output is 1x512, 2 bits for gripper
    self.actor = self._create_actor(512 + state_dims, hidden_layer_dims, state_dims)

    self.sigmoid = torch.nn.Sigmoid()
  
  def _create_actor(self, input_size, hidden_layer_dims, output_size):
    actor = []
    actor.append(torch.nn.Linear(input_size, hidden_layer_dims[0]))
    actor.append(torch.nn.ReLU())
    for i in range(len(hidden_layer_dims) - 1):
      actor.append(torch.nn.Linear(hidden_layer_dims[i], hidden_layer_dims[i+1]))
      actor.append(torch.nn.ReLU())
    actor.append(torch.nn.Linear(hidden_layer_dims[-1], output_size))
    return torch.nn.Sequential(*actor)

  def _create_img_feature_extractor(self, frozen=False):
    """
    ResNet18 backbone with last fc layer chopped off
    Weights frozen
    Ouput shape [1, 512, 1, 1]
    """
    resnet = torchvision.models.resnet18(weights='DEFAULT')
    modules = list(resnet.children())[:-1]
    backbone = torch.nn.Sequential(*modules)
    backbone.requires_grad_(not frozen)
    return backbone

  def forward(self, state, image):
    img_features = torch.squeeze(self.img_feature_extractor(image), dim=[2, 3])
    input = torch.hstack((state, img_features))
    out = self.actor(input)
    # Gripper sigmoid
    out[:, 6:8] = self.sigmoid(out[:, 6:8])
    return out

  
  def predict(self, state, image, episode_start=None, deterministic=None):
    return self.forward(state, image)


In [3]:

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

policy = MLPPolicy([128, 128]).to(device)


In [186]:
class Trainer:
  def __init__(self, params) -> None:
    # self.env = env # This breaks caching of preprocess_data

    self.params = params

  def normalize_qpos(self, qpos):
    return (qpos - self.params["normalize_qpos"]["bounds_centre"]) / self.params["normalize_qpos"]["bounds_range"] + 0.5

  def unnormalize_qpos(self, qpos):
    return (qpos - 0.5) * self.params["normalize_qpos"]["bounds_range"] + self.params["normalize_qpos"]["bounds_centre"]
  
  def embed_gripper(self, gripper):
    """
    Convert from (-1, 1) to one hot encoded
    One hot needs them as 1d
    """
    return torch.nn.functional.one_hot(gripper.flatten() + 1, num_classes=3)

  def decode_gripper(self, gripper):
    """
    Convert from one hot encoded to column vector in range (-1, 1)
    """
    return (torch.argmax(gripper, dim=1) - 1).unsqueeze(1).to(int)

  def preprocess_data(self, batch):
    """
    Take a batch of data and put it in a suitable tensor format for the model
    """
    out = {}
    
    observation_qpos = torch.tensor(batch["observation.state.qpos"], dtype=torch.float32)
    action_qpos = torch.tensor(batch["action.qpos"], dtype=torch.float32)

    observation_gripper = self.embed_gripper(torch.tensor(batch["observation.state.gripper"], dtype=int)).to(torch.float32)
    action_gripper = self.embed_gripper(torch.tensor(batch["action.gripper"], dtype=int)).to(torch.float32)

    if self.params["normalize_qpos"] is not False:
      observation_qpos = self.normalize_qpos(observation_qpos)
      action_qpos = self.normalize_qpos(action_qpos)

    out["preprocessed.observation.state"] = torch.hstack((observation_qpos, observation_gripper))
    out["preprocessed.action.state"] = torch.hstack((action_qpos, action_gripper))
    
    # Convert to float32 with image from channel first in [0,255]
    tf = torchvision.transforms.ToTensor()
    out["preprocessed.observation.image"] = torch.stack([tf(x) for x in batch["observation.pixels.side"]])

    return out
  
  def lerobot_preprocess(self, batch):
    """
    Take a batch of data and put it in a suitable tensor format for the model
    Batches here are as a list
    """
    out = {}

    idxs = range(len(batch[list(batch.keys())[0]]))
    
    if "observation.state.gripper" in batch and "observation.state.qpos" in batch:
      observation_gripper = [self.embed_gripper(batch["observation.state.gripper"][x]).to(torch.float32) for x in idxs ]
      if self.params["normalize_qpos"] is not False:
        batch["observation.state.qpos"] = [self.normalize_qpos(batch["observation.state.qpos"])[x] for x in idxs if "observation.state.gripper" in batch]
      out["preprocessed.observation.state"] = [torch.hstack((batch["observation.state.qpos"][x], observation_gripper[x].flatten())) for x in idxs]
    
    if "action.gripper" in batch and "action.qpos" in batch:
      action_gripper = [self.embed_gripper(batch["action.gripper"][x]).to(torch.float32) for x in idxs if "action.gripper" in batch]
      if self.params["normalize_qpos"] is not False:
        batch["action.qpos"] = [self.normalize_qpos(batch["action.qpos"])[x] for x in idxs if "action.qpos" in batch]
      out["preprocessed.action.state"] = [torch.hstack((batch["action.qpos"][x], action_gripper[x].flatten())) for x in idxs]

    
    # Convert to float32 with image from channel first in [0,255]
    # tf = torchvision.transforms.ToTensor()
    # out["preprocessed.observation.image"] = torch.stack([tf(x) for x in batch["observation.pixels.side"]])
    batch.update(out)

    return batch
  
  def evaluate_policy(self, env, policy, n):
    avg_reward = 0
    for i in range(n):
      numpy_observation, info = env.reset()

      # Prepare to collect every rewards and all the frames of the episode,
      # from initial state to final state.
      rewards = []
      frames = []
      action = {}

      # Render frame of the initial state
      frames.append(env.render())

      step = 0
      done = False
      while not done and len(frames) < 300:
        # Prepare observation for the policy running in Pytorch
        # Get qpos in range (-1, 1), gripper is already in range (-1, 1)
        qpos = torch.from_numpy(numpy_observation["state"]["qpos"]).unsqueeze(0)
        gripper = self.embed_gripper(torch.tensor(numpy_observation["state"]["gripper"]))
        if self.params["normalize_qpos"] is not False:
          qpos = self.normalize_qpos(qpos)
        state = torch.hstack((qpos, gripper))
        image = torch.from_numpy(numpy_observation["pixels"]["side"])
        
        # Convert to float32 with image from channel first in [0,255]
        # to channel last in [0,1]
        state = state.to(torch.float32)
        image = image.to(torch.float32) / 255
        image = image.permute(2, 0, 1)

        # Add extra (empty) batch dimension, required to forward the policy
        # state = state.unsqueeze(0)
        image = image.unsqueeze(0)

        # Send data tensors from CPU to GPU
        state = state.to(device, non_blocking=True)
        image = image.to(device, non_blocking=True)

        # Predict the next action with respect to the current observation
        with torch.inference_mode():
          raw_action = policy.predict(state, image).to("cpu")
        
        action["qpos"] = raw_action[:, :6]
        if self.params["normalize_qpos"] is not False:
          action["qpos"] = self.unnormalize_qpos(action["qpos"])
        
        action["qpos"] = action["qpos"].flatten().numpy()
        action["gripper"] = self.decode_gripper(raw_action[:, 6:8]).item()
        
        # print(action)
        # numpy_action = np.hstack((action["qpos"], action["gripper"]))

        # Step through the environment and receive a new observation
        numpy_observation, reward, terminated, truncated, info = env.step(action)
        # Keep track of all the rewards and frames
        rewards.append(reward)
        frames.append(env.render())

        # The rollout is considered done when the success state is reach (i.e. terminated is True),
        # or the maximum number of iterations is reached (i.e. truncated is True)
        done = terminated | truncated | done
        step += 1
      
      avg_reward += rewards[-1]/n
    
      return avg_reward, frames



In [116]:

from datasets import load_from_disk
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime
from pathlib import Path
import argparse


In [6]:

# parser = argparse.ArgumentParser(
#                   prog='Train Lite6 BC-MLP-MSE',
#                   description='Train BC-MLP-MSE on Ufactory Lite6')
# parser.add_argument('checkpoint')

# args = parser.parse_args()

# %%
task = gym_lite6.pickup_task.PickupTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')
env = gym.make(
    "UfactoryCubePickup-v0",
    task=task,
    obs_type="pixels_state",
    max_episode_steps=350,
    visualization_width=320,
    visualization_height=240,
)
observation, info = env.reset()
# media.show_image(env.render(), width=400, height=400)


In [17]:
from lerobot.common.datasets.utils import (
    hf_transform_to_torch,
)

In [187]:

params = {}

jnt_range_low = env.unwrapped.model.jnt_range[:6, 0]
jnt_range_high = env.unwrapped.model.jnt_range[:6, 1]
bounds_centre = torch.tensor((jnt_range_low + jnt_range_high) / 2, dtype=torch.float32)
bounds_range = torch.tensor(jnt_range_high - jnt_range_low, dtype=torch.float32)
# params["normalize_qpos"] = {"bounds_centre": bounds_centre, "bounds_range": bounds_range}
params["normalize_qpos"] = False

trainer = Trainer(params)

dataset = load_from_disk("BC-MLP-MSE/datasets/pickup/scripted_trajectories_50_2024-08-02_12-49-56.hf")
if "from" not in dataset.column_names:
  first_frames=dataset.filter(lambda example: example['frame_index'] == 0)
  from_idxs = torch.tensor(first_frames['index'])
  to_idxs = torch.tensor(first_frames['index'][1:] + [len(dataset)])
  episode_data_index={"from": from_idxs, "to": to_idxs}
    
# dataset.set_transform(hf_transform_to_torch)
dataset.set_transform(lambda x: trainer.lerobot_preprocess(hf_transform_to_torch(x)))
# dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=2)


In [188]:
dataset[0]

{'action.qpos': tensor([-2.3405,  0.1678,  1.0883, -1.8791,  0.1592, -0.4317]),
 'action.gripper': tensor(0),
 'observation.state.qpos': tensor([-2.3405,  0.1681,  1.0880, -1.8791,  0.1592, -0.4317]),
 'observation.state.qvel': tensor([-3.7407e-06,  8.3225e-03, -6.5398e-03,  1.8246e-04, -1.5866e-04,
          1.7253e-06]),
 'observation.state.gripper': tensor(0),
 'observation.pixels.side': tensor([[[0.1608, 0.1608, 0.1608,  ..., 0.1608, 0.1608, 0.1608],
          [0.1608, 0.1608, 0.1608,  ..., 0.1608, 0.1608, 0.1608],
          [0.1608, 0.1608, 0.1608,  ..., 0.1608, 0.1608, 0.1608],
          ...,
          [0.2353, 0.2353, 0.2353,  ..., 0.1137, 0.1137, 0.1137],
          [0.2353, 0.2353, 0.2353,  ..., 0.1137, 0.1137, 0.1137],
          [0.2353, 0.2353, 0.2353,  ..., 0.1137, 0.1137, 0.1137]],
 
         [[0.2706, 0.2706, 0.2706,  ..., 0.2706, 0.2706, 0.2706],
          [0.2706, 0.2706, 0.2706,  ..., 0.2706, 0.2706, 0.2706],
          [0.2706, 0.2706, 0.2706,  ..., 0.2706, 0.2706, 0.27

In [189]:
dataset.select_columns("timestamp")[0]

KeyError: 'action.gripper'

In [184]:
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, CODEBASE_VERSION
lerobot_dataset = LeRobotDataset.from_preloaded(root=Path("BC-MLP-MSE/datasets/scripted_trajectories_50_2024-08-02_12-49-56.hf"),
        split="train",
        # transform=trainer.lerobot_preprocess,
        delta_timestamps={"action.qpos": [0, 0.1]},
        # additional preloaded attributes
        hf_dataset=dataset,
        episode_data_index=episode_data_index,
        # episode_data_index=dataset["timestamp"],
        info = {
        "codebase_version": CODEBASE_VERSION,
        "fps": env.metadata["render_fps"]
        })

In [185]:
lerobot_dataset[0]

KeyError: 'observation.state.gripper'

In [None]:

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

curr_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
hidden_layer_dims = '_'.join([str(x.out_features) for x in policy.actor[:-1] if 'out_features' in x.__dict__])
OUTPUT_FOLDER=f'../ckpts/lite6_pick_place_h{hidden_layer_dims}_{curr_time}'
Path(OUTPUT_FOLDER).mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(log_dir=f"../runs/lite6_pick_place/{curr_time}")

n_epoch = 20
step = 0
for epoch in range(n_epoch):
  policy.train()
  end = time.time()
  for batch in tqdm(dataloader):
    data_load_time = time.time()

    # Send data tensors from CPU to GPU
    state = batch["preprocessed.observation.state"].to(device, non_blocking=True)
    image = batch["preprocessed.observation.image"].to(device, non_blocking=True)
    a_hat = batch["preprocessed.action.state"].to(device, non_blocking=True)

    gpu_load_time = time.time()

    a_pred = policy.predict(state, image)

    pred_time = time.time()

    loss = loss_fn(a_pred, a_hat)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_time = time.time()

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("Time/data_load", data_load_time - end, step)
    writer.add_scalar("Time/gpu_transfer", gpu_load_time - data_load_time, step)
    writer.add_scalar("Time/pred_time", pred_time - gpu_load_time, step)
    writer.add_scalar("Time/train_time", train_time - pred_time, step)
    writer.add_scalar("Time/step_time", time.time() - end, step)

    step += 1
    end = time.time()
  
  if epoch % 2 == 0 or epoch == n_epoch-1:
    # Evaluate
    policy.eval()
    print(f"Epoch: {epoch+1}/{n_epoch}, steps: {step}, loss: {loss.item()}")
    avg_reward, frames = trainer.evaluate_policy(env, policy, 5)
    media.write_video(OUTPUT_FOLDER + f"/epoch_{epoch}.mp4", frames, fps=env.metadata["render_fps"])
    print("avg reward: ", avg_reward)
    writer.add_scalar("Reward/val", avg_reward, step)
    # _, frames = evaluate_policy(policy, env, 1, visualise=True)
    writer.add_images("Image", np.stack([frames[x].transpose(2, 0, 1) for x in range(0, len(frames), 50)], axis=0), step)
  
    writer.add_scalar("Time/eval_time", time.time() - end, step)


  if epoch % 10 == 0 or epoch == n_epoch-1:
    torch.save({
            'epoch': epoch,
            'params': params,
            'policy_state_dict': policy.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, OUTPUT_FOLDER + f'/epoch_{epoch}.pt')
  
writer.flush()

writer.close()



