# Qualitative Visualization  

This notebook demonstrates how to create the qualitative visualization of Figure 9 and Figure 10 in paper. 

## Step-1: Generate Data

You can use `pusht_human.py` to find a preferable layout and save both the images and pointer locations by pressing `s` key. 

Then, you can draw the trajectory with `draw_human_trajectory.ipynb`, just use your image as the background instead of the default "000.jpg".

You can download our data from the `datasets/visualization` folder (in box). We assume that you already download the `datasets/visualization` folder (in box) and put it under `outputs` (local, in `pusht` folder).

## Step-2: Estimate Likelihood and conditioned Trajectories

First, import libraries and define some functions.

In [1]:
from omegaconf import OmegaConf
import os.path as osp
import torch
torch.set_grad_enabled(False)
import os
import hydra
from copy import deepcopy
from torchvision.transforms.functional import to_pil_image, to_tensor
from diffusion_policy.common.pytorch_util import dict_apply
import torch.nn.functional as F
import pathlib
from scipy import interpolate
from PIL import Image, ImageDraw
import numpy as np
import json
from tqdm.auto import tqdm

def load_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def smooth_curve(points, window_size=3):
    points = np.array(points)
    smoothed_points = []
    for i in range(len(points)):
        start = max(0, i - window_size // 2)
        end = min(len(points), i + window_size // 2 + 1)
        avg_point = np.mean(points[start:end], axis=0)
        smoothed_points.append(avg_point)
    return smoothed_points

def add_circle(draw, end_point, radius=5, fill=(255, 0, 0)):
    x, y = end_point
    bounding_box = [x - radius, y - radius, x + radius, y + radius]
    draw.ellipse(bounding_box, fill=fill)

def draw_smooth_line_with_dot(img, points, line_color=(255, 0, 0), line_width=3, dot_radius=5, dot_color=(255, 0, 0)):
    img = img.convert('RGBA')
    draw = ImageDraw.Draw(img)
    smooth_points = smooth_curve(points)
    smooth_points_tuple = [tuple(map(int, point)) for point in smooth_points]
    draw.line(smooth_points_tuple, fill=line_color, width=line_width)
    if len(smooth_points_tuple) > 1:
        add_circle(draw, smooth_points_tuple[-1], radius=dot_radius, fill=dot_color)
    return img

In [2]:
device = 0 # selecting the device

Next, we load two ARP models. They are simpler than the one in the main result: they use a flat action sequence. One of them treats every action as a continuous value, and the another one discretize the action space.

In [3]:
cfg1 = OmegaConf.load(f'outputs/visualization/flat.yaml')

P1 = hydra.utils.instantiate(cfg1.policy)
P1.load_state_dict(torch.load('outputs/visualization/flat_epoch=0500-test_mean_score=0.822.ckpt')['state_dicts']['model'])


cfg2 = OmegaConf.load(f'outputs/visualization/flat_dis.yaml')
P2 = hydra.utils.instantiate(cfg2.policy)
P2.load_state_dict(torch.load('outputs/visualization/flat_dis_epoch=1500-test_mean_score=0.815.ckpt')['state_dicts']['model'])



using obs modality: rgb with keys: ['image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []
using obs modality: low_dim with keys: []






using obs modality: rgb with keys: ['image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []
using obs modality: low_dim with keys: []


<All keys matched successfully>

### Likelihood Inference

First, load all human data into `all_batches`.

In [4]:
# Data Preparation
all_batches = []

for i in range(4):
    fig = Image.open(f'./outputs/visualization/human-data/{i:03d}.jpg')
    pic = to_tensor(fig)
    state = load_json(f'./outputs/visualization/human-data/{i:03d}.json')
    action = load_json(f'./outputs/visualization/human-data/{i:03d}.action.json')

    agent_pos = torch.as_tensor(state['pos_agent']).reshape(1, 2)

    action_tensor = []
    for l in action:
        l = F.interpolate(torch.as_tensor(l).float().reshape(1, -1, 2).permute(0, 2, 1), size=(16,), mode='linear').permute(0, 2, 1)
        l = l / 256 * 512
        l = l.clamp(0, 511)
        action_tensor.append(l)
    action_tensor = torch.cat(action_tensor, 0)

    batch = {
        'image': F.interpolate(pic[None, ...], size=(96, 96), mode='bilinear', align_corners=False)[None, ...],
        'origin_image': fig,
        'agent_pos': agent_pos,
        'action': action_tensor
    }
    all_batches.append(batch)

The following code will estimate the likelihood of human-drawn trajectories, and save them to `./outputs/visualization/regenerated_likelihood.{images, states}`. 

In [7]:
total = 0
for bi in range(4):
     for self_name, self in [('P1', P1), ('P2', P2)]:
          batch = deepcopy(all_batches[bi])
          for action_i, action in enumerate(batch['action']):
               total += 1
                    
with tqdm(total=total) as pbar:
     for bi in range(4):
          for self_name, self in [('P1', P1), ('P2', P2)]:
               batch = deepcopy(all_batches[bi])
               for action_i, action in enumerate(batch['action']):
                    action = batch['action'][[action_i]]

                    self.eval()
                    nobs = {
                         'image': batch['image'],
                         'agent_pos': batch['agent_pos']
                    }

                    nobs['image'] -= 0.5
                    nobs['image'] /= 0.5

                    nobs['agent_pos'] -= 256
                    nobs['agent_pos'] /= 256.

                    action -= 256
                    action /= 256

                    batch_size = 1
                         
                    future_tk_types = ['x', 'y'] * self.horizon
                    future_chk_ids = list(range(self.horizon * 2))


                    this_nobs = dict_apply(nobs, 
                         lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
                    nobs_features = self.obs_encoder(this_nobs)
                    nobs_features = self.obs_feat_linear(nobs_features)
                    nobs_features = nobs_features.reshape(batch_size, self.n_obs_steps, self.policy.cfg.n_embd)

                    tk_vals = action.flatten(1).unsqueeze(-1)
                    tk_names = future_tk_types
                    tk_types = torch.as_tensor([self.policy.token_name_2_ids[tname] for tname in tk_names]).reshape(1, -1, 1).repeat(batch_size, 1, 1)
                    seq = torch.cat([tk_vals, tk_types], dim=-1)
                    loss_dict, log_prob = self.policy.compute_loss(seq, contexts={'visual-token': nobs_features}, log_prob=True) 
                    log_prob = log_prob.sum().item()

                    pts = (action * 256 + 256) / 512 * 256
                    fig = batch['origin_image'].copy()
                    fig = draw_smooth_line_with_dot(fig, pts[0], line_color=(255, 0, 0, 208), line_width=4, dot_radius=4, dot_color=(255, 0, 0, 208))
                    
                    output_folder_images = f'./outputs/visualization/regenerated_likelihood.images/{self_name}/{bi}/'
                    output_folder_states = f'./outputs/visualization/regenerated_likelihood.states/{self_name}/{bi}/'
                    os.makedirs(output_folder_images, exist_ok=True)
                    os.makedirs(output_folder_states, exist_ok=True)
                    fig.save(osp.join(output_folder_images, f'{action_i}_{log_prob:.02f}.png'))
                    torch.save({'pts': pts, 'log_prob': log_prob}, osp.join(output_folder_states, f'{action_i}.pt'))
                    pbar.update()

  0%|          | 0/44 [00:00<?, ?it/s]

### Predict with Human Guidance

The following code will generate trajectories conditioned on first half of the human trajectory,  and save them to `./outputs/visualization/regenerated_guide.{images, states}`. 

In [8]:
total = 0
for bi in range(4):
     for self_name, self in [('P1', P1), ('P2', P2)]:
          batch = deepcopy(all_batches[bi])
          for action_i, action in enumerate(batch['action']):
               for sample_i in range(10):
                    total += 1
               
               
with tqdm(total=total) as pbar:       
     for bi in range(4):
          for self_name, self in [('P1', P1), ('P2', P2)]:
               batch = deepcopy(all_batches[bi])
               for action_i, action in enumerate(batch['action']):
                    for sample_i in range(10):
                         action = batch['action'][[action_i]]

                         self.eval()
                         nobs = {
                              'image': batch['image'],
                              'agent_pos': batch['agent_pos']
                         }

                         nobs['image'] -= 0.5
                         nobs['image'] /= 0.5

                         nobs['agent_pos'] -= 256
                         nobs['agent_pos'] /= 256.

                         action -= 256
                         action /= 256

                         batch_size = 1
                              
                         future_tk_types = ['x', 'y'] * (self.horizon // 2)   
                         future_chk_ids = list(range(self.horizon, self.horizon * 2))
                         future_tk_chk_ids = [{'tk_id': self.policy.token_name_2_ids[tk_type], 'chk_id': chk_id} for chk_id, tk_type in zip(future_chk_ids, future_tk_types)]       


                         this_nobs = dict_apply(nobs, 
                              lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
                         nobs_features = self.obs_encoder(this_nobs)
                         nobs_features = self.obs_feat_linear(nobs_features)
                         nobs_features = nobs_features.reshape(batch_size, self.n_obs_steps, self.policy.cfg.n_embd)


                         seq = torch.zeros(batch_size, self.horizon, 2)

                         seq[:, :, 1] = torch.as_tensor([0, 1] * (self.horizon // 2))
                         seq[:, :, 0] = action[0, :self.horizon//2, :].flatten()

                         try:
                              action_pred = self.policy.generate(seq, future_tk_chk_ids, contexts={'visual-token': nobs_features}, sample=True)
                         except:
                              tqdm.write(f'error at {bi} {action_i}, possibly get an NaN on out-of-distribution data')
                              continue
                         action_pred = action_pred[..., 0].reshape(-1, self.horizon, 2)

                         action_pred *= 256.
                         action_pred += 256.
                         action_pred.clamp_(0, 511)


                         pts = action_pred[0] / 512 * 256
                         fig =  batch['origin_image'].copy()
                         # draw_keypoints(fig, pts, radius=1)
                         fig = draw_smooth_line_with_dot(fig, pts[:self.horizon//2+1], line_color=(255, 0, 0, 245), line_width=4, dot_radius=2, dot_color=(255, 0, 0, 208))
                         fig = draw_smooth_line_with_dot(fig, pts[self.horizon//2:], line_color=(0, 0, 255, 208), line_width=4, dot_radius=4, dot_color=(0, 0, 255, 208))

                         output_folder_images = f'./outputs/visualization/regenerated_guide.images/{self_name}/{bi}/'
                         output_folder_states = f'./outputs/visualization/regenerated_guide.states/{self_name}/{bi}/'
                         os.makedirs(output_folder_images, exist_ok=True)
                         os.makedirs(output_folder_states, exist_ok=True)

                         fig.save(osp.join(output_folder_images, f'{action_i}-{sample_i}.png'))
                         torch.save(pts, osp.join(output_folder_states, f'{action_i}-{sample_i}.pt'))
                         pbar.update()


  0%|          | 0/440 [00:00<?, ?it/s]

error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data
error at 3 6, possibly get an NaN on out-of-distribution data


## Step-3: Visualization

Unzip the `outputs/visualization/visualized_data.zip` in the `outputs/visualization` folder. It contains our likelihood and prediction with guidance results (we use the discrete model at a different epoch).  Here we recreate the visualization in paper.


## Creating Colormap

We use the `fast` colormap from https://www.kennethmoreland.com/color-advice/. 

In [17]:
from colormath.color_objects import *
from colormath.color_conversions import convert_color

file_descriptor = open('outputs/visualization/fast.colormap.json', 'r')
raw_color_data = json.load(file_descriptor)[0]
import pandas

scalar = []
rgb_values = []
for i in range(0, len(raw_color_data['RGBPoints']), 4):
    scalar.append(raw_color_data['RGBPoints'][i+0])
    rgb_values.append(sRGBColor(
        raw_color_data['RGBPoints'][i+1],
        raw_color_data['RGBPoints'][i+2],
        raw_color_data['RGBPoints'][i+3]
    ))

data = pandas.DataFrame({'scalar': scalar, 'rgb_values': rgb_values})
data['lab_values'] = data['rgb_values'].apply(lambda rgb: convert_color(rgb, LabColor))


def color_lookup_sRGBColor(x):
    if x < 0:
        return sRGBColor(0, 0, 0)
    for index in range(0, data.index.size-1):
        low_scalar = data['scalar'][index]
        high_scalar = data['scalar'][index+1]
        if (x > high_scalar):
            continue
        low_lab = data['lab_values'][index]
        high_lab = data['lab_values'][index+1]
        interp = (x-low_scalar)/(high_scalar-low_scalar)
        mid_lab = LabColor(interp*(high_lab.lab_l-low_lab.lab_l) + low_lab.lab_l,
                           interp*(high_lab.lab_a-low_lab.lab_a) + low_lab.lab_a,
                           interp*(high_lab.lab_b-low_lab.lab_b) + low_lab.lab_b,
                           observer=low_lab.observer,
                           illuminant=low_lab.illuminant)
        return convert_color(mid_lab, sRGBColor)
    return sRGBColor(1, 1, 1)

def color_lookup(x):
    return color_lookup_sRGBColor(x).get_value_tuple()

def color_lookup_upscaled(x):
    return color_lookup_sRGBColor(x).get_upscaled_value_tuple()

def to_color(v):
    r = color_lookup(v)[:3]
    return tuple([int(i * 255) for i in r])

## Predict with Human Guidance

The following dicts are used to select which image / trajectory to visualization. You always want to choose the good looking ones. 

In [9]:
assignments ={
0: {
    'bad': [0,1,2],
    'good': [3,4]
},
1: {
    'bad': [2,3,4,5],
    'good': [0,1]
},
2: {
    'good': [0,1],
    'bad': [2,3]
},
3: {
    'good': [0, 1, 2, 3, 6],
    'bad': [4, ]
}
}

guide_assignments ={
0: {
    'bad': [(0, 0), (2, 0)],
    'good': [(3, 4), (4, 0)]
},
1: {
    'bad': [(2, 8), (3, 5) , (4, 9), (5, 1)],
    'good': [(0, 0), (1, 7)]
},
2: {
    'good': [(0, 0), (1, 9), (2, 3)],
    'bad': [(3, 9)]
},
3: {
    'good': [(0, 9), (1, 9), (2, 5), (5, 9)],
    'bad': [(3, 3), (4, 2), (6, 5)]
}
}

The following code generate the prediction with guidance visualization images, and save them to `./outputs/visualization/results/guide`.

In [16]:
model_v = 'P1'
horizon = 16

os.makedirs('outputs/visualization/results/guide/unroll', exist_ok=True)

for img_id in range(4):
    origin_fig = Image.open(f'./outputs/visualization/human-data/{str(img_id).zfill(3)}.jpg')
    guide_pts = {}

    for k in guide_assignments[img_id]:
        guide_pts[k] = {}
        for aid, sid in guide_assignments[img_id][k]:
            state = torch.load(f'./outputs/visualization/guide.states/{model_v}/{img_id}/{aid}-{sid}.pt')
            guide_pts[k][aid] = state
            
    fig = origin_fig.copy()
    for cid, aid in enumerate(sorted(guide_pts['good'])):
        c1, c2 = (255, 0, 0, 200), (0, 0, 255, 200)
        fig = draw_smooth_line_with_dot(fig, guide_pts['good'][aid][:horizon//2+1], line_color=c1, line_width=4, dot_radius=2, dot_color=c1)
        fig = draw_smooth_line_with_dot(fig, guide_pts['good'][aid][horizon//2:], line_color=c2, line_width=4, dot_radius=4, dot_color=c2)

        _fig = origin_fig.copy()
        _fig = draw_smooth_line_with_dot(_fig, guide_pts['good'][aid][:horizon//2+1], line_color=c1, line_width=4, dot_radius=2, dot_color=c1)
        _fig = draw_smooth_line_with_dot(_fig, guide_pts['good'][aid][horizon//2:], line_color=c2, line_width=4, dot_radius=4, dot_color=c2)
        _fig.save(f'outputs/visualization/results/guide/unroll/{img_id}.{cid}.good.png')

    fig.save(f'outputs/visualization/results/guide/{img_id}.good.png')


    fig = origin_fig.copy()
    for cid, aid in enumerate(guide_pts['bad']):
        if img_id == 1:
            if aid in [3]: continue
        c1, c2 = (255, 0, 0, 200), (0, 0, 255, 200)
        fig = draw_smooth_line_with_dot(fig, guide_pts['bad'][aid][:horizon//2+1], line_color=c1, line_width=4, dot_radius=2, dot_color=c1)
        fig = draw_smooth_line_with_dot(fig, guide_pts['bad'][aid][horizon//2:], line_color=c2, line_width=4, dot_radius=4, dot_color=c2)

        _fig = origin_fig.copy()
        _fig = draw_smooth_line_with_dot(_fig, guide_pts['bad'][aid][:horizon//2+1], line_color=c1, line_width=4, dot_radius=2, dot_color=c1)
        _fig = draw_smooth_line_with_dot(_fig, guide_pts['bad'][aid][horizon//2:], line_color=c2, line_width=4, dot_radius=4, dot_color=c2)
        _fig.save(f'outputs/visualization/results/guide/unroll/{img_id}.{cid}.bad.png')


    fig.save(f'outputs/visualization/results/guide/{img_id}.bad.png')

dict_keys([0, 2])
dict_keys([2, 3, 4, 5])
dict_keys([3])
dict_keys([3, 4, 6])


## Likelihood Inference

The following code generate the likelihood inference visualization images, and save them to `./outputs/visualization/results/likelihood`. Note for the image `1`, we use the discrete model instead of the continuous one. 

In [19]:
os.makedirs('outputs/visualization/results/likelihood/unroll', exist_ok=True)

for img_id in range(4):
    if img_id == 1:
        model_v = 'P2'
    else:
        model_v = 'P1'
    horizon = 16
    BASELINE = 0.1

    origin_fig = Image.open(f'./outputs/visualization/human-data/{str(img_id).zfill(3)}.jpg')
    likelihood = {}
    state_files = os.listdir(f'./outputs/visualization/likelihood.states/{model_v}/{img_id}/')
    for s in state_files:
        sid = int(s.split('.')[0])
        state = torch.load(f'./outputs/visualization/likelihood.states/{model_v}/{img_id}/{s}')
        likelihood[sid] = state

    ma = max([v['log_prob'] for v in likelihood.values()])
    mi = min([v['log_prob'] for v in likelihood.values()])
    for v in likelihood.values():
        v['normed_log_prob'] = (v['log_prob'] - mi) / (ma - mi) * 0.9 + 0.1 # some color tweaks

    
    fig = origin_fig.copy()
    for i, s in sorted(likelihood.items()):
        if img_id == 3:
            if i in [3, 5]: continue
        
        nlp = round(s['normed_log_prob'], 2)
        color = to_color(nlp)
        fig = draw_smooth_line_with_dot(fig, s['pts'][0], line_color=color, line_width=4, dot_radius=4, dot_color=color)
        _fig = origin_fig.copy()
        _fig = draw_smooth_line_with_dot(_fig, s['pts'][0], line_color=color, line_width=4, dot_radius=4, dot_color=color)
        _fig.save(f'outputs/visualization/results/likelihood/unroll/{img_id}.{i}.png')

    fig.save(f'outputs/visualization/results/likelihood/{img_id}.png')
        