In [None]:
from modelzoo import MODELS_INDEX
import pandas as pd


models = pd.read_csv(MODELS_INDEX, sep="\t")
models.head()

In [None]:
# row = models.iloc[30]
row = models[models.model == "modelzoo.modules.aes.ae.VanillaAE"][models.dataset == "cifar100"].iloc[0]

entity, project_name, run_id, path = row["entity"], row["project_name"], row["wandb_id"], row["path"]
entity, project_name, run_id, path

## W&B Loading

In [None]:
# from modelzoo.utils.io_model import load_wandb_ckpt

# # Load the remote model
# model, ckpt = load_wandb_ckpt(entity, project_name, run_id)
# model

## Local loading

In [None]:
from modelzoo.utils.io_model import load_local_ckpt
from modelzoo import PACKAGE_ROOT

filepath = PACKAGE_ROOT / path
model, ckpt = load_local_ckpt(filepath, strict=False)
model

## Extract configuration from ckpt

In [None]:
from omegaconf import OmegaConf

cfg = OmegaConf.create(ckpt["cfg"])
cfg

## Instantiate datamodule

In [None]:
from torch.utils.data import DataLoader
import hydra
from functools import partial
from modelzoo.data.vision.datamodule import collate_fn
from omegaconf import OmegaConf

# Instantiate datamodule
datamodule = hydra.utils.instantiate(OmegaConf.to_container(cfg.nn.data), _recursive_=False)
datamodule.setup(stage="fit")

# Using the loaders from the datamodule ensures that the correct transforms are applied
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()[0]

train_loader, val_loader

## Test inference on a batch

In [None]:
batch = next(iter(train_loader))
reconstruction = model(batch["x"])["reconstruction"].detach().cpu()

In [None]:
idx = 0

In [None]:
import matplotlib.pyplot as plt
import torch

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(batch["x"][idx].permute(1, 2, 0))
axes[0].title.set_text("Original")
axes[1].imshow(reconstruction[idx].permute(1, 2, 0))
axes[1].title.set_text("Reconstruction")

print(f"MSE: {torch.nn.functional.mse_loss(batch['x'][idx], reconstruction[idx])}")
idx += 1