# Perturbation saliency on LBC agents

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
import torch
torch.__version__ # Get PyTorch and CUDA version
torch.cuda.is_available() # Check that CUDA works
torch.cuda.device_count() # Check how many CUDA capable devices you have

# Print device human readable names
torch.cuda.get_device_name(0)


'NVIDIA GeForce RTX 2080 Ti'

In [4]:
from __future__ import print_function
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously

import matplotlib.pyplot as plt
import matplotlib as mpl #; mpl.use("Agg") # turn this on when only saving and not showing
import matplotlib.animation as manimation

import gym, os, sys, time, argparse
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import cv2
# sys.path.append('..')
# from visualize_dreyevr_agent_saliency.saliency import get_mask
# from saliency import get_env_meta

In [5]:
from PIL import Image
import ast

In [6]:
from pathlib import Path
import sys
LBC_root = Path("/scratch/abhijatb/Bosch22/LbC_DReyeVR/")
CARLA_ROOT = Path("/scratch/abhijatb/Bosch22/carla.harp_p13bd/")
sys.path.insert(0, str(CARLA_ROOT / 'PythonAPI/carla'))
sys.path.insert(0, str(CARLA_ROOT / 'PythonAPI/carla/dist/carla-0.9.13-py3.6-linux-x86_64.egg'))
sys.path.insert(0, str(LBC_root))
sys.path.insert(0, str(LBC_root / 'leaderboard'))
sys.path.insert(0, str(LBC_root / 'leaderboard/team_code'))
sys.path.insert(0, str(LBC_root / 'scenario_runner'))
sys.path.insert(0, str(CARLA_ROOT / 'PythonAPI/examples')) # for DReyeVR_utils

from image_agent import ImageAgent
import torch, torchvision
torch.cuda.set_device(0)
from copy import deepcopy

## Load data

In [7]:
def get_data(route_data_path, sampled_dataidx):
    route_rgb_path = route_data_path / 'rgb'
    route_rgb_pathL = route_data_path / 'rgb_left'
    route_rgb_pathR = route_data_path / 'rgb_right'
    route_measurements_path = route_data_path / 'measurements'

    rgb_img_path = route_rgb_path / '{:04d}.png'.format(sampled_dataidx)
    rgb_imgL_path = route_rgb_pathL / '{:04d}.png'.format(sampled_dataidx)
    rgb_imgR_path = route_rgb_pathR / '{:04d}.png'.format(sampled_dataidx)
    rgb_img = Image.open(str(rgb_img_path))
    rgb_imgL = Image.open(str(rgb_imgL_path))
    rgb_imgR = Image.open(str(rgb_imgR_path))

    measure_path = route_measurements_path / '{:04d}.json'.format(sampled_dataidx)
    with open(measure_path) as read_file:
        json_str = read_file.readline()
        input_data = ast.literal_eval(json_str)
    input_data['rgb'] = rgb_img
    input_data['rgb_left'] = rgb_imgL
    input_data['rgb_right'] = rgb_imgR
    return input_data

In [8]:
data_root = Path("/scratch/abhijatb/Bosch22/dreyevr_recordings_sensors/LBC_config")
routedir_iterator = data_root.iterdir()
for route_dir in routedir_iterator:
    break
route_data_path = route_dir
route_rgb_path = route_data_path / 'rgb'
final_datapt_idx = int(sorted(list(route_rgb_path.glob('*.png')))[-1].stem)
datapt_idcs = np.arange(final_datapt_idx)+1 # starts from 0001.png

sampled_dataidx = np.random.choice(datapt_idcs)
input_data = get_data(route_data_path, sampled_dataidx)

In [9]:
# globals/setup for saliency goes here
post_h=144
post_w=256
# prepro = lambda img: cv2.resize(img, (post_h,post_w)).astype(np.float32).reshape(post_h,post_w)/255.
# prepro = lambda img: img.astype(np.float32)/255.
def prepro(img):
    return img.astype(np.float32)/255.
