# (04) PyTorch profiler

**Motivation**: @torch.jit.script was slowing things down <br>

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# tmp & extras dir
git_dir = os.path.join(os.environ['HOME'], 'Dropbox/git')
extras_dir = os.path.join(git_dir, 'jb-MTMST/_extras')
fig_base_dir = os.path.join(git_dir, 'jb-MTMST/figs')
tmp_dir = os.path.join(git_dir, 'jb-MTMST/tmp')

# GitHub
sys.path.insert(0, os.path.join(git_dir, '_MTMST'))
from model.train_vae import TrainerVAE, ConfigTrainVAE
from model.vae2d import VAE, ConfigVAE
from analysis.opticflow import *
from figures.fighelper import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
from tqdm.notebook import tqdm
from rich.jupyter import print
%matplotlib inline
set_style()

## Trainer

In [2]:
vae = VAE(ConfigVAE(
    n_latent_scales=2, n_groups_per_scale=20, n_latent_per_group=7,
    scale_init=False, residual_kl=True, ada_groups=True, # separable=False,
))
self = TrainerVAE(
    model=vae,
    cfg=ConfigTrainVAE(
        lr=0.002, batch_size=500, epochs=2000, grad_clip=None,
        lambda_anneal=True, lambda_init=1e-7, lambda_norm=1e-2,
        kl_beta=0.25, kl_anneal_cycles=1, 
        scheduler_kws={'T_max': 660.0, 'eta_min': 1e-05},   
        optimizer='adamax_fast',
    ),
    device='cuda:1',
)
vae.cfg.total_latents()

210

In [3]:
vae.print()
vae.scales

[8, 4]

In [4]:
len(vae.all_conv_layers), len(vae.all_log_norm)

(291, 287)

In [5]:
from tqdm import tqdm

In [None]:
# neither epe, nor _normalize have jit. furthermore, _normalize uses linalg.vector_norm

In [6]:
%%time

self.train(100, 'test', False)

epoch # 100, avg loss: 4.861284: 100%|██████████| 100/100 [1:49:38<00:00, 65.79s/it]


CPU times: user 1h 47min 2s, sys: 3min 38s, total: 1h 50min 41s
Wall time: 1h 49min 39s


In [6]:
%%time

self.model.train()

for _ in range(5):
    for x, norm in tqdm(iter(self.dl_trn)):
        self.model(x)

100%|██████████| 80/80 [00:18<00:00,  4.25it/s]
100%|██████████| 80/80 [00:16<00:00,  4.93it/s]
100%|██████████| 80/80 [00:17<00:00,  4.71it/s]
100%|██████████| 80/80 [00:14<00:00,  5.49it/s]
100%|██████████| 80/80 [00:14<00:00,  5.67it/s]

CPU times: user 1min 21s, sys: 1.53 s, total: 1min 22s
Wall time: 1min 20s





In [7]:
# was with bunch of jit at Normal

100%|██████████| 80/80 [00:21<00:00,  3.81it/s]
100%|██████████| 80/80 [00:15<00:00,  5.09it/s]
100%|██████████| 80/80 [00:17<00:00,  4.57it/s]
100%|██████████| 80/80 [00:14<00:00,  5.35it/s]
100%|██████████| 80/80 [00:15<00:00,  5.00it/s]

CPU times: user 1min 26s, sys: 1.65 s, total: 1min 27s
Wall time: 1min 25s





In [6]:
# was with torch.linalg.vector_norm

100%|██████████| 80/80 [00:14<00:00,  5.57it/s]
100%|██████████| 80/80 [00:13<00:00,  6.07it/s]
100%|██████████| 80/80 [00:11<00:00,  6.92it/s]
100%|██████████| 80/80 [00:12<00:00,  6.27it/s]
100%|██████████| 80/80 [00:11<00:00,  6.88it/s]

CPU times: user 1min 4s, sys: 1.64 s, total: 1min 6s
Wall time: 1min 3s





In [7]:
from model.utils_model import kl_balancer

In [9]:
self.pbar = tqdm(range(10))

  0%|          | 0/10 [00:00<?, ?it/s]

In [10]:
%%time

for e in range(5): 
    self.iteration(0, n_iters_warmup=0)

gstep # 61, nelbo: 15.907, grad: 228.0:   0%|          | 0/10 [05:06<?, ?it/s]      

KeyboardInterrupt: 

In [9]:
%%time

self.model.train()

