# (29) jacobian — bench

**Motivation**: host = ```mach```, device = ```cuda:2``` <br>

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# tmp & extras dir
git_dir = os.path.join(os.environ['HOME'], 'Dropbox/git')
extras_dir = os.path.join(git_dir, 'jb-vae/_extras')
fig_base_dir = os.path.join(git_dir, 'jb-vae/figs')
tmp_dir = os.path.join(git_dir, 'jb-vae/tmp')

# GitHub
sys.path.insert(0, os.path.join(git_dir, '_IterativeVAE'))
from figures.fighelper import *
from vae.train_vae import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
from rich.jupyter import print
%matplotlib inline
set_style()

In [2]:
from base.utils_model import load_quick

device_idx = 2
device = f'cuda:{device_idx}'

In [3]:
def get_perf(first, first_descriptor, second, second_descriptor):
    if first.times[0] < second.times[0]:
        winner = first_descriptor
        faster = first.times[0]
        slower = second.times[0]
    else:
        winner = second_descriptor
        faster = second.times[0]
        slower = first.times[0]
    
    gain = 100 * (slower - faster) / slower
    print(f"Performance delta: {gain:.2f} percent improvement with '{winner}'")

## jacrev and jacfwd

In [5]:
def predict(x):
    return F.linear(x, weight, bias).tanh()

In [6]:
batch_size = 123
n_latents = 16
n_pix = 256

weight = torch.randn(n_pix, n_latents)
print(f"weight shape = {weight.shape}")

bias = torch.randn(n_pix)
z = torch.randn(batch_size, n_latents)

In [7]:
compute_batch_jacobian = torch.func.vmap(
    func=torch.func.jacfwd(predict),
    in_dims=0,
)
jacobian = compute_batch_jacobian(z)
jacobian.shape

torch.Size([123, 256, 16])

In [9]:
from torch.utils.benchmark import Timer

stmt_fwd = "torch.func.vmap(torch.func.jacfwd(predict, argnums=0), in_dims=0)(z)"
stmt_rev = "torch.func.vmap(torch.func.jacrev(predict, argnums=0), in_dims=0)(z)"
timer_fwd = Timer(stmt=stmt_fwd, globals=globals())
timer_rev = Timer(stmt=stmt_rev, globals=globals())

fwd_times = timer_fwd.timeit(500)
rev_times = timer_rev.timeit(500)

print(fwd_times)
print(rev_times)

In [13]:
print(f"latent dim: {n_latents}  ———  num pixels: {n_pix}")
get_perf(fwd_times, 'jacfwd', rev_times, 'jacrev')

In [15]:
batch_size = 123
n_latents = 512
n_pix = 256

weight = torch.randn(n_pix, n_latents)
bias = torch.randn(n_pix)
z = torch.randn(batch_size, n_latents)

In [16]:
from torch.utils.benchmark import Timer

stmt_fwd = "torch.func.vmap(torch.func.jacfwd(predict, argnums=0), in_dims=0)(z)"
stmt_rev = "torch.func.vmap(torch.func.jacrev(predict, argnums=0), in_dims=0)(z)"
timer_fwd = Timer(stmt=stmt_fwd, globals=globals())
timer_rev = Timer(stmt=stmt_rev, globals=globals())

fwd_times = timer_fwd.timeit(500)
rev_times = timer_rev.timeit(500)

print(fwd_times)
print(rev_times)

In [17]:
print(f"latent dim: {n_latents}  ———  num pixels: {n_pix}")
get_perf(fwd_times, 'jacfwd', rev_times, 'jacrev')

## Explore with real IP-VAE

### (option 1) latent_dim < n_pix

In [4]:
model_type = 'poisson'
cfg_vae, cfg_tr = default_configs('vH16', model_type, 'jac|mlp')

cfg_vae['n_latents'] = [100]
cfg_vae['init_scale'] = 1e-2
cfg_vae['seq_len'] = 20

In [5]:
vae = HIPVAE(CFG_CLASSES[model_type](**cfg_vae))
tr = TrainerVAE(vae, ConfigTrainVAE(**cfg_tr), device=device)

In [6]:
x = next(iter(tr.dl_vld))[0]
x = x.flatten(start_dim=1)

self = tr.model.input_layer
self.reset_state(len(x))

In [7]:
dist, spks, pred = self.generate(0.2, True)
residual = x - pred

In [8]:
print(self.dim < self._cfg.input_sz ** 2)

In [9]:
func = torch.func.jacrev(func=self.decode, argnums=0)
jacobian = torch.func.vmap(func=func, in_dims=0, chunk_size=None)

In [10]:
from torch.utils.benchmark import Timer

stmt_main = "self.encode(residual, spks)"
stmt_alt = "torch.einsum('bmk, bm -> bk', jacobian(spks).squeeze(), x)"
timer_main = Timer(stmt=stmt_main, globals=globals())
timer_alt = Timer(stmt=stmt_alt, globals=globals())

timer_main = timer_main.timeit(500)
timer_alt = timer_alt.timeit(500)

print(timer_main)
print(timer_alt)

In [11]:
get_perf(timer_main, 'main', timer_alt, 'alt')

### (option 2) latent_dim > n_pix

In [12]:
model_type = 'poisson'
cfg_vae, cfg_tr = default_configs('vH16', model_type, 'jac|mlp')

cfg_vae['n_latents'] = [512]
cfg_vae['init_scale'] = 1e-2
cfg_vae['seq_len'] = 20

In [13]:
vae = HIPVAE(CFG_CLASSES[model_type](**cfg_vae))
tr = TrainerVAE(vae, ConfigTrainVAE(**cfg_tr), device=device)

In [14]:
x = next(iter(tr.dl_vld))[0]
x = x.flatten(start_dim=1)

self = tr.model.input_layer
self.reset_state(len(x))

In [15]:
dist, spks, pred = self.generate(0.2, True)
residual = x - pred

In [16]:
print(self.dim < self._cfg.input_sz ** 2)

In [17]:
func = torch.func.jacfwd(func=self.decode, argnums=0)
jacobian = torch.func.vmap(func=func, in_dims=0, chunk_size=None)

In [18]:
from torch.utils.benchmark import Timer

stmt_main = "self.encode(residual, spks)"
stmt_alt = "torch.einsum('bmk, bm -> bk', jacobian(spks).squeeze(), x)"
timer_main = Timer(stmt=stmt_main, globals=globals())
timer_alt = Timer(stmt=stmt_alt, globals=globals())

timer_main = timer_main.timeit(500)
timer_alt = timer_alt.timeit(500)

print(timer_main)
print(timer_alt)

In [19]:
get_perf(timer_main, 'main', timer_alt, 'alt')

## When implementinc mlp decoder

In [54]:
for i in range(5):
    print(_build_mlp_dec(100, 784, 'swish', i))