In [1]:
import os
import torch
import numpy as np

import cv2
from matplotlib import cm
import matplotlib.pyplot as plt

from pointnav_vo.vo import VisualOdometryTransformerActEmbed
from pointnav_vo.config.vo_config.default import get_config as get_vo_config

In [2]:
image_load_path = "/datasets/home/memmel/PointNav-VO/img_attn_single"
images = ["forward.png", "left.png", "right.png"]
actions = [1,2,3]

config_path = "/datasets/home/memmel/PointNav-VO/configs/vo/vit_baselines/"
configs = [
    "vo_vit_b_dino_act_rgbd.yaml",
    "vo_vit_b_dino_act_rgbd_freeze.yaml",
    "vo_vit_b_in21k_act_rgbd.yaml",
    "vo_vit_b_in21k_act_rgbd_freeze.yaml",
    "vo_vit_b_mmae_act_rgbd.yaml",
    "vo_vit_b_mmae_act_rgbd_freeze.yaml",
]

In [3]:
def load_model(config, model_load_path):
    
    model = VisualOdometryTransformerActEmbed(
        observation_space=config.VO.MODEL.visual_type,
        observation_size=(config.VO.VIS_SIZE_W, config.VO.VIS_SIZE_H),
        hidden_size=config.VO.MODEL.hidden_size,
        backbone=config.VO.MODEL.visual_backbone,
        normalize_visual_inputs=True,
        output_dim=3,
        dropout_p=config.VO.MODEL.dropout_p,
        discretized_depth_channels=0,
        top_down_view_pair_channel=0,
        cls_action=config.VO.MODEL.cls_action,
        train_backbone=config.VO.MODEL.train_backbone,
        pretrain_backbone=config.VO.MODEL.pretrain_backbone,
        custom_model_path=config.VO.MODEL.custom_model_path,
        depth_aux_loss=bool(config.VO.TRAIN.depth_aux_loss),
    )
    
    checkpoint = torch.load(model_load_path, map_location=torch.device('cpu'))

    def convert_dataparallel_weights(weights):
        converted_weights = {}
        keys = weights.keys()
        for key in keys:
            if 'vit.cls_token' in key:
                continue
            new_key = key.split("module.")[-1]
            converted_weights[new_key] = weights[key]
        return converted_weights

    model_state = convert_dataparallel_weights(checkpoint['model_states'][-1])
    model.load_state_dict(model_state, strict=False)
    
    return model

# model = load_model(config, model_load_path)

# obs_size = model.obs_size
# obs_size_single = model.obs_size_single