# searchlight = lambda I, mask: I*mask + gaussian_filter(I, sigma=3)*(1-mask) # choose an area NOT to blur
# occlude = lambda I, mask: I*(1-mask) + gaussian_filter(I, sigma=3)*mask # choose an area to blur
def occlude(I, mask):
    return I*(1-mask) + gaussian_filter(I, sigma=3)*mask
# what else goes here

In [10]:
def fwd_img_model(dreyevr_img_agent, input_data, mask=None):
#     tick_data = dreyevr_img_agent.offline_tick(input_data)
    tick_data = deepcopy(input_data)
    # order is 'rgb', 'rgb_left', 'rgb_right'
    img = torchvision.transforms.functional.to_tensor(tick_data['image'])
    img = img[None].cuda()

    target = torch.from_numpy(tick_data['target'])
    target = target[None].cuda()
    _, (_, _), logits = dreyevr_img_agent.net.forward_w_logit(img, target)
    # this is (1x4xHW) -- 4 is the num of intermediate pts being pred 
    flat_logits = logits.view(logits.shape[:-2] + (-1,)) #.detach().cpu().numpy()
    return flat_logits    
    

def get_mask(center, size, r):
    y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]
    keep = x*x + y*y <= 1
    mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels
    mask = gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
    return mask/mask.max()

def apply_mask(input_data, mask, interp_func, channel=0):
    masked_data = deepcopy(input_data)
    if channel < 3:        
        img = masked_data['image'][..., (channel*3):(channel*3)+3]
        # perturb input I -> I'
        im1 = interp_func(prepro(img[...,0]).squeeze(), mask)
        im2 = interp_func(prepro(img[...,1]).squeeze(), mask)
        im3 = interp_func(prepro(img[...,2]).squeeze(), mask)
        masked_data['image'][..., (channel*3):(channel*3)+3] = (np.stack((im1, im2, im3), axis=2)*255).astype(int)
#     elif channel == 3: # actually the 3rd channel is happening inside the thing so irrelevant
#         img = input_data['image'][..., -1]
#         im = interp_func(prepro(img).squeeze(), mask).reshape(post_h,post_w)
#         masked_data['image'][..., -1] = im
    else:
        raise ValueError("channel must be 0-2:rgb/left/right or 3:command heatmap")     
    return masked_data

In [17]:
def fwd_img_model_batch(dreyevr_img_agent, batch_data):
    # img = torchvision.transforms.functional.to_tensor(tick_data[1:]['image'])    
    result = [torchvision.transforms.functional.to_tensor(input_data_dict['image'])\
                       for input_data_dict in batch_data]
    imgs_tensor = torch.stack(result, 0)
    imgs_tensor = imgs_tensor.cuda()

    targets_tensor = [torch.from_numpy(input_data_dict['target'])\
                       for input_data_dict in batch_data]
    targets_tensor = torch.stack(targets_tensor,0)
    targets_tensor = targets_tensor.cuda()

    _, (_, _), logits = dreyevr_img_agent.net.forward_w_logit(imgs_tensor, targets_tensor)
    flat_logits = logits.view(logits.shape[:-2] + (-1,))
    return flat_logits

In [11]:
def _get_masked_data_parallel(ijk, other_data):
    i,j,k = ijk
    input_data, r, interp_func = other_data 
    return get_and_apply_mask([i,j], input_data, r, interp_func, k)

def get_and_apply_mask(center, input_data, r, interp_func, channel):
    H, W = input_data['image'][..., (channel*3)].shape[:2]
    mask = get_mask(center=center, size=[H, W], r=r)
    masked_data = apply_mask(input_data, mask, interp_func=occlude, channel=channel)
    return masked_data

