# MNIST Point Set Visualization
- By running this script, you can get images of point sets.
## To run this code...
- You should prepare the summary file by running sample_and_summarize.py with a trained checkpoint.
- You should install below libraries.
    - matplotlib
    - open3d
    - numpy
    - torch
    - torchvision
    - tqdm

In [None]:
import os
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import open3d as o3d

import numpy as np
import torch
torch.no_grad()
from torchvision.utils import save_image, make_grid

from draw import draw

## Set directories
1. summary file path: summary_name
2. path to save images: save_dir

In [None]:
save_dir = 'images'
experiment_name = 'mnist/camera-ready'
summary_name = os.path.join('../checkpoints/gen/', experiment_name, 'summary.pth')

imgdir = os.path.join(save_dir, experiment_name)
imgdir_gt = os.path.join(imgdir, 'gt')
imgdir_recon = os.path.join(imgdir, 'recon')
imgdir_gen = os.path.join(imgdir, 'gen')

os.makedirs(save_dir, exist_ok=True)
os.makedirs(imgdir_gt, exist_ok=True)
os.makedirs(imgdir_recon, exist_ok=True)
os.makedirs(imgdir_gen, exist_ok=True)

In [None]:
summary = torch.load(summary_name)
for k, v in summary.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {len(v)}")

## Select the samples to visualize
- parse the samples by index.
- below default code will visualize all samples. **Warning: Requires Huge Memory**

In [None]:
recon_targets = list(range(len(summary['gt_mask'])))[:]
gen_targets = list(range(len(summary['smp_mask'])))[:]

In [None]:
len_att = len(summary['dec_att'])
gt = summary['gt_set'][recon_targets]
gt_mask = summary['gt_mask'][recon_targets]

recon = summary['recon_set'][recon_targets]
recon_mask = summary['recon_mask'][recon_targets]

dec_att = [summary['dec_att'][l][:, :, recon_targets] for l in range(len_att)]
enc_att = [summary['enc_att'][l][:, :, recon_targets] for l in range(len_att)]

gen = summary['smp_set'][gen_targets]
gen_mask = summary['smp_mask'][gen_targets]
gen_att = [summary['smp_att'][l][:, :, gen_targets] for l in range(len_att)]

## Visualize

In [None]:
def visualize(gt, gt_mask):
    return draw(gt, gt_mask)

### Visualize Recon

In [None]:
recon_imgs = visualize(recon, recon_mask)
for idx in range(len(recon_targets)):
    data_idx = recon_targets[idx]
    if torch.nonzero(recon_imgs[idx].float().mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    save_image(recon_imgs[idx] / 255., os.path.join(imgdir_recon, f'{data_idx}.png'))
del recon_imgs

### Visualize GT

In [None]:
gt_imgs = visualize(gt, gt_mask)
for idx in range(len(recon_targets)):
    data_idx = recon_targets[idx]
    if torch.nonzero(gt_imgs[idx].float().mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    save_image(gt_imgs[idx]/255, os.path.join(imgdir_gt, f'{data_idx}.png'))
del gt_imgs

### Visualize Generated Samples

In [None]:
gen_imgs = visualize(gen, gen_mask)
for idx in range(len(gen_targets)):
    if torch.nonzero(gen_imgs[idx].float().mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    data_idx = gen_targets[idx]
    save_image(gen_imgs[idx]/255, os.path.join(imgdir_gen, f'{data_idx}.png'))
del gen_imgs

In [None]:
print("Done")