# (30) Fit -- cuda2

**Motivation**: Fitting notebook, cuda2 <br>

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# tmp & extras dir
git_dir = os.path.join(os.environ['HOME'], 'Dropbox/git')
extras_dir = os.path.join(git_dir, 'jb-MTMST/_extras')
fig_base_dir = os.path.join(git_dir, 'jb-MTMST/figs')
tmp_dir = os.path.join(git_dir, 'jb-MTMST/tmp')

# GitHub
sys.path.insert(0, os.path.join(git_dir, '_MTMST'))
from model.train_vae import TrainerVAE, ConfigTrain
from model.vae2d import VAE, ConfigVAE
from analysis.opticflow import *
from figures.fighelper import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
from tqdm.notebook import tqdm
from rich.jupyter import print
%matplotlib inline
set_style()

## Trainer

In [2]:
vae = VAE(ConfigVAE(
    n_latent_scales=2, n_groups_per_scale=20, n_latent_per_group=7,
    scale_init=False, residual_kl=True, ada_groups=True,
))
tr = TrainerVAE(
    model=vae,
    cfg=ConfigTrain(
        lr=0.003, batch_size=512, epochs=2000, grad_clip=1000,
        lambda_anneal=True, lambda_init=1e-7, lambda_norm=1e-3,
        kl_beta=0.25, kl_anneal_cycles=1, 
        scheduler_kws={'T_max': 660.0, 'eta_min': 1e-05},
        optimizer='adamax_fast',
    ),
    device='cuda:2',
)
vae.cfg.total_latents()

210

In [3]:
vae.print()
vae.scales

[8, 4]

In [4]:
len(vae.all_conv_layers), len(vae.all_log_norm)

(291, 228)

In [5]:
vars(tr.cfg)

{'lr': 0.003,
 'epochs': 2000,
 'batch_size': 512,
 'warmup_portion': 0.02,
 'lambda_anneal': True,
 'lambda_init': 1e-07,
 'lambda_norm': 0.001,
 'kl_beta': 0.25,
 'kl_beta_min': 0.0001,
 'kl_balancer': 'equal',
 'kl_anneal_cycles': 1,
 'kl_anneal_portion': 0.3,
 'kl_const_portion': 0.001,
 'optimizer': 'adamax_fast',
 'optimizer_kws': {'betas': (0.9, 0.999),
  'weight_decay': 0.0001,
  'eps': 1e-08},
 'scheduler_type': 'cosine',
 'scheduler_kws': {'T_max': 660.0, 'eta_min': 1e-05},
 'spectral_reg': False,
 'ema_rate': 0.999,
 'grad_clip': 1000,
 'chkpt_freq': 50,
 'eval_freq': 5,
 'log_freq': 30,
 'use_amp': False}

In [6]:
tr.optim

Adamax (
Parameter Group 0
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.003
    lr: 0.003
    weight_decay: 0.0001
)

## Review train options

In [7]:
print(f"{vae.cfg.name()}\n{tr.cfg.name()}")

## Train

In [None]:
comment = f"ClipVal:500_AdamaxFast(eps:1e-8)_{tr.cfg.name()}"
tr.train(comment=comment)

epoch # 1390, avg loss: 15.067790:  70%|██████▉   | 1390/2000 [22:28:01<9:18:11, 54.90s/it]   

### Also: was non smooth L1