# (06) timer run

**Motivation**: host = ```mach```, device = ```cuda:1``` <br>

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# tmp & extras dir
git_dir = os.path.join(os.environ['HOME'], 'Dropbox/git')
extras_dir = os.path.join(git_dir, 'jb-vae/_extras')
fig_base_dir = os.path.join(git_dir, 'jb-vae/figs')
tmp_dir = os.path.join(git_dir, 'jb-vae/tmp')

# GitHub
# sys.path.insert(0, os.path.join(git_dir, '_PoissonVAE'))
sys.path.insert(0, os.path.join(git_dir, '_IterativeVAE'))
from figures.fighelper import *
from vae.train_vae import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
from rich.jupyter import print
%matplotlib inline
set_style()

In [2]:
from base.utils_model import load_quick
from figures.analysis import plot_convergence
from figures.imgs import plot_weights

device_idx = 1
device = f'cuda:{device_idx}'

print(f"device: {device}  ———  host: {os.uname().nodename}")

## MNIST

### Get configs

In [3]:
model_type = 'poisson'
cfg_vae, cfg_tr = default_configs('MNIST', model_type, 'conv+b|conv+b')

cfg_vae['n_latents'] = [128]
cfg_vae['init_scale'] = 1e-4
cfg_vae['seq_len'] = 10

cfg_tr['lr'] = 2e-3
cfg_tr['epochs'] = 500
cfg_tr['batch_size'] = 200
cfg_tr['kl_beta'] = 5.0
cfg_tr['kl_balancer'] = None

### Make model + trainer

In [4]:
vae = IPVAE(CFG_CLASSES[model_type](**cfg_vae))
tr = TrainerVAE(vae, ConfigTrainVAE(**cfg_tr), device=device, verbose=True)

In [5]:
tr.n_iters

150000

In [6]:
tr.train()

  0%|                                                   | 0/500 [00:00<?, ?it/s]

epoch # 1, avg loss: 53.247894:   0%|         | 1/500 [01:01<8:32:50, 61.66s/it]

KeyboardInterrupt



In [7]:
self = tr

In [33]:
import time

start_time = time.time()
first_batch_time = None
for i, (x, *_) in enumerate(self.dl_trn):
    if i == 0:
        first_batch_time = time.time() - start_time
    elif i == 1:
        second_batch_time = time.time() - start_time - first_batch_time
    elif i == 2:
        third_batch_time = time.time() - start_time - second_batch_time
        break

print(f"Time to load first batch: {first_batch_time:.3f} seconds")
print(f"Time to load second batch: {second_batch_time:.3f} seconds")
print(f"Time to load third batch: {third_batch_time:.3f} seconds")

In [35]:
from torch.profiler import profile, ProfilerActivity

num_epochs = 100

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    # Place your training loop here
    for epoch in range(num_epochs):
        for i, (x, *_) in enumerate(self.dl_trn):
            # Perform model training steps
            pass

prof.export_chrome_trace(pjoin(tmp_dir, "trace.json")) 
# Exports the profiling results which can be viewed in Chrome’s tracing tools


STAGE:2024-09-05 20:07:35 70558:70558 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-09-05 20:10:55 70558:70558 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-09-05 20:11:06 70558:70558 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [36]:
gstep = 1245

In [8]:
def test():
    annealing_is_done = (
        self.cfg.temp_stop ==
        self.temperatures[gstep]
    )
    hard = (
        self.model.cfg.hard_fwd
        and annealing_is_done
    )
    kws = dict(hard=hard)
    return kws

In [9]:
%timeit test()

257 ns ± 42.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [16]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("annealing_and_hard"):
        annealing_is_done = (
            min(self.temperatures) ==
            self.temperatures[gstep]
        )
        hard = (
            self.model.cfg.hard_fwd
            and annealing_is_done
        )
        kws = dict(hard=hard)

STAGE:2024-09-05 17:46:58 61702:61702 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-09-05 17:46:59 61702:61702 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-09-05 17:46:59 61702:61702 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [17]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

### Fit model

In [9]:
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.benchmark_limit = 0

In [9]:
tr.train(fit_name=f"vmap_{tr.cfg.name()}")

  0%|                                                   | 0/500 [00:00<?, ?it/s]

epoch # 1, avg loss: 42.677650:   0%|         | 1/500 [00:26<3:41:24, 26.62s/it]

epoch # 2, avg loss: 34.329868:   0%|         | 2/500 [00:50<3:28:01, 25.06s/it]

epoch # 3, avg loss: 31.267898:   1%|         | 3/500 [01:15<3:25:55, 24.86s/it]

epoch # 4, avg loss: 35.803611:   1%|         | 4/500 [01:39<3:24:48, 24.78s/it]

epoch # 5, avg loss: 41.684516:   1%|         | 5/500 [02:05<3:26:26, 25.02s/it]

epoch # 6, avg loss: 34.035676:   1%|         | 6/500 [02:30<3:25:59, 25.02s/it]

