# utils

> utilities


In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os 
import torch

In [None]:
#| export 
def save_checkpoint(model, epoch, val_loss, cfg, optimizer=None, save_every=25, n_keep=5, verbose=True, tag=""):
    """Saves new checkpoint, keeps best & the most recent n_keep.
       Can loop over multiple models. (Saves separate files for each model)"""
    if not hasattr(save_checkpoint, 'best_val_loss'):
        save_checkpoint.best_val_loss = float('inf')
    if epoch % save_every != 0 and val_loss >= save_checkpoint.best_val_loss: return

    os.makedirs('checkpoints', exist_ok=True)
    models = model if isinstance(model, (list, tuple)) else [model]
    for i, m in enumerate(models):
        ckpt = {'epoch': epoch, 'model_state_dict': getattr(m, '_orig_mod', m).state_dict(), 'config': dict(cfg), 'val_loss': val_loss}
        if optimizer is not None and i == 0: ckpt = ckpt | {'optimizer_state_dict': optimizer.state_dict()}  # Only first model on list gets optimizer in its ckpt file.

        mtag = f"{getattr(m, '_orig_mod', m).__class__.__name__}_{tag}"
        if epoch % save_every == 0:
            if verbose: print(f"Saving checkpoint to checkpoints/{mtag}ckpt_epoch{epoch}.pt")
            torch.save(ckpt, f'checkpoints/{mtag}ckpt_epoch{epoch}.pt')
        if val_loss < save_checkpoint.best_val_loss:
            torch.save(ckpt, f'checkpoints/{mtag}_best.pt')

        # delete any checkpoints older than the n_keep-th one
        ckpts = sorted([f for f in os.listdir('checkpoints') if f.startswith(f'{mtag}ckpt_epoch')],
                   key=lambda x: os.path.getmtime(f'checkpoints/{x}'))
        for old in ckpts[:-n_keep]: os.remove(f'checkpoints/{old}')

    if val_loss <= save_checkpoint.best_val_loss:
        save_checkpoint.best_val_loss = val_loss

In [None]:
#| export
def load_checkpoint(model, ckpt_path:str, return_all=False, weights_only=False, strict=False):
    "loads a model (and maybe other things) from a checkpoint file"
    device = next(model.parameters()).device
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=weights_only)
    print(f">>> Loaded model checkpoint from {ckpt_path}")
    ckpt['model_state_dict'] = {k.replace('_orig_mod.', ''): v for k, v in ckpt['model_state_dict'].items()}
    model.load_state_dict(ckpt['model_state_dict'], strict=strict)
    return (model, ckpt) if return_all else model

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()