# (??) NVAE -- play

**Motivation**: Play around with the official implementation and figure it out. <br>

In [1]:
# HIDE CODE


import os, sys
from copy import deepcopy as dc
from os.path import join as pjoin
from IPython.display import display, IFrame, HTML

# tmp & extras dir
git_dir = pjoin(os.environ['HOME'], 'Dropbox/git')
extras_dir = pjoin(git_dir, 'jb-MTMST/_extras')
fig_base_dir = pjoin(git_dir, 'jb-MTMST/figs')
tmp_dir = pjoin(git_dir, 'jb-MTMST/tmp')

sys.path.insert(0, pjoin(git_dir, 'NVAE'))
import utils

In [2]:
# HIDE CODE


import re
import os
import json
import h5py
import torch
import pickle
import joblib
import shutil
import random
import pathlib
import inspect
import logging
import argparse
import warnings
import operator
import functools
import itertools
import collections
import numpy as np
import pandas as pd
from tqdm import tqdm
from rich import print
from datetime import datetime
from os.path import join as pjoin
from prettytable import PrettyTable
from scipy import linalg as sp_lin
from scipy import signal as sp_sig
from scipy import stats as sp_stats
from scipy import ndimage as sp_img
from scipy.spatial import distance as sp_dist
from sklearn.preprocessing import Normalizer
from numpy.ma import masked_where as mwh
from typing import *
import torch
from torch import nn


def print_num_params(module: nn.Module):
    t = PrettyTable(['Module Name', 'Num Params'])

    for name, m in module.named_modules():
        total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
        x = total_params // 1e6
        y = total_params // 1e3
        if x > 0:
            num = f"{np.round(total_params / 1e6, 2):1.1f} M"
        elif y > 0:
            num = f"{np.round(total_params / 1e3, 2):1.1f} K"
        else:
            num = total_params

        if '.' not in name:
            if isinstance(m, type(module)):
                t.add_row(["{}".format(m.__class__.__name__), "{}".format(num)])
                t.add_row(['---', '---'])
            else:
                t.add_row([name, "{}".format(num)])
    print(t, '\n\n')

In [3]:
# HIDE CODE


ArgsNVAE = collections.namedtuple(
    typename='args',
    field_names=[
        'data',
        'dataset',
        'batch_size',
        'learning_rate',
        'learning_rate_min',
        'weight_decay',
        'weight_decay_norm',
        'weight_decay_norm_init',
        'weight_decay_norm_anneal',
        'epochs',
        'warmup_epochs',
        'fast_adamax',
        'arch_instance',
        'kl_anneal_portion',
        'kl_const_portion',
        'kl_const_coeff',
        'num_nf',
        'num_x_bits',
        'num_latent_scales',
        'num_groups_per_scale',
        'num_latent_per_group',
        'ada_groups',
        'min_groups_per_scale',
        'num_channels_enc',
        'num_preprocess_blocks',
        'num_preprocess_cells',
        'num_cell_per_cond_enc',
        'num_channels_dec',
        'num_postprocess_blocks',
        'num_postprocess_cells',
        'num_cell_per_cond_dec',
        'num_mixture_dec',
        'use_se',
        'res_dist',
        'cont_training',
        'distributed',
    ])

args_default = {
    'data': tmp_dir,
    'dataset': 'mnist',
    'batch_size': 200,
    'learning_rate': 1e-2,
    'learning_rate_min': 1e-4,
    'weight_decay': 3e-4,
    'weight_decay_norm': 0,
    'weight_decay_norm_init': 10,
    'weight_decay_norm_anneal': False,
    'epochs': 200,
    'warmup_epochs': 5,
    'fast_adamax': False,
    'arch_instance': 'res_mbconv',
    'kl_anneal_portion': 0.3,
    'kl_const_portion': 0.0001,
    'kl_const_coeff': 0.0001,
    'num_nf': 0,
    'num_x_bits': 8,
    'num_latent_scales': 1,
    'num_groups_per_scale': 10,
    'num_latent_per_group': 20,
    'ada_groups': False,
    'min_groups_per_scale': 1,
    'num_channels_enc': 32,
    'num_preprocess_blocks': 2,
    'num_preprocess_cells': 3,
    'num_cell_per_cond_enc': 1,
    'num_channels_dec': 32,
    'num_postprocess_blocks': 2,
    'num_postprocess_cells': 3,
    'num_cell_per_cond_dec': 1,
    'num_mixture_dec': 10,
    'use_se': False,
    'res_dist': False,
    'cont_training': False,
    'distributed': False,
}