In [12]:
def score_frame(dreyevr_img_agent, input_data, r, d, interp_func, pt_aggregate="leading"):
    # r: radius of blur
    # d: density of scores (if d==1, then get a score for every pixel...
    #    if d==2 then every other, which is 25% of total pixels for a 2D image)
    # unmodified image logits
    input_data = dreyevr_img_agent.offline_tick(input_data)
    L = fwd_img_model(dreyevr_img_agent, input_data)    
    scores = np.zeros((int(post_h/d)+1,int(post_w/d)+1, 3)) # saliency scores S(t,i,j)

    for i in range(0,post_h,d):
        for j in range(0,post_w,d):
            for k in range(0,3): # this is for the channel rgb/left/right
                mask = get_mask(center=[i,j], size=[post_h,post_w], r=r)
                masked_data = apply_mask(input_data, mask, occlude, channel=k)
                # masked image logits
                l = fwd_img_model(dreyevr_img_agent, masked_data)
                # this corresponds to 
                if pt_aggregate=="leading":
                    scores[int(i/d),int(j/d), k] = (L-l)[:,:2,:].pow(2).sum().mul_(.5).data.item()
                elif pt_aggregate=="all":
                    scores[int(i/d),int(j/d), k] = (L-l).pow(2).sum().mul_(.5).data.item()
                else:
                    raise ValueError("only 'leading'(first 2) and 'all' aggregations are supported")

    pmax = scores.max()
    scores = cv2.resize(scores, dsize=(post_w, post_h), interpolation=cv2.INTER_LINEAR).astype(np.float32)
    smap = pmax * scores / scores.max()
    smap = smap.astype(int)
    return smap

In [13]:
def score_frame_batched(dreyevr_img_agent, input_data, r, d, interp_func, pt_aggregate="leading", batch_size=64):
    # r: radius of blur
    # d: density of scores (if d==1, then get a score for every pixel...
    #    if d==2 then every other, which is 25% of total pixels for a 2D image)
    # unmodified image logits
    input_data = dreyevr_img_agent.offline_tick(input_data)
    L = fwd_img_model(dreyevr_img_agent, input_data) 
    
    # more parallelism here
    masked_data_arr = np.empty((int(post_h/d)+1, int(post_w/d)+1, 3), dtype=dict)
    for i in range(0,post_h,d):
        for j in range(0,post_w,d):
            for k in range(0,3): # this is for the channel rgb/left/right
                masked_data_arr[int(i/d),int(j/d), k] = get_and_apply_mask([i,j], input_data, r, interp_func, k)
    masked_data_flat = masked_data_arr.reshape(-1)
    
    # aggregate batches for forward
    num_batches = int(masked_data_flat.size/batch_size)+1
    scores = np.zeros(shape=masked_data_flat.shape)

    for i in range(num_batches):
        if i < num_batches-1:
            batch_data = masked_data_flat[i*batch_size:(i+1)*batch_size]        
        else:
            batch_data = masked_data_flat[i*batch_size:]    

        flat_logits = fwd_img_model_batch(dreyevr_img_agent, batch_data)

        if pt_aggregate=="leading":
            score_temp = (L-flat_logits)[:,:2,:].pow(2).sum(dim=[1,2]).mul_(.5).data.tolist()
        else:
            score_temp = (L-flat_logits).pow(2).sum(dim=[1,2]).mul_(.5).data.tolist()

        if i< num_batches-1:
            scores[i*batch_size:(i+1)*batch_size] = score_temp
        else:
            scores[i*batch_size:] = score_temp
    scores = scores.reshape(masked_data_arr.shape)
    
#     ijks = list(itertools.product(*[range(0,post_h,d), range(0,post_w,d), range(3)]))
#     other_data = (input_data, r, interp_func)
#     _args = zip(ijks, itertools.repeat(other_data))
#     with Pool() as pool:
#         results = pool.starmap(_get_masked_data_parallel, _args)
    
