In [1]:
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
import h5py
import os
import numpy as np
from typing import List
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class GradRAM(BaseCAM):
    def __init__(self, model, target_layers,
                 reshape_transform=None):
        super(
            GradRAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):
        return np.mean(grads, axis=(2, 3))

    def forward(self,
                input_tensor: torch.Tensor,
                targets: List[torch.nn.Module],
                eigen_smooth: bool = False) -> np.ndarray:
        input_tensor = input_tensor.to(self.device)

        if self.compute_input_gradient:
            input_tensor = torch.autograd.Variable(input_tensor,
                                                   requires_grad=True)

        self.outputs = outputs = self.activations_and_grads(input_tensor)
        self.model.requires_grad_(True)

        if self.uses_gradients:
            self.model.zero_grad()
            loss = targets[0](outputs)

            loss.backward(retain_graph=True)
            
        cam_per_layer = self.compute_cam_per_layer(input_tensor,
                                                   targets,
                                                   eigen_smooth)
        return self.aggregate_multi_layers(cam_per_layer)


class LogProbOutputTarget:
    def __init__(self, target_value):
        self.target_value = target_value
    def __call__(self, model_output):
        log_probs = model_output.log_prob(self.target_value)
        loss = -log_probs.mean() # loss is just negative log-likelihood of action targets
        return loss

In [2]:
with h5py.File('lift/depth84.hdf5', 'r') as f:
    front_images = f['data']['demo_0']['obs']['frontview_image'][:]
    agent_images = f['data']['demo_0']['obs']['agentview_image'][:]
    side_images = f['data']['demo_0']['obs']['sideview_image'][:]
    eyeinhand_images = f['data']['demo_0']['obs']['robot0_eye_in_hand_image'][:]
    eef_poses = f['data']['demo_0']['obs']['robot0_eef_pos'][:]
    eef_quats = f['data']['demo_0']['obs']['robot0_eef_quat'][:]
    gripper_qposes = f['data']['demo_0']['obs']['robot0_gripper_qpos'][:]
    actions = f['data']['demo_0']['actions'][:]

### Define models and run gradcam here

#### Baseline

In [3]:
input_shape = [512, 3, 3]
image_latent_dim = 256
action_dim = 7
low_dim_input_dim = 3 + 4 + 2  # robot0_eef_pos + robot0_eef_quat + robot0_gripper_qpos
mlp_hidden_dims = [1024, 1024]

In [5]:
from models.lift.baseline.resnet18_gmmmlp_view13rgb_model_ver1 import PiNetwork
from models.lift.baseline.resnet18_gmmmlp_view13rgb_model_ver1 import Imageonly_Model

model = PiNetwork(input_shape, image_latent_dim, action_dim, low_dim_input_dim, mlp_hidden_dims)
model.to(device)
model.float()

# test load and rollout

vision1_encoder_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/baseline/models/bc_robomimic_ver1_vision1robot0_eye_in_hand_vision2frontview_anchors256_lr0.0001_seed101_model.pt'
vision2_encoder_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/baseline/models/bc_robomimic_ver1_vision1agentview_vision2sideview_anchors256_lr0.0001_seed101_model.pt'
gmm_mlp_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/baseline/models/bc_robomimic_ver1_vision1robot0_eye_in_hand_vision2frontview_anchors256_lr0.0001_seed101_model.pt'

data1 = torch.load(vision1_encoder_path, map_location=device)  # data1 for task vision1_encoder
data2 = torch.load(vision2_encoder_path, map_location=device)  # data2 for vision2_encoder
data3 = torch.load(gmm_mlp_path, map_location=device)  # data3 for gmm_mlp

model.RGBView1ResnetEmbed.load_state_dict(data1[0])
model.RGBView3ResnetEmbed.load_state_dict(data2[1])
model.Probot.load_state_dict(data3[2])


<All keys matched successfully>

