In [1]:
import os
from os.path import join, dirname, abspath
import sys
CURRENT_DIR = os.getcwd()
sys.path.insert(0, join(CURRENT_DIR, '../..'))  # Import local models

from agents.PickPlaceAgent import PickPlaceAgent
from cliport.dataset import RavensDataset
import torch
from src.utils import get_affordance_map_from_formatted_input
import matplotlib.pyplot as plt
import numpy as np
from src.utils import convert_angle_to_channel

In [10]:
agent = PickPlaceAgent(num_rotations=12, lr=1e-4, device='cuda')
agent.pick_model.load_state_dict(torch.load("/home/ubuntu/VLM/checkpoints/checkpoint_PairPack_epoch390.pth")['pick_state_dict'])
agent.place_model.load_state_dict(torch.load("/home/ubuntu/VLM/checkpoints/checkpoint_PairPack_epoch390.pth")['place_state_dict'])

<All keys matched successfully>

In [11]:
train_dataset_cfg = {"dataset":{"type": "single",
                    "images": True,
                    "cache": False,
                    "augment":{"theta_sigma":60},
                    "cache_size": 350},
                    }

# load data
train_dataset = RavensDataset('/home/ubuntu/cliport/data/packing-boxes-pairs-full-val', train_dataset_cfg, n_demos=100, augment=False)

In [88]:
import torchvision
from torchvision.utils import draw_segmentation_masks

def show_images(images, affordances):
    if not isinstance(images, list):
        images = [images]
    if not isinstance(affordances, list):
        affordances = [affordances]
    ncols = min(len(images), len(affordances))
    fig, axs = plt.subplots(nrows=3, ncols=ncols, squeeze=False, figsize=(10, 15))
    for i in range(ncols):
        image = images[i].detach().cpu()
        affordance = affordances[i].detach().cpu()

        image_uint8 = (image * 255.0).to(torch.uint8)

        overlaid_affordance = draw_segmentation_masks(
            image_uint8, masks=(affordance > 1.1), colors="green", alpha=1.0
        )
        overlaid_affordance_pil = torchvision.transforms.functional.to_pil_image(
            overlaid_affordance
        )

        image_pil = torchvision.transforms.functional.to_pil_image(image)
        affordance_pil = torchvision.transforms.functional.to_pil_image(affordance)

        heatmap = (np.array(plt.cm.jet(affordance_pil)) * 255).astype(np.uint8)
        heatmap = heatmap[:, :, 0:3]
        heatmap_pil = torchvision.transforms.functional.to_pil_image(heatmap)

        blend_pil = Image.blend(image_pil, heatmap_pil, alpha=0.5)

        axs[0, i].imshow(np.asarray(overlaid_affordance_pil))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

        label_pil = torchvision.transforms.functional.to_pil_image(affordance)
        axs[1, i].imshow(np.asarray(label_pil))
        axs[1, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

        axs[2, i].imshow(np.asarray(blend_pil))
        axs[2, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

# Pick Outputs

In [128]:
batch_data = next(iter(train_dataset))
inp, _ = batch_data

In [None]:
act, (pick_affordances, place_affordances) = agent.act(inp['img'], inp['lang_goal'])

pick_affordances = torch.nn.functional.softmax(pick_affordances, dim=1)
pick_affordances = pick_affordances.detach().cpu()
pick_affordances = pick_affordances.view(320,160)

place_affordances = torch.nn.functional.softmax(place_affordances, dim=1)
place_affordances = place_affordances.detach().cpu()
place_affordances = place_affordances.view(12,320,160)

print(act) 
print(inp['lang_goal'])
show_images(torch.tensor(inp['img'][:,:,0:3].transpose((2, 1,0))/255), pick_affordances.T/pick_affordances.max())

In [None]:
# Changing the language command
new_lang_goal = "Insert the green squares into that brown box"
act, (pick_affordances, place_affordances) = agent.act(inp['img'], new_lang_goal)

pick_affordances = torch.nn.functional.softmax(pick_affordances, dim=1)
pick_affordances = pick_affordances.detach().cpu()
pick_affordances = pick_affordances.view(320,160)

place_affordances = torch.nn.functional.softmax(place_affordances, dim=1)
place_affordances = place_affordances.detach().cpu()
place_affordances = place_affordances.view(12,320,160)

print(act) 
print(new_lang_goal)
show_images(torch.tensor(inp['img'][:,:,0:3].transpose((2, 1,0))/255), pick_affordances.T/pick_affordances.max())

# Place Outputs

In [124]:
batch_data = next(iter(train_dataset))
inp, _ = batch_data

In [None]:
act, (pick_affordances, place_affordances) = agent.act(inp['img'], inp['lang_goal'])

pick_affordances = torch.nn.functional.softmax(pick_affordances, dim=1)
pick_affordances = pick_affordances.detach().cpu()
pick_affordances = pick_affordances.view(320,160)

place_affordances = torch.nn.functional.softmax(place_affordances, dim=1)
place_affordances = place_affordances.detach().cpu()
place_affordances = place_affordances.view(12,320,160)

best_place_rotation_idx = convert_angle_to_channel(act['place'][2], 12)
place_affordances = place_affordances[best_place_rotation_idx]

print(act) 
print(inp['lang_goal'])
show_images(torch.tensor(inp['img'][:,:,0:3].transpose((2, 1,0))/255), place_affordances.T/place_affordances.max())

In [None]:
# Changing the language command
new_lang_goal = "Insert the pink looking square into that brown box"
act, (pick_affordances, place_affordances) = agent.act(inp['img'], new_lang_goal)

pick_affordances = torch.nn.functional.softmax(pick_affordances, dim=1)
pick_affordances = pick_affordances.detach().cpu()
pick_affordances = pick_affordances.view(320,160)

place_affordances = torch.nn.functional.softmax(place_affordances, dim=1)
place_affordances = place_affordances.detach().cpu()
place_affordances = place_affordances.view(12,320,160)

print(act) 
print(new_lang_goal)
show_images(torch.tensor(inp['img'][:,:,0:3].transpose((2, 1,0))/255), pick_affordances.T/pick_affordances.max())