# MNIST Attention Visualization
- By running this script, you can get images of attention-based-color-coded Set-MNIST.
## To run this code...
- You should prepare the summary file by running sample_and_summarize.py with a trained checkpoint.
- You should install below libraries.
    - matplotlib
    - open3d
    - numpy
    - torch
    - torchvision
    - tqdm

In [None]:
import os
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import open3d as o3d

import numpy as np
import torch
torch.no_grad()
from torchvision.utils import save_image, make_grid

from draw import draw, draw_attention

## Set directories
1. summary file path: summary_name
2. path to save images: save_dir

In [None]:
save_dir = 'images_attn'
experiment_name = 'mnist/camera-ready'
summary_name = os.path.join('../checkpoints/gen/', experiment_name, 'summary.pth')

imgdir = os.path.join(save_dir, experiment_name)
imgdir_gt = os.path.join(imgdir, 'gt')
imgdir_recon = os.path.join(imgdir, 'recon')
imgdir_gen = os.path.join(imgdir, 'gen')

os.makedirs(save_dir, exist_ok=True)
os.makedirs(imgdir_gt, exist_ok=True)
os.makedirs(imgdir_recon, exist_ok=True)
os.makedirs(imgdir_gen, exist_ok=True)

In [None]:
summary = torch.load(summary_name)
for k, v in summary.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {len(v)}")

## Select the samples to visualize
- parse the samples by index.
- below default code will visualize all samples. **Warning: Requires Huge Memory**

In [None]:
recon_targets = list(range(len(summary['gt_mask'])))[:]
gen_targets = list(range(len(summary['smp_mask'])))[:]

In [None]:
len_att = len(summary['dec_att'])
gt = summary['gt_set'][recon_targets]
gt_mask = summary['gt_mask'][recon_targets]

recon = summary['recon_set'][recon_targets]
recon_mask = summary['recon_mask'][recon_targets]

dec_att = [summary['dec_att'][l][:, :, recon_targets] for l in range(len_att)]
enc_att = [summary['enc_att'][l][:, :, recon_targets] for l in range(len_att)]

gen = summary['smp_set'][gen_targets]
gen_mask = summary['smp_mask'][gen_targets]
gen_att = [summary['smp_att'][l][:, :, gen_targets] for l in range(len_att)]

## Visualize Attention
- lidx: index of layer
- projection: ISAB has 2 projection attention and back-projection attention.
    - 0: projection, 1: back-projection

In [None]:
def attention_selector(gt, gt_mask, att, lidx=0, projection=0):
    return draw_attention(gt, gt_mask, att[lidx][projection], color_opt='gist_rainbow', dot_size=300)  # use 300 for multimnist, 700 for mnist

### Visualize Encoder Attention on GT samples

In [None]:
for topdown in tqdm(range(len(enc_att))):
    for projection in [0]:
        gt_imgs = attention_selector(gt, gt_mask, enc_att, len(enc_att) - 1 - topdown, projection)
        gt_imgs = [i/255. for i in gt_imgs]
        for head in range(enc_att[0][0].shape[0]):
            for idx in range(len(recon_targets)):
                data_idx = recon_targets[idx]
                gt_img = gt_imgs[idx][head]
                save_image(gt_img, os.path.join(imgdir_gt, f'{topdown}_{projection}_{head}_{data_idx}.png'))
del gt_imgs
print('gt DONE')

### Visualize Decoder Attention on Reconstructed samples

In [None]:
for topdown in tqdm(range(len(enc_att))):
    for projection in [1]:
        recon_imgs = attention_selector(recon, recon_mask, dec_att, topdown, projection)
        recon_imgs = [i/255. for i in recon_imgs]
        for head in range(enc_att[0][0].shape[0]):
            for idx in range(len(recon_targets)):
                data_idx = recon_targets[idx]
                recon_img = recon_imgs[idx][head]
                save_image(recon_img, os.path.join(imgdir_recon, f'{topdown}_{projection}_{head}_{data_idx}.png'))
del recon_imgs
print('recon DONE')

### Visualize Decoder Attention on Generated samples

In [None]:
for topdown in tqdm(range(len(dec_att))):
    for projection in [1,]:
        gen_imgs = attention_selector(gen, gen_mask, gen_att, topdown, projection)
        gen_imgs = [i/255. for i in gen_imgs]
        for head in range(enc_att[0][0].shape[0]):
            for idx in range(len(gen_targets)):
                data_idx = gen_targets[idx]
                gen_img = gen_imgs[idx][head]
                save_image(gen_img.float(), os.path.join(imgdir_gen, f'{topdown}_{projection}_{head}_{data_idx}.png'))
        del gen_imgs
print('gen DONE')

In [None]:
print('DONE')