for e in range(5): 
    for i, (x, norm) in tqdm(enumerate(iter(self.dl_trn))):
        gstep = e * len(self.dl_trn) + i
        y, _, q, p = self.model(x)
        epe = self.model.loss_recon(x=x, y=y, w=1/norm)
        kl_all, kl_diag = self.model.loss_kl(q, p)
        # balance kl
        balanced_kl, gamma, kl_vals = kl_balancer(
            kl_all=kl_all,
            alpha=self.alphas,
            coeff=self.betas[gstep],
            beta=self.cfg.kl_beta,
        )
        loss = torch.mean(epe + balanced_kl)
        # add regularization
        loss_w = self.model.loss_weight()
        if self.wd_coeffs[gstep] > 0:
            loss += self.wd_coeffs[gstep] * loss_w
        cond_reg_spectral = self.cfg.lambda_norm > 0 \
            and self.cfg.spectral_reg and \
            not self.model.cfg.spectral_norm
        if cond_reg_spectral:
            loss_sr = self.model.loss_spectral(
                device=self.device, name='w')
            loss += self.wd_coeffs[gstep] * loss_sr
        else:
            loss_sr = None
        loss.backward()

80it [00:54,  1.48it/s]
80it [00:43,  1.84it/s]
80it [00:44,  1.78it/s]
80it [00:48,  1.65it/s]
80it [00:44,  1.79it/s]

CPU times: user 3min 51s, sys: 6.77 s, total: 3min 58s
Wall time: 3min 55s





In [7]:
# was with torch.linalg.vector_norm

80it [00:51,  1.54it/s]
80it [00:46,  1.72it/s]
80it [00:42,  1.88it/s]
80it [00:46,  1.73it/s]
80it [00:48,  1.63it/s]

CPU times: user 3min 50s, sys: 8.81 s, total: 3min 59s
Wall time: 3min 56s





In [9]:
activities = [
    torch.profiler.ProfilerActivity.CPU,
    torch.profiler.ProfilerActivity.CUDA,
]
kws = dict(
    activities=activities,
    record_shapes=True,
    with_stack=True,
)
with torch.profiler.profile(**kws) as prof:
    y, _, q, p = self.model(x)
    epe = self.model.loss_recon(x=x, y=y, w=1/norm)
    kl_all, kl_diag = self.model.loss_kl(q, p)
    # balance kl
    balanced_kl, gamma, kl_vals = kl_balancer(
        kl_all=kl_all,
        alpha=self.alphas,
        coeff=self.betas[gstep],
        beta=self.cfg.kl_beta,
    )
    loss = torch.mean(epe + balanced_kl)
    # add regularization
    loss_w = self.model.loss_weight()
    if self.wd_coeffs[gstep] > 0:
        loss += self.wd_coeffs[gstep] * loss_w
    cond_reg_spectral = self.cfg.lambda_norm > 0 \
        and self.cfg.spectral_reg and \
        not self.model.cfg.spectral_norm
    if cond_reg_spectral:
        loss_sr = self.model.loss_spectral(
            device=self.device, name='w')
        loss += self.wd_coeffs[gstep] * loss_sr
    else:
        loss_sr = None
    loss.backward()

In [10]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))

In [11]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_time_total", row_limit=10))

### Save to txt?

In [12]:
output = prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=100)
with open("profile_cuda.txt", "w") as f:
    f.write(output)
    
output = prof.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_time_total", row_limit=100)
with open("profile_cpu.txt", "w") as f:
    f.write(output)

**Conclusion**: jit was slowing things down

In [None]:
activities = [
    torch.profiler.ProfilerActivity.CPU,
    torch.profiler.ProfilerActivity.CUDA,
]
kws = dict(
    activities=activities,
    record_shapes=True,
    with_stack=True,
)
with torch.profiler.profile(**kws) as prof:
    self.train(1, save=False)

epoch # 1, avg loss: 57.204149: 100%|██████████| 1/1 [02:20<00:00, 140.16s/it]


In [None]:
output = prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=100)
with open("profile_cuda.txt", "w") as f:
    f.write(output)
    
output = prof.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_time_total", row_limit=100)
with open("profile_cpu.txt", "w") as f:
    f.write(output)

## Compile

In [3]:
vae = VAE(ConfigVAE(
    n_latent_scales=2, n_groups_per_scale=20, n_latent_per_group=7,
    scale_init=False, residual_kl=True, ada_groups=True, # separable=False,
))
tr = TrainerVAE(
    model=torch.compile(vae),
    cfg=ConfigTrainVAE(
        lr=0.002, batch_size=500, epochs=2000, grad_clip=1000,
        lambda_anneal=True, lambda_init=1e-7, lambda_norm=1e-2,
        kl_beta=0.25, kl_anneal_cycles=1, 
        scheduler_kws={'T_max': 660.0, 'eta_min': 1e-05},   
        optimizer='adamax_fast',
    ),
    device='cuda:1',
)
vae.cfg.total_latents()

210

In [4]:
%%time

tr.model.train()

for _ in range(5):
    for x, norm in tqdm(iter(tr.dl_trn)):
        tr.model(x)

  0%|          | 0/80 [00:00<?, ?it/s]

BackendCompilerFailed: debug_wrapper raised RuntimeError: Function ConvolutionBackward0 returned an invalid gradient at index 1 - expected device cuda:1 but got cuda:0

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


In [None]:
tr.train(comment='compiled_test')