#     results  
    pmax = scores.max()
    scores = cv2.resize(scores, dsize=(post_w, post_h), interpolation=cv2.INTER_LINEAR).astype(np.float32)
    smap = pmax * scores / scores.max()
    smap = smap.astype(int)
    return smap

##  Make saliency movie

In [14]:
import matplotlib.pyplot as plt
import matplotlib as mpl ; mpl.use("Agg")
import matplotlib.animation as manimation

In [18]:
# set metadata for movie to be written
def saliency_movie(path_to_conf_file, route_data_path):
    # model load
#     path_to_conf_file = LBC_root / "checkpoints/imgmodel_17trainseqs_CC-1_LR-4_FT/epoch=27.ckpt"
    # "checkpoints/imgmodel_17trainseqs_CC-2_LR-4_FT/epoch=47.ckpt"
    dreyevr_img_agent = ImageAgent(str(path_to_conf_file))

    r, d = 10, 10
    aggregate_method ="leading"
    batch_size = 96
    
    model_name = path_to_conf_file.parents[0].stem
    movie_title = model_name \
                    + "_" + route_data_path.stem.split('_')[1] \
                    + "_" + aggregate_method  + ".mp4"

    save_dir = "/scratch/abhijatb/Bosch22/LbC_DReyeVR/saliency_movies/"
    if os.path.exists(save_dir + movie_title):
        print(movie_title, "already exists")
        return
    else:
        print(movie_title, "being processed")
    resolution=150

    start = time.time()
    FFMpegWriter = manimation.writers['ffmpeg']
    metadata = dict(title=movie_title, artist='ajdroid', comment='dreyevrLBC-saliency-video')
    writer = FFMpegWriter(fps=2, metadata=metadata)
    
    final_datapt_idx = int(sorted(list((route_data_path/'rgb').glob('*.png')))[-1].stem)
    datapt_idcs = np.arange(final_datapt_idx)+1 # starts from 0001.png
    prog = '' ; total_frames = datapt_idcs[-1]
    
    f = plt.figure(figsize=[10*3.2, 10], dpi=resolution)

    with writer.saving(f, save_dir + movie_title, resolution):
        for i in datapt_idcs:
            input_data = get_data(route_data_path, i)
            smap = score_frame_batched(dreyevr_img_agent, input_data, r, d, 
                               interp_func=occlude, pt_aggregate=aggregate_method, batch_size=batch_size)
            # image order is rgb, rgb_left, rgb_right
            salmap_3stack = np.hstack([smap[:,:,1], smap[:,:,0], smap[:,:,2]])
            img_3stack = np.hstack([input_data['image'][...,3:6],
                        input_data['image'][...,0:3],
                        input_data['image'][...,6:9]])            

            plt.imshow(img_3stack, alpha=1)
            plt.imshow(salmap_3stack, alpha=0.5, cmap=plt.get_cmap('Reds'))
            plt.axis('off')
            writer.grab_frame() ; f.clear()

            tstr = time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start))
            print('\ttime: {} | progress: {:.1f}%'.format(tstr, 100*i/total_frames), end='\r')
    print('\nfinished.')

In [None]:
# data load
data_root = Path("/scratch/abhijatb/Bosch22/dreyevr_recordings_sensors/LBC_config")
routedir_iterator = data_root.iterdir()
# routedir_iterator = data_root.glob("*11*/")
path_to_conf_file = LBC_root / "checkpoints/pre_trained/epoch_24.ckpt"

paths_to_conf_files = [LBC_root / "checkpoints/pre_trained/epoch_24.ckpt",
                      LBC_root / "checkpoints/imgmodel_17trainseqs_CC-1_LR-4_FT/epoch=27.ckpt"]

for route_data_path in routedir_iterator:
    for path_to_conf_file in paths_to_conf_files:
        saliency_movie(path_to_conf_file, route_data_path)    

<All keys matched successfully>
pre_trained_swapnil32_leading.mp4 being processed
	time: 01h 15m 46s | progress: 100.0%
