# VINN - Relative velocty

In [None]:
import torch
import torchvision

from datasets import load_from_disk
from lerobot.common.datasets.utils import hf_transform_to_torch
# from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, CODEBASE_VERSION
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime
from pathlib import Path
import time


In [None]:

dataset_path = "../datasets/grasp_100_2024-09-06_17-03-47.hf"
dataset = load_from_disk(dataset_path)
# if "from" not in dataset.column_names:
#   first_frames=dataset.filter(lambda example: example['frame_index'] == 0)
#   from_idxs = torch.tensor(first_frames['index'])
#   to_idxs = torch.tensor(first_frames['index'][1:] + [len(dataset)])
#   episode_data_index={"from": from_idxs, "to": to_idxs}
    
dataset.set_transform(hf_transform_to_torch)
# dataset = dataset.with_format("torch", device=device)


In [None]:
dataloader = DataLoader(
        dataset,
        num_workers=4,
        batch_size=256,
        shuffle=True,
        # sampler=sampler,
        # pin_memory=params["device"].type != "cpu",
        drop_last=False,
    )


curr_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
dataset_name = Path(dataset_path).stem
CKPT_DIR=f'ckpts/resnet_byol_{dataset_name}_{curr_time}'
TENSORBOARD_DIR=f'runs/resnet_byol_{dataset_name}_{curr_time}'
Path(CKPT_DIR).mkdir(parents=True, exist_ok=True)


In [None]:
from byol_pytorch import BYOL
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

# Remove the fc/classification layer
resnet = torchvision.models.resnet18(weights='DEFAULT')
modules = list(resnet.children())[:-1]
backbone = torch.nn.Sequential(*modules)
net = backbone


learner = BYOL(
    net,
    image_size = 240,
    hidden_layer = 8
).to(device)


In [None]:

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

writer = SummaryWriter(log_dir=TENSORBOARD_DIR)
step = 0
for epoch in range(1, 101):
  print(f"epoch {epoch}")
  end = time.time()
  for batch in tqdm(dataloader):
    data_load_time = time.time()
    step +=1

    # images = torch.cat((batch['observation.pixels.side'], batch['observation.pixels.gripper']), dim=0).to(device)
    images = batch['observation.pixels.gripper'].to(device)
    gpu_load_time = time.time()

    loss = learner(images)
    pred_time = time.time()

    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder
    train_time = time.time()

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("Time/data_load", data_load_time - end, step)
    writer.add_scalar("Time/gpu_transfer", gpu_load_time - data_load_time, step)
    writer.add_scalar("Time/pred_time", pred_time - gpu_load_time, step)
    writer.add_scalar("Time/train_time", train_time - pred_time, step)
    writer.add_scalar("Time/step_time", time.time() - end, step)

    if epoch % 10 == 0:
      # save the improved network
      torch.save({'policy_state_dict': net.state_dict(),
                  'optimizer_state_dict': opt.state_dict(),
                  'loss': loss,
                  'epoch': epoch,
                  'step': step,
                  }, CKPT_DIR + f'/epoch_{epoch}.pt')





In [None]:
# modules = list(net.children())[:-1]
# backbone = torch.nn.Sequential(*modules).to(device)

# Save embeddings of all images
img_cols = {"observation.pixels.side": "observation.vector.side", "observation.pixels.gripper" : "observation.vector.gripper"}
vec_columns = {v:[] for k,v in img_cols.items()}
for batch in tqdm(dataloader):
    with torch.inference_mode():
      for img_col in img_cols:
        # Batch inference to get embeddings
        embedding = net(batch[img_col].to(device)).to("cpu")
        # Add the embedding (size 512) for each image to the column
        vec_columns[img_cols[img_col]].extend([e.numpy() for e in embedding.squeeze()])

# vec_columns = {k: [r.numpy() for r in v] for k, v in vec_columns.items()}
for col in vec_columns:
  assert(len(vec_columns[col]) == len(dataset))
  dataset = dataset.add_column(name=col, column=vec_columns[col])

