In [None]:
from pathlib import Path
from nn_core.common import PROJECT_ROOT
import pandas as pd


models = pd.read_csv(PROJECT_ROOT / "models" / "index.csv", sep="\t")
models.head()

In [None]:
# row = models.iloc[30]
row = models[models.model == "modelzoo.modules.aes.ae.VanillaAE"][models.dataset == "cifar100"].iloc[0]

entity, project_name, run_id = row["entity"], row["project_name"], row["wandb_id"]
entity, project_name, run_id

## W&B Loading

In [None]:
# from modelzoo.utils.io_model import load_wandb_ckpt

# # Load the remote model
# model, ckpt = load_wandb_ckpt(entity, project_name, run_id)
# model

## Local loading

In [None]:
from modelzoo.utils.io_model import load_local_ckpt

filepath = PROJECT_ROOT / "models" / "checkpoints" / f"{run_id}.ckpt.zip"
model, ckpt = load_local_ckpt(filepath)
model

## Extract configuration from ckpt

In [None]:
from omegaconf import OmegaConf

cfg = OmegaConf.create(ckpt["cfg"])
cfg

In [None]:
from torch.utils.data import DataLoader
import hydra
from functools import partial
from modelzoo.data.vision.datamodule import collate_fn
from omegaconf import OmegaConf

# Instantiate validation and traing loaders, with correct transforms
datamodule = hydra.utils.instantiate(OmegaConf.to_container(cfg.nn.data), _recursive_=False)
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
val_dataset = datamodule.val_datasets[0]

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    pin_memory=True,
    shuffle=False,
    num_workers=4,
    collate_fn=partial(collate_fn, split="train", metadata=datamodule.metadata, transform=datamodule.transform_batch),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    pin_memory=True,
    shuffle=False,
    num_workers=4,
    collate_fn=partial(collate_fn, split="val", metadata=datamodule.metadata, transform=datamodule.transform_batch),
)

train_dataset, val_dataset

In [None]:
# Test inference on a batch

batch = next(iter(train_loader))
reconstruction = model(batch["x"])["reconstruction"].detach().cpu()

In [None]:
idx = 0

In [None]:
import matplotlib.pyplot as plt
import torch

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(batch["x"][idx].permute(1, 2, 0))
axes[0].title.set_text("Original")
axes[1].imshow(reconstruction[idx].permute(1, 2, 0))
axes[1].title.set_text("Reconstruction")

print(f"MSE: {torch.nn.functional.mse_loss(batch['x'][idx], reconstruction[idx])}")
idx += 1