finished.
<All keys matched successfully>
imgmodel_17trainseqs_CC-1_LR-4_FT_swapnil32_leading.mp4 already exists
<All keys matched successfully>
pre_trained_brady54_leading.mp4 being processed
	time: 01h 13m 41s | progress: 100.0%
finished.
<All keys matched successfully>
imgmodel_17trainseqs_CC-1_LR-4_FT_brady54_leading.mp4 already exists
<All keys matched successfully>
pre_trained_tab32_leading.mp4 being processed
	time: 01h 15m 37s | progress: 100.0%
finished.
<All keys matched successfully>
imgmodel_17trainseqs_CC-1_LR-4_FT_tab32_leading.mp4 already exists
<All keys matched successfully>
pre_trained_dexter54_leading.mp4 being processed
	time: 01h 13m 13s | progress: 100.0%
finished.
<All keys matched successfully>
imgmodel_17trainseqs_CC-1_LR-4_FT_dexter54_leading.mp4 already exists
<All keys matched successfully>
pre_trained_tab54_leading.mp4 being processed
	tim

In [None]:
route_data_path = route_dir

final_datapt_idx = int(sorted(list(route_rgb_path.glob('*.png')))[-1].stem)
datapt_idcs = np.arange(final_datapt_idx)+1 # starts from 0001.png

sampled_dataidx = np.random.choice(datapt_idcs)
input_data = get_data(route_data_path, sampled_dataidx)

## Load model to introspect (interactive testing)

In [None]:
!ls /scratch/abhijatb/Bosch22/LbC_DReyeVR/checkpoints/

In [None]:
path_to_conf_file = LBC_root / "checkpoints/imgmodel_17trainseqs_CC-1_LR-4_FT/epoch=27.ckpt"
dreyevr_img_agent = ImageAgent(str(path_to_conf_file))

In [None]:
d=10
r=10
aggregate_method="leading"
smap = score_frame(dreyevr_img_agent, input_data, r, d, interp_func=occlude, pt_aggregate=aggregate_method)

In [None]:
# show salience across images
salmap_3stack = np.hstack([smap[:,:,1], smap[:,:,0], smap[:,:,2]])
img_3stack = np.hstack([input_data['image'][...,0:3],
                        input_data['image'][...,3:6],
                        input_data['image'][...,6:9]])

In [None]:
f = plt.figure(figsize=[10*3.2, 10], dpi=150)
plt.imshow(img_3stack, alpha=1)
plt.imshow(salmap_3stack, alpha=0.5, cmap=plt.get_cmap('Reds'))
plt.show()
f

In [None]:
f = plt.figure()
plt.imshow(smap[:,:,0])
plt.show()
f

In [None]:
input_img =  input_data['image'][...,k*3:k*3+3] if k <3 else input_data['image'][...,-1]
f = plt.figure()
plt.imshow(input_img)
plt.show()
f

In [None]:
masked_img = masked_data['image'][...,k*3:k*3+3] if k <3 else masked_data['image'][...,-1]
f = plt.figure()
plt.imshow(masked_img)
plt.show()
f

In [None]:
for i in range(0,post_h,d):
    for j in range(0,post_w,d):
        for k in range(0,3): # this is for the channel rgb/left/right/target
            break
            mask = get_mask(center=[i,j], size=[post_h,post_w], r=r) # perturbation mask
            masked_data = apply_mask(input_data, mask, interp_func, channel=k)
            l = fwd_img_model(dreyevr_img_agent, masked_data)
            scores[int(i/d),int(j/d), k] = (L-l).sum()
# avoid range artifacts while resizing
pmax = scores.max()
scores = cv2.resize(scores, dsize=(post_h,post_w), interpolation=cv2.INTER_LINEAR).astype(np.float32)
# return pmax * scores / scores.max()

In [None]:
logits.shape

In [None]:
points_world