In [6]:
def get_input(traj_idx, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, actions):
        front_image = torch.from_numpy(front_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        agent_image = torch.from_numpy(agent_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        side_image = torch.from_numpy(side_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        eyeinhand_image = torch.from_numpy(eyeinhand_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        eef_pos = torch.from_numpy(eef_poses[traj_idx]).to(device).float()
        eef_quat = torch.from_numpy(eef_quats[traj_idx]).to(device).float()
        gripper_qpos = torch.from_numpy(gripper_qposes[traj_idx]).to(device).float()
        action = torch.from_numpy(actions[traj_idx]).to(device).float()

        input_args = [eef_pos.unsqueeze(0), 
                eef_quat.unsqueeze(0), 
                gripper_qpos.unsqueeze(0), 
                side_image.unsqueeze(0), 
        ]
        onlyimage_model = Imageonly_Model(model, input_args)

        return onlyimage_model, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, action

In [23]:
for traj_idx in range(len(front_images)):
    onlyimage_model, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, action = get_input(traj_idx, 
                                                                                                                                 front_images, 
                                                                                                                                 agent_images, 
                                                                                                                                 side_images, 
                                                                                                                                 eyeinhand_images, 
                                                                                                                                 eef_poses, 
                                                                                                                                 eef_quats, 
                                                                                                                                 gripper_qposes,
                                                                                                                                 actions)
    target_layers = [onlyimage_model.original_model.RGBView3ResnetEmbed.resnet18_base_model.layer4[1].conv2, 
                 onlyimage_model.original_model.RGBView3ResnetEmbed.resnet18_base_model.layer4[1].conv1,
                 onlyimage_model.original_model.RGBView1ResnetEmbed.resnet18_base_model.layer4[1].conv2, 
                 onlyimage_model.original_model.RGBView1ResnetEmbed.resnet18_base_model.layer4[1].conv1
                 ]
    cam = GradRAM(model=onlyimage_model, target_layers=target_layers)
    input_tensor = eyeinhand_image.unsqueeze(0).to(device).float()

    label = action.unsqueeze(0)

    targets = [LogProbOutputTarget(label.to(device).float())]

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]

    visualization = show_cam_on_image(input_tensor.squeeze(0).permute(1,2,0).cpu().numpy(), grayscale_cam, use_rgb=True)

    # # You can also get the model outputs without having to re-inference
    model_outputs = cam.outputs
    folder = 'results/baseline_demo0'
    if not os.path.exists(folder):
        os.makedirs(folder)
    Image.fromarray(visualization).save(f'{folder}/gradram{traj_idx}.png')

#### Ours

In [4]:
from models.lift.ours.resnet18_gmmmlp_view13rgb_rel_model_low_dim_layer import PiNetwork
from models.lift.ours.resnet18_gmmmlp_view13rgb_rel_model_low_dim_layer import Imageonly_Model

vision1_anchors = np.load("/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/ours/robot0_eye_in_hand_256anchor_images_from_agentview_idx.npy")
vision2_anchors = np.load("/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/ours/sideview_256anchor_images_from_agentview_idx.npy")

vision1_anchors_tensor = torch.tensor(vision1_anchors, dtype=torch.float32).to(device).permute(0, 3, 1, 2) / 255.0
vision2_anchors_tensor = torch.tensor(vision2_anchors, dtype=torch.float32).to(device).permute(0, 3, 1, 2) / 255.0

model = PiNetwork(input_shape, vision1_anchors_tensor, vision2_anchors_tensor, image_latent_dim, action_dim, low_dim_input_dim, mlp_hidden_dims)
model.to(device)
model.float()

# test load and rollout

vision1_encoder_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/ours/models/bc_rel_ver1_vision1robot0_eye_in_hand_vision2agentview_anchors256_lr0.0001_seed101_model.pt'
vision2_encoder_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/ours/models/bc_rel_ver1_vision1agentview_vision2sideview_anchors256_lr0.0001_seed101_model.pt'
gmm_mlp_path = '/home/generalroboticslab/Desktop/gradcam-view-stitching/models/lift/ours/models/bc_rel_ver1_vision1robot0_eye_in_hand_vision2agentview_anchors256_lr0.0001_seed101_model.pt'

data1 = torch.load(vision1_encoder_path, map_location=device)  # data1 for task vision1_encoder
data2 = torch.load(vision2_encoder_path, map_location=device)  # data2 for vision2_encoder
data3 = torch.load(gmm_mlp_path, map_location=device)  # data3 for gmm_mlp

model.RGBView1ResnetEmbed.load_state_dict(data1[0])
model.RGBView3ResnetEmbed.load_state_dict(data2[1])
model.Probot.load_state_dict(data3[2])

<All keys matched successfully>

In [5]:
def get_input(traj_idx, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, actions):
        front_image = torch.from_numpy(front_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        agent_image = torch.from_numpy(agent_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        side_image = torch.from_numpy(side_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        eyeinhand_image = torch.from_numpy(eyeinhand_images[traj_idx]).permute(2, 0, 1).to(device).float()/255
        eef_pos = torch.from_numpy(eef_poses[traj_idx]).to(device).float()
        eef_quat = torch.from_numpy(eef_quats[traj_idx]).to(device).float()
        gripper_qpos = torch.from_numpy(gripper_qposes[traj_idx]).to(device).float()
        action = torch.from_numpy(actions[traj_idx]).to(device).float()

        input_args = [eef_pos.unsqueeze(0), 
                eef_quat.unsqueeze(0), 
                gripper_qpos.unsqueeze(0), 
                side_image.unsqueeze(0), 
        ]
        onlyimage_model = Imageonly_Model(model, input_args)

        return onlyimage_model, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, action

In [10]:
for traj_idx in range(len(front_images)):
    onlyimage_model, front_image, agent_image, side_image, eyeinhand_image, eef_pos, eef_quat, gripper_qpos, action = get_input(traj_idx, 
                                                                                                                                 front_images, 
                                                                                                                                 agent_images, 
                                                                                                                                 side_images, 
                                                                                                                                 eyeinhand_images, 
                                                                                                                                 eef_poses, 
                                                                                                                                 eef_quats, 
                                                                                                                                 gripper_qposes,
                                                                                                                                 actions)
    target_layers = [
                onlyimage_model.original_model.RGBView3ResnetEmbed.resnet18_base_model.layer4[1].conv2, 
                 onlyimage_model.original_model.RGBView3ResnetEmbed.resnet18_base_model.layer4[1].conv1,
                 onlyimage_model.original_model.RGBView1ResnetEmbed.resnet18_base_model.layer4[1].conv2, 
                 onlyimage_model.original_model.RGBView1ResnetEmbed.resnet18_base_model.layer4[1].conv1
                 ]
    cam = GradRAM(model=onlyimage_model, target_layers=target_layers)
    input_tensor = eyeinhand_image.unsqueeze(0).to(device).float()

    label = action.unsqueeze(0)

    targets = [LogProbOutputTarget(label.to(device).float())]

    # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets,aug_smooth=True, eigen_smooth=True)

    # In this example grayscale_cam has only one image in the batch:
    # print(grayscale_cam.shape)
    grayscale_cam = grayscale_cam[0, :]

    visualization = show_cam_on_image(input_tensor.squeeze(0).permute(1,2,0).cpu().numpy(), grayscale_cam, use_rgb=True)

    # # You can also get the model outputs without having to re-inference
    model_outputs = cam.outputs
    folder = 'results/ours_demo0'
    if not os.path.exists(folder):
        os.makedirs(folder)
    Image.fromarray(visualization).save(f'{folder}/gradram{traj_idx}.png')

In [25]:
def make_vid(image_folder, video_path):
    import cv2
    images = []
    length = len(os.listdir(image_folder))
    for i in range(length):
        filename = f'gradram{i}.png'
        path = os.path.join(image_folder, filename)
        if os.path.exists(path):
            img = cv2.imread(path)
            images.append(img)
        else:
            print(f"Image {filename} not found. Skipping.")

    if not images:
        raise RuntimeError("No images found. Check your file paths.")
    
    height, width, layers = images[0].shape

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(video_path, fourcc, 10, (width, height))

    for img in images:
        video.write(img)

    video.release()

In [None]:
image_folder = "/home/generalroboticslab/Desktop/gradcam-view-stitching/results/ours_demo0"
video_path = 'results/video_ours_eye.mp4'  # replace with your desired output path

make_vid(image_folder, video_path)

In [26]:
image_folder = "/home/generalroboticslab/Desktop/gradcam-view-stitching/results/baseline_demo0"
video_path = 'results/video_baseline_eye.mp4'  # replace with your desired output path

make_vid(image_folder, video_path)