epoch # 7, avg loss: 28.336998:   1%|▏        | 7/500 [03:03<3:35:17, 26.20s/it]


KeyboardInterrupt: 

In [None]:
log_rate = tonp(tr.model.log_rate).ravel()
bias = tonp(tr.model.layer.bias).ravel()

fig, axes = create_figure(1, 2, (10, 2))
kws = dict(fill=True, lw=3, alpha=0.3, ax=axes[0])
sns.histplot(log_rate, color='C0', element='step', label=r'$\log r$', **kws)

kws = dict(fill=True, lw=3, alpha=0.3, ax=axes[1])
sns.histplot(np.exp(log_rate), color='C0', element='step', label='rate', **kws)
sns.histplot(bias, color='C8', element='step', label='bias', **kws)

axes[1].set(ylabel='')
add_legend(axes)

plt.show()

In [None]:
dead = log_rate > 1
dead.sum(), (~dead).sum(), dead.sum() / len(dead)

In [None]:
%%time

kws = dict(
    seq_total=1000,
    seq_batch_sz=1000,
    n_data_batches=10,
    active=~dead,
    full_data=True,
    return_recon=True,
)
results = {
    name: tr.analysis(dl_name=name, **kws)
    for name in ['trn', 'vld'] # , 'tst']
}

In [None]:
colors = {'trn': 'C9', 'vld': 'C0', 'tst': 'k'}
for name, d in results.items():
    if name != 'trn':
        print('\n\n\n')
    print('_' * 110)
    print('-' * 110)
    print(name)

    plot_convergence(d, color=colors[name])

In [None]:
%%time

kws = dict(
    seq_total=3000,
    seq_batch_sz=1000,
    n_data_batches=1,
    active=~dead,
    full_data=True,
    return_recon=True,
)
results_to_plot = {
    name: tr.analysis(dl_name=name, **kws)
    for name in ['trn', 'vld'] # , 'tst']
}

In [None]:
num = 16
shape = (tr.model.cfg.input_sz, ) * 2

for name, d in results_to_plot.items():
    if name != 'trn':
        print('\n\n\n')
    print('_' * 110)
    print('-' * 110)
    print(name)
    
    x2p = np.concatenate([
        d['x'][:num, 0],
        d['y'][:num, -1].reshape(-1, *shape),
    ])
    _ = plot_weights(x2p, nrows=2)

In [None]:
tr.model.reset_states(64)
dist, z, pred = tr.model.layer.generate(0.0)

output = tr.model.generate(pred, seq=range(300))
output = output.stack()

loss_kl = tonp(torch.sum(output['loss_kl'], -1))
u = tr.model.log_rate.expand(len(pred), -1)
desc_len = tr.model.layer.loss_kl(u=u)
desc_len = tonp(torch.sum(desc_len, dim=-1))

list(output)

In [None]:
shape = (tr.model.cfg.input_sz, tr.model.cfg.input_sz)

In [None]:
nrows = 4
ncols = int(np.ceil(len(pred) / nrows))
fig, axes = create_figure(nrows, ncols, (1.1 * ncols, 1.3 * nrows), 'all', 'all')
for sample_i, ax in enumerate(axes.flat):
    i, j = sample_i // ncols, sample_i % ncols
    x2p = tonp(pred[sample_i].reshape(shape))
    ax.imshow(x2p, cmap='Greys_r')
    ax.set_title(f"i = {sample_i}")
trim_axs(axes, len(pred))
remove_ticks(axes)
plt.show()

In [None]:
sample_i = 27

In [None]:
nrows, ncols = 4, 30
fig, axes = create_figure(nrows, ncols, (1.3 * ncols, 1.3 * nrows), 'all', 'all')
for idx, ax in enumerate(axes.flat):
    i, j = idx // ncols, idx % ncols
    x2p = tonp(output['recon'][sample_i, idx].reshape(shape))
    ax.imshow(x2p, cmap='Greys_r')
remove_ticks(axes)

In [None]:
order = np.argsort(desc_len)

In [None]:
for i in range(len(pred)):
    sample_i = order[i]
    nrows, ncols = 4, 30
    fig, axes = create_figure(nrows, ncols, (1.3 * ncols, 1.3 * nrows), 'all', 'all')
    tit = ' ——— '.join([
        f"i = {i}, sample # {sample_i}",
        f"desc len = {desc_len[sample_i]:0.2g}",
    ])
    fig.suptitle(tit, fontsize=25, y=1.14)
    
    x2p = tonp(pred[sample_i].reshape(shape))
    axes[0, 0].imshow(x2p, cmap='Greys_r')
    
    for idx, ax in enumerate(axes.flat[1:], start=1):
        i, j = idx // ncols, idx % ncols
        x2p = tonp(output['recon'][sample_i, idx - 1].reshape(28, 28))
        ax.imshow(x2p, cmap='Greys_r')
    remove_ticks(axes)
    plt.show()