In [20]:
import pytorch_lightning as pl
import torch
from numerize import numerize


from diffusion.unet import SimpleUNet
from diffusion.lightning_modules import DiffusionWithModel
from diffusion.diffusion import Diffusion
from diffusion.vgg5 import VGG5
from diffusion.vae import SimpleVAE

import numpy as np
import pandas as pd
import tabulate

import os

In [9]:
model_dir = r"C:\Users\niels\local_data\bachelor\eval_models"
model_path = lambda p: os.path.join(model_dir, f'{p}.ckpt')

classifier           = VGG5.load_from_checkpoint(model_path('classifier'))
cond_ddpm_combined   = DiffusionWithModel.load_from_checkpoint(model_path('cond_ddpm'))
uncond_ddpm_combined = DiffusionWithModel.load_from_checkpoint(model_path('uncond_ddpm'))
vae                  = SimpleVAE.load_from_checkpoint(model_path('vae'))

cond_unet, cond_diffusion = cond_ddpm_combined.extract_models()
uncond_unet, uncond_diffusion = uncond_ddpm_combined.extract_models()

  rank_zero_warn(


In [26]:
table = list()
for name, parameter in classifier.named_parameters():
    if not parameter.requires_grad: continue
    params = parameter.numel()
    table.append([name, params])

df = pd.DataFrame(table, columns=['Modules', 'Num params'])

# splits module names on '.'
df['Modules'] = df['Modules'].apply(lambda x: x.split('.')[0])
df.groupby('Modules').sum().reset_index()

Unnamed: 0,Modules,Num params
0,block_1,9696
1,block_2,55680
2,block_3,369792
3,block_4,1476864
4,dense,792586


In [68]:
def list_parameters(model: torch.nn.Module|pl.LightningModule, format='simple', group_block=0, decimals=0) -> str:
    header = ["Modules", "Num params"]
    table = list()
    
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.append([name, params])
        
    humanize = lambda x: _str[:-1] + ' ' + _str[-1] if (_str := numerize.numerize(x, decimals=decimals))[-1] in ('K', 'M') else _str + '  '
    
    df = pd.DataFrame(table, columns=['Modules', 'Num params'])
    df['Modules'] = df['Modules'].apply(lambda x: '.'.join(x.split('.')[:group_block+1]))
    df_grouped = df.groupby('Modules').sum().reset_index()
    df_grouped['Num params'] = df_grouped['Num params'].apply(humanize)
    table = df_grouped.to_numpy().tolist()
    table.append(['Total', humanize(int(df['Num params'].sum()))])
    
    return tabulate.tabulate(table, headers=header, tablefmt=format, intfmt='g')

print(list_parameters(uncond_unet, decimals=1, group_block=1, format='latex'))

\begin{tabular}{ll}
\hline
 Modules          & Num params   \\
\hline
 input\_blocks.0   & 160          \\
 input\_blocks.1   & 6.9 K        \\
 input\_blocks.10  & 25 K         \\
 input\_blocks.11  & 25 K         \\
 input\_blocks.12  & 25 K         \\
 input\_blocks.13  & 25 K         \\
 input\_blocks.14  & 25 K         \\
 input\_blocks.15  & 25 K         \\
 input\_blocks.16  & 25 K         \\
 input\_blocks.17  & 78.7 K       \\
 input\_blocks.18  & 95 K         \\
 input\_blocks.19  & 95 K         \\
 input\_blocks.2   & 6.9 K        \\
 input\_blocks.20  & 95 K         \\
 input\_blocks.21  & 95 K         \\
 input\_blocks.22  & 95 K         \\
 input\_blocks.23  & 95 K         \\
 input\_blocks.24  & 95 K         \\
 input\_blocks.3   & 6.9 K        \\
 input\_blocks.4   & 6.9 K        \\
 input\_blocks.5   & 6.9 K        \\
 input\_blocks.6   & 6.9 K        \\
 input\_blocks.7   & 6.9 K        \\
 input\_blocks.8   & 6.9 K        \\
 input\_blocks.9   & 20.9 K       \\
 labe