In [4]:
args = ArgsNVAE(**args_default)

In [5]:
arch_instance = utils.get_arch_cells(args.arch_instance)
arch_instance

{'normal_enc': ['res_bnswish', 'res_bnswish'],
 'down_enc': ['res_bnswish', 'res_bnswish'],
 'normal_dec': ['mconv_e6k5g0'],
 'up_dec': ['mconv_e6k5g0'],
 'normal_pre': ['res_bnswish', 'res_bnswish'],
 'down_pre': ['res_bnswish', 'res_bnswish'],
 'normal_post': ['mconv_e3k5g0'],
 'up_post': ['mconv_e3k5g0'],
 'ar_nn': ['']}

In [6]:
from model import AutoEncoder
model = AutoEncoder(args, utils.Writer(0, 'exp'), utils.get_arch_cells(args.arch_instance))

len log norm: 128
len bn: 92


In [7]:
print_num_params(model)

In [25]:
args_vanilla = args_default.copy()
args_vanilla['num_channels_enc'] = 32
args_vanilla['num_channels_dec'] = 32
args_vanilla['num_postprocess_cells'] = 3
args_vanilla['num_preprocess_cells'] = 3
args_vanilla['num_latent_scales'] = 1
args_vanilla['num_latent_per_group'] = 20
args_vanilla['num_cell_per_cond_enc'] = 2
args_vanilla['num_cell_per_cond_dec'] = 2
args_vanilla['num_preprocess_blocks'] = 2
args_vanilla['num_postprocess_blocks'] = 2
args_vanilla['num_groups_per_scale'] = 1
args_vanilla['num_nf'] = 0
args_vanilla['min_groups_per_scale'] = 4
args_vanilla['weight_decay_norm_anneal'] = True
args_vanilla['weight_decay_norm_init'] = 10
args_vanilla['use_se'] = True
args_vanilla['res_dist'] = True
args_vanilla['ada_groups'] = True
args_vanilla['fast_adamax'] = True
args_vanilla = ArgsNVAE(**args_vanilla)

In [26]:
vanilla = AutoEncoder(args_vanilla, utils.Writer(0, 'exp'), utils.get_arch_cells(args_vanilla.arch_instance))

len log norm: 45
len bn: 36


In [27]:
print_num_params(mnist)

In [28]:
print_num_params(vanilla)

In [29]:
print(vanilla)

In [24]:
print(mnist)

In [22]:
args_celeb = args_default.copy()
args_celeb['num_channels_enc'] = 30
args_celeb['num_channels_dec'] = 30
args_celeb['num_postprocess_cells'] = 2
args_celeb['num_preprocess_cells'] = 2
args_celeb['num_latent_scales'] = 5
args_celeb['num_latent_per_group'] = 20
args_celeb['num_cell_per_cond_enc'] = 2
args_celeb['num_cell_per_cond_dec'] = 2
args_celeb['num_preprocess_blocks'] = 1
args_celeb['num_postprocess_blocks'] = 1
args_celeb['num_groups_per_scale'] = 16
args_celeb['num_nf'] = 2
args_celeb['min_groups_per_scale'] = 4
args_celeb['ada_groups'] = True
args_celeb['weight_decay_norm_anneal'] = True
args_celeb['weight_decay_norm_init'] = 10
args_celeb['use_se'] = True
args_celeb['res_dist'] = True
args_celeb['fast_adamax'] = True
args_celeb = ArgsNVAE(**args_celeb)

