# PixNerd ImageNet Training & Super-Resolution Notebook

This notebook mirrors the repository's standard Lightning-CLI workflow for PixNerd: install dependencies, create a run-ready config for ImageNet training, launch fitting, and generate higher-resolution samples from a trained checkpoint.


## Prerequisites

- A multi-GPU machine with enough memory to match the original PixNerd-XL ImageNet runs (the default batch sizes assume multiple high-memory GPUs).
- A local ImageNet-1K training set (folder structure compatible with `torchvision.datasets.ImageFolder`).
- Access to the DINOv2 ViT-B/14 weights used for the feature-alignment loss.
- (Optional) Weights & Biases credentials if you want online logging; otherwise you can disable the logger.


In [None]:
# If you haven't installed dependencies in this environment, uncomment the next line.
# !pip install -r requirements.txt


In [None]:
from omegaconf import OmegaConf
from pathlib import Path
import datetime

# Paths you need to customize
IMAGENET_ROOT = "/path/to/imagenet/train"          # ImageNet training folder
DINOV2_WEIGHTS = "/path/to/dinov2_vitb14"          # Local DINOv2 ViT-B/14 checkpoint directory
WORKDIR_ROOT = Path("./workdirs_notebook")          # Where logs, checkpoints, and samples will be saved

EXP_NAME = f"pixnerd_imagenet256_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
BASE_CONFIG = Path("configs_c2i/pix256std1_repa_pixnerd_xl.yaml")

cfg = OmegaConf.load(BASE_CONFIG)

# Override training/logging paths to local locations
cfg.tags.exp = EXP_NAME
cfg.trainer.default_root_dir = str(WORKDIR_ROOT)
if cfg.trainer.logger:
    cfg.trainer.logger.init_args.name = EXP_NAME
    # Set your own project if desired; keep the default to match the paper runs
    cfg.trainer.logger.init_args.project = cfg.trainer.logger.init_args.get("project", "universal_pix_flow")

# Point datasets and frozen encoder to your local data
cfg.data.train_dataset.init_args.root = IMAGENET_ROOT
cfg.data.train_dataset.init_args.resolution = 256
cfg.data.eval_dataset.init_args.latent_shape = [3, 256, 256]
cfg.data.pred_dataset.init_args.latent_shape = [3, 256, 256]
cfg.diffusion_trainer.init_args.encoder.init_args.weight_path = DINOV2_WEIGHTS

# Cache overrides are optional; keep them None unless you need a custom hub path
cfg.torch_hub_dir = None
cfg.huggingface_cache_dir = None

# Save the run-specific config next to this notebook
notebook_cfg = Path("notebooks/configs/pixnerd_imagenet256_notebook.yaml")
notebook_cfg.parent.mkdir(parents=True, exist_ok=True)
OmegaConf.save(cfg, notebook_cfg)
print("Saved config to", notebook_cfg)


The saved config mirrors `configs_c2i/pix256std1_repa_pixnerd_xl.yaml` but with local paths/tags. Training outputs land in `WORKDIR_ROOT/exp_<EXP_NAME>/...`, and the `SaveImagesHook` writes validation/prediction images under the same root.


In [None]:
# Inspect a subset of the resolved config (feel free to delete or extend)
print(OmegaConf.to_yaml(cfg, resolve=True)[:2000])


## Launch full PixNerd ImageNet training

This cell runs the exact Lightning-CLI entrypoint used in the repository. It will respect multi-GPU settings from the config (mixed precision, EMA, checkpointing, periodic validation saves). Expect long runtimes similar to the paper.


In [None]:
# Starts training; comment this out if you only want inference with a pre-trained checkpoint.
# !python main.py fit -c notebooks/configs/pixnerd_imagenet256_notebook.yaml


After training, Lightning stores checkpoints in `WORKDIR_ROOT/exp_<EXP_NAME>/checkpoints/`. Replace `CKPT_PATH` below with `last.ckpt` or the best-step checkpoint you want to sample.


In [None]:
from pathlib import Path

RUN_DIR = Path(cfg.trainer.default_root_dir) / f"exp_{cfg.tags.exp}"
CKPT_PATH = RUN_DIR / "checkpoints/last.ckpt"  # update if you prefer a specific step
print("Expected checkpoint:", CKPT_PATH)


## Sample at training resolution (256×256)

Use the `predict` subcommand to generate class-conditional samples at the training size. Outputs are written to `<default_root_dir>/<save_dir>/predict/` (with `save_dir` coming from `SaveImagesHook`, default `val`).


In [None]:
# !python main.py predict -c notebooks/configs/pixnerd_imagenet256_notebook.yaml --ckpt_path $CKPT_PATH


## Super-resolution sampling (e.g., 512×512)

PixNerd can sample larger grids without architecture changes. Clone the saved config, bump the latent shape to the target resolution, and lower the batch size if needed. The same checkpoint is reused.


In [None]:
superres_cfg = OmegaConf.load(notebook_cfg)
superres_cfg.data.pred_dataset.init_args.latent_shape = [3, 512, 512]
superres_cfg.data.pred_batch_size = 4  # adjust for your memory budget
superres_cfg.trainer.default_root_dir = str(Path(cfg.trainer.default_root_dir) / "superres_runs")

superres_cfg_path = Path("notebooks/configs/pixnerd_imagenet512_predict.yaml")
OmegaConf.save(superres_cfg, superres_cfg_path)
print("Super-resolution config saved to", superres_cfg_path)


In [None]:
# !python main.py predict -c notebooks/configs/pixnerd_imagenet512_predict.yaml --ckpt_path $CKPT_PATH


## Visualize predicted images

The `SaveImagesHook` writes PNGs (first few samples) and a compressed `.npz` array. Update `PRED_DIR` to point at the prediction folder for either the 256×256 or super-resolution run and visualize a handful of images.


In [None]:
import matplotlib.pyplot as plt
from PIL import Image

# Point this to the prediction directory you want to browse
PRED_DIR = Path(superres_cfg.trainer.default_root_dir) / superres_cfg.trainer.callbacks[1].init_args.save_dir / "predict"
print("Reading from", PRED_DIR)

images = sorted(PRED_DIR.glob("*.png"))[:8]
fig, axes = plt.subplots(1, len(images), figsize=(4*len(images), 4))
for ax, image_path in zip(axes, images):
    ax.imshow(Image.open(image_path))
    ax.set_title(image_path.name)
    ax.axis("off")
plt.tight_layout()