In [None]:
angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
steer = self._turn_controller.step(angle)
steer = np.clip(steer, -1.0, 1.0)

In [None]:
flat_logits = logits.view(logits.shape[:-2] + (-1,))
flat_logits.shape

In [None]:
logits.shape

In [None]:
points.shape

In [None]:
def fwd_for_saliency(self, input_data):
    tick_data = self.offline_tick(input_data)

    img = torchvision.transforms.functional.to_tensor(tick_data['image'])
    img = img[None].cuda()

    target = torch.from_numpy(tick_data['target'])
    target = target[None].cuda()

    points, (target_cam, _) = self.net.forward(img, target)
    points_cam = points.clone().cpu()
    points_cam[..., 0] = (points_cam[..., 0] + 1) / 2 * img.shape[-1]
    points_cam[..., 1] = (points_cam[..., 1] + 1) / 2 * img.shape[-2]
    points_cam = points_cam.squeeze()
    points_world = self.converter.cam_to_world(points_cam).numpy()

    aim = (points_world[1] + points_world[0]) / 2.0
    angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
    steer = self._turn_controller.step(angle)
    steer = np.clip(steer, -1.0, 1.0)

    desired_speed = np.linalg.norm(points_world[0] - points_world[1]) * 2.0
    # desired_speed *= (1 - abs(angle)) ** 2

    speed = tick_data['speed']

    brake = desired_speed < 0.4 or (speed / desired_speed) > 1.1

    delta = np.clip(desired_speed - speed, 0.0, 0.25)
    throttle = self._speed_controller.step(delta)
    throttle = np.clip(throttle, 0.0, 0.75)
    throttle = throttle if not brake else 0.0

    control = carla.VehicleControl()
    control.steer = steer
    control.throttle = throttle
    control.brake = float(brake)

    if DEBUG:
        debug_display(
                tick_data, target_cam.squeeze(), points.cpu().squeeze(),
                steer, throttle, brake, desired_speed,
                self.step)

    return control


In [None]:
points_cam[..., 0]

In [None]:
def run_step(self, input_data, timestamp):
    if not self.initialized:
        self._init()

    tick_data = self.tick(input_data)

    img = torchvision.transforms.functional.to_tensor(tick_data['image'])
    img = img[None].cuda()

    target = torch.from_numpy(tick_data['target'])
    target = target[None].cuda()

    points, (target_cam, _) = self.net.forward(img, target)
    points_cam = points.clone().cpu()
    points_cam[..., 0] = (points_cam[..., 0] + 1) / 2 * img.shape[-1]
    points_cam[..., 1] = (points_cam[..., 1] + 1) / 2 * img.shape[-2]
    points_cam = points_cam.squeeze()
    points_world = self.converter.cam_to_world(points_cam).numpy()

    aim = (points_world[1] + points_world[0]) / 2.0
    angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
    steer = self._turn_controller.step(angle)
    steer = np.clip(steer, -1.0, 1.0)

    desired_speed = np.linalg.norm(points_world[0] - points_world[1]) * 2.0
    # desired_speed *= (1 - abs(angle)) ** 2

    speed = tick_data['speed']

    brake = desired_speed < 0.4 or (speed / desired_speed) > 1.1

    delta = np.clip(desired_speed - speed, 0.0, 0.25)
    throttle = self._speed_controller.step(delta)
    throttle = np.clip(throttle, 0.0, 0.75)
    throttle = throttle if not brake else 0.0

    control = carla.VehicleControl()
    control.steer = steer
    control.throttle = throttle
    control.brake = float(brake)
    return control

In [None]:
target_cam

In [None]:
def run_fwd_imgagent(img_agent)

In [None]:
# put in forward hooks


## Load agent, build environment, play an episode

In [None]:
env_name = 'Breakout-v0'
save_dir = 'figures/'