dataset.set_transform(None)
dataset.save_to_disk(f"../datasets/byol/{dataset_name}_epoch_{epoch}.hf")
dataset.set_format("torch")


# Embed images in a dataset
Run from here if you already have a pretrained model, but a new dataset

In [None]:
import torch
import torchvision

from lerobot.common.datasets.utils import hf_transform_to_torch
from datasets import load_from_disk
from tqdm import tqdm
import datetime
from pathlib import Path
import time

from torch.utils.data import DataLoader

dataset_path = "../datasets/grasp/grasp_ee_vel_random_50_2024-10-15_18-56-26.hf"
dataset_name = Path(dataset_path).stem
dataset = load_from_disk(dataset_path)
dataset.set_transform(hf_transform_to_torch)

checkpoint = torch.load("../VINN/ckpts/resnet_byol_grasp_100_2024-09-06_17-03-47_2024-10-15_21-08/epoch_100.pt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

# Remove the fc/classification layer
resnet = torchvision.models.resnet18().to(device)
modules = list(resnet.children())[:-1]
backbone = torch.nn.Sequential(*modules)
net = backbone
net.load_state_dict(checkpoint["policy_state_dict"])
net = net.to(device)

net.eval()
# Save embeddings of all images
img_cols = {"observation.pixels.side": "observation.vector.side", "observation.pixels.gripper" : "observation.vector.gripper"}
# img_cols = {"observation.pixels.gripper" : "observation.vector.gripper"}
vec_columns = {v:[] for k,v in img_cols.items()}
# For some reason, batch embedding fucks up and gives different results to single
dataloader = DataLoader(
        dataset,
        num_workers=4,
        batch_size=128,
        shuffle=False,
        # sampler=sampler,
        # pin_memory=params["device"].type != "cpu",
        drop_last=False,
    )
for batch in tqdm(dataloader):
    with torch.inference_mode():
      for img_col in img_cols:
        # Batch inference to get embeddings
        embedding = net(batch[img_col].to(device)).to("cpu")
        # Add the embedding (size 512) for each image to the column
        vec_columns[img_cols[img_col]].extend([e.numpy() for e in embedding.squeeze()])

# for datum in tqdm(dataset):
#     with torch.inference_mode():
#       for img_col in img_cols:
#         # Batch inference to get embeddings
#         embedding = net(datum[img_col].unsqueeze(0).to(device)).cpu().squeeze()
#         # Add the embedding (size 512) for each image to the column
#         vec_columns[img_cols[img_col]].append(embedding.numpy())

# vec_columns = {k: [r.numpy() for r in v] for k, v in vec_columns.items()}
for col in vec_columns:
  assert(len(vec_columns[col]) == len(dataset))
  dataset = dataset.add_column(name=col, column=vec_columns[col])

dataset.set_transform(None)
dataset.save_to_disk(f"../datasets/byol/{dataset_name}_epoch_{epoch}.hf")
dataset.set_format("torch")


In [None]:
for col in vec_columns:
  assert(len(vec_columns[col]) == len(dataset))
  dataset = dataset.add_column(name=col, column=vec_columns[col])

In [None]:
dataset.set_transform(None)
dataset.save_to_disk(f"../datasets/byol/{dataset_name}_epoch_{epoch}.hf")
dataset.set_format("torch")

In [None]:
dataset_name

# Run VINN
Start from here if you already have a vision model and a dataset with embeddings

In [None]:
# Load dataset and model if not continuing from before
import torch
import torchvision

from lerobot.common.datasets.utils import hf_transform_to_torch
from datasets import load_from_disk
from tqdm import tqdm
import datetime
from pathlib import Path
import time


dataset_path = "../datasets/byol/grasp_ee_vel_random_50_2024-10-15_18-56-26_epoch_100.hf"
dataset_name = Path(dataset_path).stem
dataset = load_from_disk(dataset_path)
# dataset.set_transform(hf_transform_to_torch)
# dataset = dataset.with_format("torch", device=device)
dataset.set_format("torch")

checkpoint = torch.load("../VINN/ckpts/resnet_byol_grasp_100_2024-09-06_17-03-47_2024-10-15_21-08/epoch_100.pt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

# Remove the fc/classification layer
resnet = torchvision.models.resnet18().to(device)
modules = list(resnet.children())[:-1]
backbone = torch.nn.Sequential(*modules)
net = backbone
net.load_state_dict(checkpoint["policy_state_dict"])
net = net.to(device)
net.eval()


In [None]:
import gym_lite6.env, gym_lite6.pickup_task, gym_lite6.utils
import gymnasium as gym
import numpy as np
import mediapy as media


task = gym_lite6.pickup_task.GraspAndLiftTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')
env = gym.make(
    "UfactoryCubePickup-v0",
    task=task,
    obs_type="pixels_state",
    action_type="qvel",
    max_episode_steps=300,
    visualization_width=320,
    visualization_height=240
)
observation, info = env.reset()
media.show_image(env.render())


In [None]:
import heapq

def dist_metric(x,y):
  """
  L2 norm
  """
  return torch.norm(x-y, dim=1)

def calculate_nearest_neighbours(query, data, k=1, verbose=False):
  t0 = time.time()
  dists = []
  assert len(data[0].shape) == 1, f"Invalid shape, should be 1D: {data.shape}"
  # Takes around 0.25s for 20000 samples
  dists = dist_metric(query, data).cpu()
  heap = [(dists[i].item(), i) for i in range(data.shape[0])]
  
  t1 = time.time()
  heapq.heapify(heap)
  t2 = time.time()

  out = []
  for i in range(k):
    out.append(heapq.heappop(heap))
  
  if verbose:
    print(f"Times: get dists: {t1-t0}, heapify: {t2-t1}")
  return out
        
def calculate_action(dists, dataset, key='action.qpos'):
  """
  Local weighting
  """
  if len(dists) > 1:
    softmin = torch.nn.Softmin(dim=0)
    top_k_weights = softmin(torch.tensor([d[0] for d in dists])).tolist()
    action = sum([top_k_weights[i] * dataset[dists[i][1]][key] for i in range(len(dists))])
  else:
    action = dataset[dists[0][1]][key]

  return action

In [None]:
observation, info = env.reset()
k=5
frames = []
nn_frames = [[] for _ in range(k)]
action = {}
ep_dict = {
      "action.qpos": [], "action.qvel": [], "action.gripper": [],"action.ee_vel": [], "action.ee_ang_vel": [],
      "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.pixels.side": [], "observation.pixels.gripper": [],
      "observation.ee_pose.pos": [], "observation.ee_pose.quat": [], "observation.ee_pose.vel": [], "observation.ee_pose.ang_vel": [],
      "reward": [], "timestamp": [], "frame_index": [],
      }
# vector_data = dataset["observation.vector.side"].to(device)
vector_data = dataset["observation.vector.gripper"].to(device)

step = 0
done = False
net.eval()
while step < 100:

  t0 = time.time()
  # image_side = (torch.from_numpy(observation["pixels"]["side"]).permute(2, 0, 1).unsqueeze(0) / 255).to(device)
  image_gripper = (torch.from_numpy(observation["pixels"]["gripper"]).permute(2, 0, 1).unsqueeze(0) / 255).to(device)

  with torch.inference_mode():
    embedding = net(image_gripper).squeeze()
    # embedding = net(image_gripper).squeeze()
  t1 = time.time()
  
  nn = calculate_nearest_neighbours(embedding, vector_data, k=k)

  t2 = time.time()

  # Save the nearest image in the dataset for comparison
  # stacked_frame = np.hstack([observation["pixels"]["side"]] + [dataset[nn[i][1]]["observation.pixels.side"].permute(1, 2, 0).cpu().numpy() for i in range(k)])
  stacked_frame = np.hstack([observation["pixels"]["gripper"]] + [dataset[nn[i][1]]["observation.pixels.gripper"].permute(1, 2, 0).cpu().numpy() for i in range(k)])
  frames.append(stacked_frame)

  vel = calculate_action(nn, dataset, key='action.ee_vel').numpy()
  ang_vel = calculate_action(nn, dataset, key='action.ee_ang_vel').numpy()
  action["gripper"] = round(calculate_action(nn, dataset, key='action.gripper').item())
  action["qvel"] = env.unwrapped.solve_ik_vel(vel, ang_vel, ref_frame='end_effector', local=True)

  
  ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
  ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
  ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
  ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
  ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
  ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
  ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
  ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
  ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
  ep_dict["timestamp"].append(env.unwrapped.data.time)

  # Step through the environment and receive a new observation
  observation, reward, terminated, truncated, info = env.step(action)
  
  ep_dict["reward"].append(reward)
  # ep_dict["action.qpos"].append(action["qpos"])
  ep_dict["action.qvel"].append(action["qvel"])
  ep_dict["action.ee_vel"].append(vel)
  ep_dict["action.ee_ang_vel"].append(ang_vel)
  ep_dict["action.gripper"].append(action["gripper"])
    
  done = truncated | done | terminated
  step += 1

  print(f"Timing: inference: {t1-t0}, nn: {t2-t1}, rest: {time.time()-t2}")


ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
ep_dict["timestamp"].append(env.unwrapped.data.time)

# Append dummy nans to the actions so we have the same number of samples can plot it
# ep_dict["action.qpos"].append(np.array([np.nan] * len(action["qpos"])))
ep_dict["action.gripper"].append(np.array([np.nan]))
ep_dict["action.qvel"].append(np.array([np.nan] * len(action["qvel"])))
ep_dict["action.ee_vel"].append(np.array([np.nan] * len(vel)))
ep_dict["action.ee_ang_vel"].append(np.array([np.nan] * len(ang_vel)))
# ep_dict["reward"].append(np.array([np.nan]))

avg_reward = sum(ep_dict["reward"])/len(ep_dict["reward"])
print(f"Avg reward: {avg_reward}")
media.show_video(frames)



In [None]:
action

In [None]:
gym_lite6.utils.plot_dict_of_arrays(ep_dict, "timestamp", keys=["action.qpos", "observation.state.qpos"], sharey=True)

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
X_pca = pca.fit_transform(dataset['observation.vector.gripper'])

In [None]:
import matplotlib.pyplot as plt

plt.scatter(x=X_pca[:, 0], y=X_pca[:, 1], marker='.')
plt.show()

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(dataset['observation.vector.gripper'])
tsne.kl_divergence_

In [None]:
plt.scatter(x=X_tsne[:, 0], y=X_tsne[:, 1], marker='.')
plt.show()

In [None]:
X_tsne.shape

In [None]:
dataset['observation.vector.gripper'].shape

In [None]:
vector_data = torch.from_numpy(X_tsne).to(device)


In [None]:
matches = index.search(dataset[0]['observation.vector.gripper'].numpy(), 2)


In [None]:
from usearch.index import Index
observation, info = env.reset()
k=5
frames = []
side_frames = []
nn_frames = [[] for _ in range(k)]
action = {}
ep_dict = {
      "action.qpos": [], "action.qvel": [], "action.gripper": [],"action.ee_vel": [], "action.ee_ang_vel": [],
      "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.pixels.side": [], "observation.pixels.gripper": [],
      "observation.ee_pose.pos": [], "observation.ee_pose.quat": [], "observation.ee_pose.vel": [], "observation.ee_pose.ang_vel": [],
      "reward": [], "timestamp": [], "frame_index": [],
      }
vector_data = dataset["observation.vector.gripper"].to(device)
index = Index(ndim=dataset['observation.vector.gripper'].shape[1])
index.add(vectors=dataset['observation.vector.gripper'].numpy(), keys=None)
# vector_data = torch.from_numpy(X_tsne).to(device)

step = 0
done = False
net.eval()
while not done:

  t0 = time.time()
  # image_side = (torch.from_numpy(observation["pixels"]["side"]).permute(2, 0, 1).unsqueeze(0) / 255).to(device)
  image_gripper = (torch.from_numpy(observation["pixels"]["gripper"]).permute(2, 0, 1).unsqueeze(0) / 255).to(device)

  with torch.inference_mode():
    embedding = net(image_gripper).squeeze()
    # embedding = net(image_gripper).squeeze()
  t1 = time.time()
  
  matches = index.search(embedding.cpu().numpy(), count=k)
  nn = [(matches.distances[i], int(matches.keys[i])) for i in range(k)]

  t2 = time.time()

  # Save the nearest image in the dataset for comparison
  # stacked_frame = np.hstack([observation["pixels"]["side"]] + [dataset[nn[i][1]]["observation.pixels.side"].permute(1, 2, 0).cpu().numpy() for i in range(k)])
  stacked_frame = np.hstack([observation["pixels"]["gripper"]] + [dataset[nn[i][1]]["observation.pixels.gripper"].permute(1, 2, 0).cpu().numpy() for i in range(k)])
  frames.append(stacked_frame)
  side_frames.append(observation["pixels"]["side"])

  vel = calculate_action(nn, dataset, key='action.ee_vel').numpy()
  ang_vel = calculate_action(nn, dataset, key='action.ee_ang_vel').numpy()
  action["gripper"] = round(calculate_action(nn, dataset, key='action.gripper').item())
  action["qvel"] = env.unwrapped.solve_ik_vel(vel, ang_vel, ref_frame='end_effector', local=True)

  
  ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
  ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
  ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
  ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
  ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
  ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
  ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
  ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
  ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
  ep_dict["timestamp"].append(env.unwrapped.data.time)

  # Step through the environment and receive a new observation
  observation, reward, terminated, truncated, info = env.step(action)
  
  ep_dict["reward"].append(reward)
  # ep_dict["action.qpos"].append(action["qpos"])
  ep_dict["action.qvel"].append(action["qvel"])
  ep_dict["action.ee_vel"].append(vel)
  ep_dict["action.ee_ang_vel"].append(ang_vel)
  ep_dict["action.gripper"].append(action["gripper"])
    
  done = truncated | done | terminated
  step += 1

  print(f"Timing: inference: {t1-t0}, nn: {t2-t1}, rest: {time.time()-t2}")


ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
ep_dict["timestamp"].append(env.unwrapped.data.time)

# Append dummy nans to the actions so we have the same number of samples can plot it
# ep_dict["action.qpos"].append(np.array([np.nan] * len(action["qpos"])))
ep_dict["action.gripper"].append(np.array([np.nan]))
ep_dict["action.qvel"].append(np.array([np.nan] * len(action["qvel"])))
ep_dict["action.ee_vel"].append(np.array([np.nan] * len(vel)))
ep_dict["action.ee_ang_vel"].append(np.array([np.nan] * len(ang_vel)))
# ep_dict["reward"].append(np.array([np.nan]))

avg_reward = sum(ep_dict["reward"])/len(ep_dict["reward"])
print(f"Avg reward: {avg_reward}")
media.show_video(frames)
media.show_video(side_frames)



In [None]:
gym_lite6.utils.plot_dict_of_arrays(ep_dict, "timestamp", keys=["action.ee_vel", "action.ee_ang_vel", "observation.state.qpos"], sharey=True)


In [None]:
nn

In [None]:
# Important and useful test to make sure embeddings can be reproduced
idx = 20
img = (dataset[idx]["observation.pixels.gripper"]/255).unsqueeze(0).to(device)
dataset_embedding = dataset[idx]["observation.vector.gripper"].numpy()

net.eval()
with torch.inference_mode():
  embedding = net(img).squeeze()
t1 = time.time()
  
matches_dataset = index.search(dataset_embedding, count=2)
matches = index.search(embedding.cpu().numpy(), count=2)
print(matches_dataset.keys, matches_dataset.distances)
print(matches.keys, matches.distances)

media.show_images([dataset[idx]["observation.pixels.gripper"].permute(1, 2, 0)])
media.show_images([dataset[int(i)]["observation.pixels.gripper"].permute(1, 2, 0) for i in matches_dataset.keys])
media.show_images([dataset[int(i)]["observation.pixels.gripper"].permute(1, 2, 0) for i in matches.keys])

In [None]:
dataset[idx]["observation.pixels.gripper"]/255