In [4]:
def pre_process_input(image_load_path, action):
    img_original = plt.imread(image_load_path)
    rgb = torch.tensor(np.uint8(img_original[:img_original.shape[0]//2,:,:3] * 255))
    rgb = torch.cat((rgb[:rgb.shape[0]//2],rgb[rgb.shape[0]//2:]),dim=2).unsqueeze(0)
    depth = torch.tensor(img_original[img_original.shape[0]//2:,:,0]).unsqueeze(-1)
    depth = torch.cat((depth[:depth.shape[0]//2],depth[depth.shape[0]//2:]),dim=2).unsqueeze(0)

    batch_pairs = {}
    batch_pairs["rgb"] = rgb
    batch_pairs["depth"] = depth
    batch_pairs["actions"] = action
    
    return batch_pairs

# action = torch.tensor([1])
# batch_pairs = pre_process_input(image_load_path, action)

In [5]:
# model.observation_strip = ["rgb"]

In [6]:
def pre_process_img(batch_pairs, config, obs_size, obs_size_single):
    rgb_check = "rgb" in config.VO.MODEL.visual_type
    # dont visualize depth when aux depth loss is used
    depth_check = "depth" in config.VO.MODEL.visual_type and not config.VO.TRAIN.depth_aux_loss

    if rgb_check:
        plot_idx = torch.randint(0, batch_pairs["rgb"].shape[0], (1,1)).item()
    elif depth_check:
        plot_idx = torch.randint(0, batch_pairs["depth"].shape[0], (1,1)).item()
    else:
        pass

    if rgb_check:
        rgb = batch_pairs["rgb"][plot_idx].unsqueeze(0)
        rgb = torch.cat((rgb[:,:,:,:rgb.shape[-1]//2], rgb[:,:,:,rgb.shape[-1]//2:]),dim=1)
        rgb = rgb.permute(0,3,1,2).contiguous()
        rgb = torch.nn.functional.interpolate(rgb, size=(obs_size_single[0]*2,obs_size_single[1]))
        img = rgb
        if config.VO.MODEL.visual_type.count("rgb") == 2:
            img = torch.cat((rgb,rgb),dim=2)
    if depth_check:
        depth =  batch_pairs["depth"][plot_idx].unsqueeze(0)
        depth = torch.cat((depth[:,:,:,:depth.shape[-1]//2], depth[:,:,:,depth.shape[-1]//2:]),dim=1)
        depth = depth.permute(0,3,1,2).contiguous()
        depth = torch.nn.functional.interpolate(depth, size=(obs_size_single[0]*2,obs_size_single[1]))
        depth = depth.expand(-1, 3, -1, -1) * 255.
        img = depth
        if config.VO.MODEL.visual_type.count("depth") == 2:
            img = torch.cat((depth,depth),dim=2)

    if rgb_check and depth_check:
        img = torch.cat((rgb,depth),dim=2)

    # log plain image
    img = img.cpu().numpy().squeeze()
    img = np.uint8(img).transpose(1,2,0)

    return img, plot_idx
    
# plot_img, plot_idx = pre_process_img(batch_pairs, config, obs_size, obs_size_single)
# plt.imshow(plot_img)

In [7]:
def get_attention(model, batch_pairs, action, reduce="max", color_map=cm.inferno):
    
    batch_pairs["actions"] = action
    features, attn = model(batch_pairs, batch_pairs["actions"], return_attention=True)
    batch_pairs["self_attention"] = attn

    nh = batch_pairs["self_attention"].shape[1] # number of head
    # we keep only the output patch attention
    if config.VO.MODEL.pretrain_backbone == 'mmae':
        # last token is cls token
        cls_token_idx = -1
        attn = batch_pairs["self_attention"][plot_idx, :, -1, :-1].reshape(nh, -1)

    else:
        # first token is cls token
        cls_token_idx = 0
        attn = batch_pairs["self_attention"][plot_idx, :, 0, 1:].reshape(nh, -1)

    path_size = 16
    attn = attn.reshape(nh, obs_size[0]//path_size, obs_size[1]//path_size)
    attn = torch.nn.functional.interpolate(attn.unsqueeze(0), scale_factor=path_size, mode="nearest")[0]

    attn = attn.detach().cpu().numpy()
    
    if reduce == "max":
        attn_agg = np.max(attn, axis=0)
    elif reduce == "min":
        attn_agg = np.min(attn, axis=0)
    else:
        attn_agg = np.mean(attn, axis=0)
    
    
    attn_agg = np.uint8(255*color_map(attn_agg/ attn_agg.max())[:,:,:-1])
    
    return attn_agg
        
# attn_agg_max = get_attention(action)
# plt.imshow(attn_agg_max)

In [8]:
def plot_img_attention(img, attn, action, img_name, obs_size_single, save_path):
    os.makedirs(save_path, exist_ok=True)
    for i in range(attn.shape[0]//obs_size_single[0]):
        attn_single = attn[i*obs_size_single[0]:(i+1)*obs_size_single[0]]
        plt.imsave(os.path.join(save_path,f'{img_name}_attn_act_{action.item()}_{i}.png'), attn_single)
        img_single = img[i*obs_size_single[0]:(i+1)*obs_size_single[0]]
        plt.imsave(os.path.join(save_path,f'{img_name}_img_act_{action.item()}_{i}.png'), img_single)
        overlay_single = cv2.addWeighted(attn_single, 0.8, img_single, 0.6, 0.0)
        plt.imsave(os.path.join(save_path,f'{img_name}_overlay_act_{action.item()}_{i}.png'), overlay_single)
        
# plot_img_attention(plot_img, attn_agg_max, action, obs_size_single, save_path)

In [19]:
for cfg in configs:
    
    config_yaml = os.path.join(config_path, cfg)
    config = get_vo_config(config_yaml, [])
    
    exp = config_yaml.split('/')[-1].split('.')[0]
    model_load_path = os.path.join("train_log/final/vit/unique/", exp, "checkpoints/best_vo.pth")
    
    save_path = os.path.join("img_attn_single", exp)
    
    model = load_model(config, model_load_path)
    obs_size = model.obs_size
    obs_size_single = model.obs_size_single
    
    for img, act in zip(images, actions):
        act = torch.tensor(act)
        img_name = img.split('.')[0]
        
        image_path = os.path.join(image_load_path, img)
        batch_pairs = pre_process_input(image_path, act)
        plot_img, plot_idx = pre_process_img(batch_pairs, config, obs_size, obs_size_single)
        attn_agg_max = get_attention(model, batch_pairs, act)
        plot_img_attention(plot_img, attn_agg_max, act, img_name, obs_size_single, save_path)
    

In [9]:
images = ["right.png"]
actions = [1,2,3]

configs = ["vo_vit_b_mmae_act_rgbd.yaml",]

In [11]:
for cfg in configs:
    
    config_yaml = os.path.join(config_path, cfg)
    config = get_vo_config(config_yaml, [])
    
    exp = config_yaml.split('/')[-1].split('.')[0]
    model_load_path = os.path.join("train_log/final/vit/unique/", exp, "checkpoints/best_vo.pth")
    
    save_path = os.path.join("img_attn_single", exp)
    
    model = load_model(config, model_load_path)
    obs_size = model.obs_size
    obs_size_single = model.obs_size_single
    
    for img in images:
        for act in actions:
            act = torch.tensor(act)
            img_name = img.split('.')[0]

            image_path = os.path.join(image_load_path, img)
            batch_pairs = pre_process_input(image_path, act)
            plot_img, plot_idx = pre_process_img(batch_pairs, config, obs_size, obs_size_single)
            attn_agg_max = get_attention(model, batch_pairs, act)
            plot_img_attention(plot_img, attn_agg_max, act, img_name, obs_size_single, save_path)
    