# Cardinality-Generalization Visualization
- To run this code, you need summary.pth and the corresponding model checkpoint *.pt file.
- This code will use GPU:0.

In [None]:
import os
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from argparse import Namespace
from glob import glob

import numpy as np
import torch
from torchvision.utils import save_image

from draw import draw, draw_attention, draw_open3d

import sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
from models.networks import SetVAE
from args import get_parser

os.environ["CUDA_VISIBLE_DEVICES"]='0'
device = torch.device('cuda')

In [None]:
!pwd

# Choose Directory and Model
- save_dir: directory to save images
- experiment_name: directory with summary.pth, and *.pt

In [None]:
save_dir = 'images_cardinality_generalization_final'
experiment_name = 'shapenet15k-car/camera-ready'
summary_name = os.path.join('../checkpoints/gen/', experiment_name, 'summary.pth')
checkpoint_name = sorted(glob(os.path.join('../checkpoints/gen/', experiment_name, '*.pt')))[-1]
print(checkpoint_name)

imgdir = os.path.join(save_dir, experiment_name)
imgdir_gt = os.path.join(imgdir, 'gt')
imgdir_recon = os.path.join(imgdir, 'recon')
imgdir_gen = os.path.join(imgdir, 'gen')

os.makedirs(imgdir_gt, exist_ok=True)
os.makedirs(imgdir_recon, exist_ok=True)
os.makedirs(imgdir_gen, exist_ok=True)

## Load Summary file to use fixed latents.

In [None]:
summary = torch.load(summary_name)
for k, v in summary.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {len(v)}")

# Select your model configuration
- If you use the configuration we provided, please select one from below based on your dataset type.
- If you use your own customization, please make a configuration dictionary.

In [None]:
# ShapeNet
argsdict = {'input_dim': 3, 'max_outputs': 2500, 'init_dim': 32, 'n_mixtures': 4,
            'z_dim': 16, 'z_scales': [1, 1, 2, 4, 8, 16, 32], 'hidden_dim': 64, 'num_heads': 4,
            'fixed_gmm': True, 'train_gmm': True, 'slot_ln': False, 'slot_mlp': False,
            'slot_att': True, 'ln': True, 'seed': 42,
            'dataset_type': 'shapenet15k', 'num_workers': 4, 'eval': True,
            'gpu': 0, 'batch_size': 32}
args = get_parser().parse_args('')
for k, v in argsdict.items():
    setattr(args, k, v)

In [None]:
# MultiMNIST
argsdict = {'input_dim': 2, 'max_outputs': 600, 'init_dim': 64, 'n_mixtures': 16,
            'z_dim': 16, 'z_scales': [2, 4, 8, 16, 32], 'hidden_dim': 64, 'num_heads': 4,
            'slot_att': True, 'ln': True, 'shared_ip': False, 'seed': 42,
            'dataset_type': 'shapenet15k', 'num_workers': 4, 'eval': True,
            'gpu': 0, 'batch_size': 32}

args = get_parser().parse_args('')
for k, v in argsdict.items():
    setattr(args, k, v)

In [None]:
# MNIST
argsdict = {'input_dim': 2, 'max_outputs': 400, 'init_dim': 32, 'n_mixtures': 4,
            'z_dim': 16, 'z_scales': [2, 4, 8, 16, 32], 'hidden_dim': 64, 'num_heads': 4,
            'slot_att': True, 'ln': True, 'seed': 42,
            'dataset_type': 'mnist', 'num_workers': 4, 'eval': True,
            'gpu': 0, 'batch_size': 32}

args = get_parser().parse_args('')
for k, v in argsdict.items():
    setattr(args, k, v)

# Model Setup

In [None]:
model = SetVAE(args)
checkpoint = torch.load(checkpoint_name, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'], strict=True)

model.to(device)
model.eval()

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_gen_parameters = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
print(f"Full: {n_parameters}, Gen: {n_gen_parameters}")

In [None]:
gen = summary['smp_set']
gen_mask = summary['smp_mask']

# Cardinality Generalization
- Define the cardinalities you want to visualize

In [None]:
cardinality = torch.tensor([100, 500, 2048, 3000, 10000, 100000])

cardinality = cardinality.to(device)
max_outputs = cardinality.max().item()
model.max_outputs = max_outputs
model.init_set.max_outputs = max_outputs

In [None]:
@torch.no_grad()
def generate_all(gen_idx, cardinality):
    z = [prior[gen_idx:gen_idx+1].repeat(cardinality.size(0), 1, 1).to(device) for prior in summary['priors']]
    output = model.sample(cardinality, given_latents=z)
    
    gen_imgs = draw(output['set'], output['set_mask'])
    return gen_imgs

In [None]:
for gen_idx in tqdm(gen_targets):
    gen_imgs = generate_all(gen_idx, cardinality)
    
    # Check whether the dataset is MNIST or ShapeNet.
    if gen_imgs.dtype == torch.float32:
        # Below code snippet crop the image's white margins. (only for ShapeNet)
        try:
            pos_min = torch.nonzero(gen_imgs.mean(0).mean(0) != 1).min(0)[0]
            pos_max = torch.nonzero(gen_imgs.mean(0).mean(0) != 1).max(0)[0]
            gen_imgs = gen_imgs[:, :, pos_min[0]:pos_max[0]+1, pos_min[1]:pos_max[1]+1]
        except RuntimeError:
            pass
        
    for imgs, c in zip(gen_imgs, cardinality):
        if imgs.dtype == torch.uint8:
            imgs = imgs.float() / 255.            
        save_image(imgs[2], os.path.join(imgdir_gen, f'{gen_idx}_{c}.png'))
    del gen_imgs

In [None]:
print(f'Everything saved under {imgdir}')