In [None]:
# Load dataset and model if not continuing from before
import torch
import torchvision

from datasets import load_from_disk
from tqdm import tqdm
import datetime
from pathlib import Path
import time
from lerobot.common.datasets.utils import hf_transform_to_torch
import numpy as np
import mediapy as media


dataset_path = "../datasets/byol/grasp_100_2024-09-06_17-03-47.hf"
dataset_name = Path(dataset_path).stem
dataset = load_from_disk(dataset_path)
dataset.set_transform(hf_transform_to_torch)
# dataset.set_format("torch")

checkpoint = torch.load("../VINN/ckpts/resnet_byol_grasp_100_2024-09-06_17-03-47_2024-10-12_21-08/epoch_10.pt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

# Remove the fc/classification layer
resnet = torchvision.models.resnet18()
modules = list(resnet.children())[:-1]
backbone = torch.nn.Sequential(*modules)
net = backbone
net.load_state_dict(checkpoint["policy_state_dict"])
net = net.to(device)


In [None]:
idx = np.random.randint(len(dataset))
print(idx)

img = dataset[idx]["observation.pixels.side"].unsqueeze(0).to(device)

# net(img)

In [None]:

net

In [None]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
train_nodes, _ = get_graph_node_names(net)
print(train_nodes)
return_nodes = ["4", '5', '6']
model2 = create_feature_extractor(net, return_nodes=return_nodes)
intermediate_outputs = model2(img)
print(intermediate_outputs)
for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)


In [None]:
return_nodes = ["4", '5', '6']
model2 = create_feature_extractor(net, return_nodes=return_nodes)
intermediate_outputs = model2(img)
print(intermediate_outputs)

In [None]:
for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)

In [None]:
media.show_image(img.detach().cpu().squeeze().permute(1, 2, 0))

In [None]:
import matplotlib.pyplot as plt

for key in intermediate_outputs.keys():
  print(f"Plot {key}")
  plt_data = intermediate_outputs[key].detach().cpu().squeeze().numpy()
  ncols = min(plt_data.shape[0], 8)
  nrows= int(np.ceil(plt_data.shape[0]/ncols))
  fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, sharey=True, figsize=(ncols*1+0.5, nrows*1+0.5), constrained_layout=True)
  for i in range(plt_data.shape[0]):
      ax = axs.flatten()[i]
      ax.imshow(plt_data[i])
      ax.set_axis_off()
      # ax.set_title(i)
  plt.suptitle(f"{key}")

In [None]:
resnet_random = torchvision.models.resnet18().to("cuda")
return_nodes = ['layer1', 'layer2', 'layer3']
resnet_random2 = create_feature_extractor(resnet_random, return_nodes=return_nodes)
intermediate_outputs = resnet_random2(img)

for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)

for key in intermediate_outputs.keys():
  print(f"Plot {key}")
  plt_data = intermediate_outputs[key].detach().cpu().squeeze().numpy()
  ncols = min(plt_data.shape[0], 8)
  nrows= int(np.ceil(plt_data.shape[0]/ncols))
  fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, sharey=True, figsize=(ncols*1+0.5, nrows*1+0.5), constrained_layout=True)
  for i in range(plt_data.shape[0]):
      ax = axs.flatten()[i]
      ax.imshow(plt_data[i])
      ax.set_axis_off()
      # ax.set_title(i)
  plt.suptitle(f"{key}")

In [None]:
resnet_default = torchvision.models.resnet18(weights="DEFAULT").to("cuda")
return_nodes = ['layer1', 'layer2', 'layer3']
resnet_default2 = create_feature_extractor(resnet_default, return_nodes=return_nodes)
intermediate_outputs = resnet_default2(img)

for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)

for key in intermediate_outputs.keys():
  print(f"Plot {key}")
  plt_data = intermediate_outputs[key].detach().cpu().squeeze().numpy()
  ncols = min(plt_data.shape[0], 8)
  nrows= int(np.ceil(plt_data.shape[0]/ncols))
  fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, sharey=True, figsize=(ncols*1+0.5, nrows*1+0.5), constrained_layout=True)
  for i in range(plt_data.shape[0]):
      ax = axs.flatten()[i]
      ax.imshow(plt_data[i])
      ax.set_axis_off()
      # ax.set_title(i)
  plt.suptitle(f"{key}")

In [None]:

checkpoint = torch.load("../VINN/ckpts/resnet_byol_grasp_100_2024-09-06_17-03-47_2024-10-15_21-08/epoch_100.pt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

# Remove the fc/classification layer
resnet = torchvision.models.resnet18()
modules = list(resnet.children())[:-1]
backbone = torch.nn.Sequential(*modules)
net = backbone
net.load_state_dict(checkpoint["policy_state_dict"])
net = net.to(device)
net.eval()

In [None]:
media.show_image(img.detach().cpu().squeeze().permute(1, 2, 0))


In [None]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
train_nodes, _ = get_graph_node_names(net)
print(train_nodes)
return_nodes = ["4", '5', '6']
model2 = create_feature_extractor(net, return_nodes=return_nodes)
intermediate_outputs = model2(img)
print(intermediate_outputs)
for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)


In [None]:
import matplotlib.pyplot as plt

for key in intermediate_outputs.keys():
  print(f"Plot {key}")
  plt_data = intermediate_outputs[key].detach().cpu().squeeze().numpy()
  ncols = min(plt_data.shape[0], 8)
  nrows= int(np.ceil(plt_data.shape[0]/ncols))
  fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, sharey=True, figsize=(ncols*1+0.5, nrows*1+0.5), constrained_layout=True)
  for i in range(plt_data.shape[0]):
      ax = axs.flatten()[i]
      ax.imshow(plt_data[i])
      ax.set_axis_off()
      # ax.set_title(i)
  plt.suptitle(f"{key}")

In [None]:
img = dataset[idx]["observation.pixels.gripper"].unsqueeze(0).to(device)
train_nodes, _ = get_graph_node_names(net)
print(train_nodes)
return_nodes = ["4", '5', '6']
model2 = create_feature_extractor(net, return_nodes=return_nodes)
intermediate_outputs = model2(img)
print(intermediate_outputs)
for k in intermediate_outputs:
  print(intermediate_outputs[k].shape)

for key in intermediate_outputs.keys():
  print(f"Plot {key}")
  plt_data = intermediate_outputs[key].detach().cpu().squeeze().numpy()
  ncols = min(plt_data.shape[0], 8)
  nrows= int(np.ceil(plt_data.shape[0]/ncols))
  fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, sharey=True, figsize=(ncols*1+0.5, nrows*1+0.5), constrained_layout=True)
  for i in range(plt_data.shape[0]):
      ax = axs.flatten()[i]
      ax.imshow(plt_data[i])
      ax.set_axis_off()
      # ax.set_title(i)
  plt.suptitle(f"{key}")