print("set up dir variables and environment...")
load_dir = 'pretrained/{}/'.format(env_name.lower())
meta = get_env_meta(env_name)
env = gym.make(env_name) ; env.seed(1)

print("initialize agent and try to load saved weights...")
model = NNPolicy(channels=1, num_actions=env.action_space.n)
_ = model.try_load(load_dir, checkpoint='*.tar') ; torch.manual_seed(1)

print("get a rollout of the policy...")
history = rollout(model, env, max_ep_len=3e3)

In [None]:
f = plt.figure(figsize=[3,3*1.3])
# frame_ix = 1404
frame_ix=1404
plt.imshow(history['ins'][frame_ix])
for a in f.axes: a.get_xaxis().set_visible(False) ; a.get_yaxis().set_visible(False)
plt.show(f)
f

## Get Jacobian saliency map

In [None]:
def jacobian(model, layer, top_dh, X):
    global top_h_ ; top_h_ = None
    def hook_top_h(m, i, o): global top_h_ ; top_h_ = o.clone()
    hook1 = layer.register_forward_hook(hook_top_h)
    _ = model(X) # do a forward pass so the forward hooks can be called

    # backprop positive signal
#     torch.autograd.backward(top_h_, top_dh.clone(), retain_variables=True) # backward hooks are called here
    torch.autograd.backward(top_h_, top_dh.clone(), retain_graph=True) # backward hooks are called here
    
    hook1.remove()
    return X[0].grad.data.clone().numpy(), X[0].data.clone().numpy()

In [None]:
# derivative is simply the output policy distribution
top_dh_actor = torch.Tensor(history['logits'][frame_ix]).view(1,-1)
top_dh_critic = torch.Tensor(history['values'][frame_ix]).view(1,-1).fill_(1)

# get input
tens_state = torch.Tensor(prepro(history['ins'][frame_ix]))
state = Variable(tens_state.unsqueeze(0), requires_grad=True)
hx = Variable(torch.Tensor(history['hx'][frame_ix-1]).view(1,-1))
cx = Variable(torch.Tensor(history['cx'][frame_ix-1]).view(1,-1))
X = (state, (hx, cx))

actor_jacobian, _ = jacobian(model, model.actor_linear, top_dh_actor, X)

state.grad.mul_(0) ; X = (state, (hx, cx))
critic_jacobian, _ = jacobian(model, model.critic_linear, top_dh_critic, X)

## Get perturbation saliency map

In [None]:
radius = 5
density = 5

actor_saliency = score_frame(model, history, frame_ix, radius, density, interp_func=occlude, mode='actor')
critic_saliency = score_frame(model, history, frame_ix, radius, density, interp_func=occlude, mode='critic')

In [None]:
# upsample jacobian saliencies
frame = history['ins'][frame_ix].squeeze().copy()
frame = saliency_on_atari_frame((actor_jacobian**2).squeeze(), frame, fudge_factor=1, channel=2, sigma=0)
jacobian_map = saliency_on_atari_frame((critic_jacobian**2).squeeze(), frame, fudge_factor=15, channel=0, sigma=0)

# upsample perturbation saliencies
frame = history['ins'][frame_ix].squeeze().copy()
frame = saliency_on_atari_frame(actor_saliency, frame, fudge_factor=200, channel=2)
perturbation_map = saliency_on_atari_frame(critic_saliency, frame, fudge_factor=100, channel=0)

## Plot side-by-side

In [None]:
f = plt.figure(figsize=[11, 5*1.3], dpi=75)

plt.subplot(1,2,1)
plt.imshow(jacobian_map)
plt.title('Jacobian', fontsize=30)

plt.subplot(1,2,2)
plt.imshow(perturbation_map)
plt.title('Ours', fontsize=30)

for a in f.axes: a.get_xaxis().set_visible(False) ; a.get_yaxis().set_visible(False)
plt.show() #; f.savefig('./figures/jacobian-vs-perturb.png', bbox_inches='tight')
f