In [1]:
import os
from pathlib import Path
import torch


In [3]:
p = '/workspace/mnt/models/pontcha.ckpt'

In [7]:
keep_only_ema=False
print(f"prunin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
sd = torch.load(p, map_location="cpu")
print(sd.keys())
for k in sd.keys():
    if k != "optimizer_states":
        nsd[k] = sd[k]
else:
    print(f"removing optimizer states for path {p}")
if "global_step" in sd:
    print(f"This is global step {sd['global_step']}.")
if keep_only_ema:
    sd = nsd["state_dict"].copy()
    # infer ema keys
    ema_keys = {k: "model_ema." + k[6:].replace(".", "") for k in sd.keys() if k.startswith('model.')}
    new_sd = dict()

    for k in sd:
        if k in ema_keys:
            print(k, ema_keys[k])
            new_sd[k] = sd[ema_keys[k]]
        elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
            new_sd[k] = sd[k]

    assert len(new_sd) == len(sd) - len(ema_keys)
    nsd["state_dict"] = new_sd
else:
    sd = nsd['state_dict'].copy()
    new_sd = dict()
    for k in sd:
        new_sd[k] = sd[k]
    nsd['state_dict'] = new_sd

fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}")
torch.save(nsd, fn)
newsize = os.path.getsize(fn)
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
      f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
if keep_only_ema:
    MSG += " and non-EMA weights"
print(MSG)

prunin' in path: /workspace/mnt/models/pontcha.ckpt
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])
removing optimizer states for path /workspace/mnt/models/pontcha.ckpt
This is global step 800.
saving pruned checkpoint at: /workspace/mnt/models/pontcha-pruned.ckpt
New ckpt size: 4.27 GB. Saved 7.86 GB by removing optimizer states


In [6]:
ema_keys

{'model.diffusion_model.time_embed.0.weight': 'model_ema.diffusion_modeltime_embed0weight',
 'model.diffusion_model.time_embed.0.bias': 'model_ema.diffusion_modeltime_embed0bias',
 'model.diffusion_model.time_embed.2.weight': 'model_ema.diffusion_modeltime_embed2weight',
 'model.diffusion_model.time_embed.2.bias': 'model_ema.diffusion_modeltime_embed2bias',
 'model.diffusion_model.input_blocks.0.0.weight': 'model_ema.diffusion_modelinput_blocks00weight',
 'model.diffusion_model.input_blocks.0.0.bias': 'model_ema.diffusion_modelinput_blocks00bias',
 'model.diffusion_model.input_blocks.1.0.in_layers.0.weight': 'model_ema.diffusion_modelinput_blocks10in_layers0weight',
 'model.diffusion_model.input_blocks.1.0.in_layers.0.bias': 'model_ema.diffusion_modelinput_blocks10in_layers0bias',
 'model.diffusion_model.input_blocks.1.0.in_layers.2.weight': 'model_ema.diffusion_modelinput_blocks10in_layers2weight',
 'model.diffusion_model.input_blocks.1.0.in_layers.2.bias': 'model_ema.diffusion_modeli