In [10]:
arch_instance = utils.get_arch_cells(args_celeb.arch_instance)
vae = AutoEncoder(args_celeb, utils.Writer(0, 'exp'), arch_instance)

len log norm: 986
len bn: 460


In [11]:
print_num_params(vae)

In [12]:
vae.image_conditional

Sequential(
  (0): ELU(alpha=1.0)
  (1): Conv2D(30, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [13]:
print(vae.dec_sampler)

In [14]:
print(vae.enc_sampler)

In [15]:
print(vae.enc_tower)

In [16]:
vae.enc0

Sequential(
  (0): ELU(alpha=1.0)
  (1): Conv2D(960, 960, kernel_size=(1, 1), stride=(1, 1))
  (2): ELU(alpha=1.0)
)

In [30]:
a = torch.Generator()

In [34]:
torch._C.Generator

torch._C.Generator

In [38]:
x = torch.zeros(10)

In [45]:
x.normal_(generator=None)

tensor([ 0.1537, -0.3420, -0.0035,  0.9074, -0.6348,  0.6950, -0.9392, -0.5921,
         0.6371,  0.3373])

In [None]:
from 

In [46]:
now()

NameError: name 'now' is not defined

NameError: name 'vae' is not defined

In [11]:
device = torch.device(f"cuda:1")
vae.eval().to(device);

In [12]:
print_num_params(vae)

In [13]:
print_num_params(vae.pre_process)

In [17]:
arch_instance

{'normal_enc': ['res_bnswish', 'res_bnswish'],
 'down_enc': ['res_bnswish', 'res_bnswish'],
 'normal_dec': ['mconv_e6k5g0'],
 'up_dec': ['mconv_e6k5g0'],
 'normal_pre': ['res_bnswish', 'res_bnswish'],
 'down_pre': ['res_bnswish', 'res_bnswish'],
 'normal_post': ['mconv_e3k5g0'],
 'up_post': ['mconv_e3k5g0'],
 'ar_nn': ['']}

In [18]:
{k: v for k, v in arch_instance.items() if k in ['normal_enc', 'normal_pre', 'down_enc', 'down_pre']}

{'normal_enc': ['res_bnswish', 'res_bnswish'],
 'down_enc': ['res_bnswish', 'res_bnswish'],
 'normal_pre': ['res_bnswish', 'res_bnswish'],
 'down_pre': ['res_bnswish', 'res_bnswish']}

In [19]:
{k: v for k, v in arch_instance.items() if k not in ['normal_enc', 'normal_pre', 'down_enc', 'down_pre']}

{'normal_dec': ['mconv_e6k5g0'],
 'up_dec': ['mconv_e6k5g0'],
 'normal_post': ['mconv_e3k5g0'],
 'up_post': ['mconv_e3k5g0'],
 'ar_nn': ['']}

In [None]:
args_celeb['num_late']

In [23]:
import datasets

In [24]:
train_queue, valid_queue, num_classes = datasets.get_loaders(args_mnist)

In [25]:
train_queue

<torch.utils.data.dataloader.DataLoader at 0x7fbc12f97dc0>

In [26]:
for step, (x, y) in enumerate(train_queue):
    x = x.to(device)
    print(step, x.shape, x.device)
    break

In [27]:
logits, log_q, log_p, kl_all, kl_diag = vae(x)

In [28]:
logits.shape

torch.Size([8, 1, 32, 32])

In [31]:
len(kl_all), len(kl_diag)

(15, 15)

In [50]:
np.unique([e.shape for e in kl_all])

array([8])

In [51]:
np.unique([e.shape for e in kl_diag])

array([20])

In [48]:
ctr_enc, ctr_dec = 0, 0
for m in vae.modules():
    if m.__class__.__name__ == 'EncCombinerCell':
        ctr_enc += 1
    elif m.__class__.__name__ == 'DecCombinerCell':
        ctr_dec += 1
ctr_enc, ctr_dec

(14, 15)

In [127]:
self = vae

In [128]:
x.shape

torch.Size([8, 1, 32, 32])

In [129]:
s = self.stem(2 * x - 1.0)
s.shape

torch.Size([8, 32, 32, 32])

In [130]:
for cell in self.pre_process:
    s = cell(s)
    print(s.shape)

In [131]:
combiner_cells_enc = []
combiner_cells_s = []
for cell in self.enc_tower:
    if cell.cell_type == 'combiner_enc':
        combiner_cells_enc.append(cell)
        combiner_cells_s.append(s)
    else:
        s = cell(s)
        print(cell.cell_type, s.shape)

In [132]:
combiner_cells_enc.reverse()
combiner_cells_s.reverse()

In [133]:
idx_dec = 0
ftr = self.enc0(s)                            # this reduces the channel dimension
ftr.shape

torch.Size([8, 256, 4, 4])

In [134]:
param0 = self.enc_sampler[idx_dec](ftr)
param0.shape

torch.Size([8, 40, 4, 4])

In [135]:
mu_q, log_sig_q = torch.chunk(param0, 2, dim=1)
mu_q.shape, log_sig_q.shape

(torch.Size([8, 20, 4, 4]), torch.Size([8, 20, 4, 4]))

In [136]:
from distributions import Normal, DiscMixLogistic, NormalDecoder

In [137]:
dist = Normal(mu_q, log_sig_q)   # for the first approx. posterior
z, _ = dist.sample()
log_q_conv = dist.log_p(z)

In [138]:
z.shape

torch.Size([8, 20, 4, 4])

In [139]:
log_q_conv.shape

torch.Size([8, 20, 4, 4])

In [140]:
all_q = [dist]
all_log_q = [log_q_conv]

In [141]:
dist = Normal(mu=torch.zeros_like(z), log_sigma=torch.zeros_like(z))
log_p_conv = dist.log_p(z)
all_p = [dist]
all_log_p = [log_p_conv]

In [142]:
self.prior_ftr0.shape

torch.Size([256, 4, 4])

In [143]:
idx_dec = 0
s = self.prior_ftr0.unsqueeze(0)
s.shape

torch.Size([1, 256, 4, 4])

In [144]:
batch_size = z.size(0)
s = s.expand(batch_size, -1, -1, -1)
batch_size, s.shape

(8, torch.Size([8, 256, 4, 4]))

In [145]:
for cell in self.dec_tower:
    if cell.cell_type == 'combiner_dec':
        if idx_dec > 0:
            # form prior
            param = self.dec_sampler[idx_dec - 1](s)
            mu_p, log_sig_p = torch.chunk(param, 2, dim=1)

            # form encoder
            ftr = combiner_cells_enc[idx_dec - 1](combiner_cells_s[idx_dec - 1], s)
            print(cell.cell_type, ftr.shape)
            param = self.enc_sampler[idx_dec](ftr)
            mu_q, log_sig_q = torch.chunk(param, 2, dim=1)
            dist = Normal(mu_p + mu_q, log_sig_p + log_sig_q) if self.res_dist else Normal(mu_q, log_sig_q)
            z, _ = dist.sample()
            log_q_conv = dist.log_p(z)
            all_log_q.append(log_q_conv)
            all_q.append(dist)

            # evaluate log_p(z)
            dist = Normal(mu_p, log_sig_p)
            log_p_conv = dist.log_p(z)
            all_p.append(dist)
            all_log_p.append(log_p_conv)

        # 'combiner_dec'
        s = cell(s, z)
        idx_dec += 1
    else:
        s = cell(s)
        print(cell.cell_type, s.shape)

In [146]:
for cell in self.post_process:
    s = cell(s)
    print(cell.cell_type, s.shape)

In [147]:
vae.stem

Conv2D(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [148]:
vae.stem.bias

Parameter containing:
tensor([-0.0274,  0.0243,  0.2942, -0.0110,  0.2739,  0.0504, -0.1710,  0.2377,
        -0.0871,  0.1767, -0.0623, -0.0808,  0.0468,  0.0126, -0.0354,  0.2536,
        -0.1658,  0.0034,  0.2498, -0.1310,  0.2838, -0.1309,  0.1926,  0.2250,
         0.0148, -0.2844, -0.1706, -0.1666,  0.2665,  0.0490, -0.3323, -0.2233],
       device='cuda:1', requires_grad=True)

In [151]:
vae.pre_process[0].cell_type

'normal_pre'

In [153]:
vae.pre_process[2].cell_type

'down_pre'

In [154]:
for cell in vae.pre_process:
    print(cell.cell_type)

In [156]:
print(vae.pre_process[0].se)

In [161]:
print(vae.pre_process[2])

In [363]:
x.shape

torch.Size([13, 32, 24, 16, 16])

In [283]:
f.ops[0](x[..., 1:, :, 1:]).shape

torch.Size([13, 1, 12, 8, 8])

In [318]:
base2(8)

'1000'

In [311]:
%time base2(4)

CPU times: user 18 µs, sys: 0 ns, total: 18 µs
Wall time: 20.7 µs


(1, 0, 0)

In [294]:
len(np.base_repr(125284, base=2))

17

In [426]:
vae.num_latent_per_group

20

In [427]:
vae.enc0

Sequential(
  (0): ELU(alpha=1.0)
  (1): Conv2D(256, 256, kernel_size=(1, 1), stride=(1, 1))
  (2): ELU(alpha=1.0)
)

In [428]:
vae.vanilla_vae

False

In [429]:
vae.num_cell_per_cond_dec

2

In [430]:
vae.num_cell_per_cond_enc

2

In [448]:
device = torch.device('cuda:1')

In [449]:
f = vae.dec_tower[1]._ops[0].to(device)
print(f)

In [459]:
for n, layer in vae.dec_tower.named_modules():
    try:
        if layer.upsample:
            break
    except:
        continue

In [461]:
f = layer.to(device)

In [54]:
a = torch.empty((5, 10))

In [462]:
x = torch.randn(13, 256, 4, 4).to(device)
x.shape

torch.Size([13, 256, 4, 4])

In [463]:
y = f(x)
y.shape

torch.Size([13, 128, 8, 8])

In [43]:
rng = torch.Generator(random)

In [58]:
seed = 0
rng = torch.Generator()
_ = rng.manual_seed(seed)

In [59]:
rng.get_state()

tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8)

In [60]:
a.normal_(generator=rng)

tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
          0.3223, -1.2633],
        [ 0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,
          0.5667,  0.7935],
        [ 0.5988, -1.5551, -0.3414,  1.8530,  0.7502, -0.5855, -0.1734,  0.1835,
          1.3894,  1.5863],
        [ 0.9463, -0.8437, -0.6136,  0.0316,  1.0554,  0.1778, -0.2303, -0.3918,
          0.5433, -0.3952],
        [ 0.2055, -0.4503,  1.5210,  3.4105, -1.5312, -1.2341,  1.8197, -0.5515,
         -1.3253,  0.1886]])

In [70]:
@torch.jit.script
def sample_normal_jit(mu: torch.Tensor, sigma: torch.Tensor):
	eps = mu.mul(0).normal_()
	z = eps.mul_(sigma).add_(mu)
	return z, eps

In [71]:
def sample_normal(mu: torch.Tensor, sigma: torch.Tensor):
	eps = mu.mul(0).normal_()
	z = eps.mul_(sigma).add_(mu)
	return z, eps

In [76]:
%timeit sample_normal_jit(torch.zeros(100000), torch.zeros(100000))

861 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [77]:
%timeit sample_normal(torch.zeros(100000), torch.zeros(100000))

851 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [435]:
print({
    k: v for k, v in vae.arch_instance.items()
    if 'dec' in k
})

In [None]:
padding = dilation * (k - 1) // 2