# PixNerd ImageNet Training & Super-Resolution (No Lightning CLI)

This self-contained notebook trains and samples PixNerd directly from Python objects—no Lightning CLI invocations required. It mirrors the original configuration while keeping everything editable inside the notebook.

## Prerequisites
- A multi-GPU machine with enough memory for PixNerd-XL ImageNet training.
- Local ImageNet-1K training data (folder layout compatible with `torchvision.datasets.ImageFolder`).
- Local DINOv2 ViT-B/14 weights for the feature-alignment loss.
- (Optional) Weights & Biases credentials if you want online logging.

In [None]:
# If dependencies are not installed in this environment, uncomment the next line.
# !pip install -r requirements.txt

In [None]:
import importlib
import os
from pathlib import Path
import datetime

import torch
from lightning import Trainer, seed_everything
import lightning.pytorch as pl
from omegaconf import OmegaConf

from src.lightning_data import DataModule
from src.lightning_model import LightningModel


def _resolve_target(target):
    module_path, class_name = target.rsplit('.', 1)
    module = importlib.import_module(module_path)
    return getattr(module, class_name)


def instantiate_from_config(conf):
    """Instantiate an object from a config dict with `class_path` and `init_args`."""
    if conf is None:
        return None
    target_cls = _resolve_target(conf["class_path"])
    kwargs = conf.get("init_args", {}) or {}
    return target_cls(**kwargs)


def callable_from_config(conf):
    """Create a callable (for optimizer/scheduler) that defers parameter binding."""
    if conf is None:
        return None
    target_cls = _resolve_target(conf["class_path"])
    kwargs = conf.get("init_args", {}) or {}
    return lambda *args, **extra: target_cls(*args, **{**kwargs, **extra})


def list_from_configs(confs):
    if confs is None:
        return None
    return [instantiate_from_config(c) for c in confs]

In [None]:
# --- Configure paths and training hyperparameters ---
IMAGENET_ROOT = "/path/to/imagenet/train"          # ImageNet training folder
DINOV2_WEIGHTS = "/path/to/dinov2_vitb14"          # Local DINOv2 ViT-B/14 checkpoint directory
WORKDIR_ROOT = Path("./workdirs_notebook")          # Where logs, checkpoints, and samples will be saved

EXP_NAME = f"pixnerd_imagenet256_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
BASE_CONFIG = Path("configs_c2i/pix256std1_repa_pixnerd_xl.yaml")

cfg = OmegaConf.load(BASE_CONFIG)

# Override training/logging paths to local locations
cfg.tags.exp = EXP_NAME
cfg.trainer.default_root_dir = str(WORKDIR_ROOT)
if cfg.trainer.logger:
    cfg.trainer.logger.init_args.name = EXP_NAME
    cfg.trainer.logger.init_args.project = cfg.trainer.logger.init_args.get("project", "universal_pix_flow")

# Point datasets and frozen encoder to your local data
cfg.data.train_dataset.init_args.root = IMAGENET_ROOT
cfg.data.train_dataset.init_args.resolution = 256
cfg.data.eval_dataset.init_args.latent_shape = [3, 256, 256]
cfg.data.pred_dataset.init_args.latent_shape = [3, 256, 256]
cfg.diffusion_trainer.init_args.encoder.init_args.weight_path = DINOV2_WEIGHTS

# Cache overrides are optional; keep them None unless you need a custom hub path
cfg.torch_hub_dir = None
cfg.huggingface_cache_dir = None

# Save the run-specific config next to this notebook
notebook_cfg = Path("notebooks/configs/pixnerd_imagenet256_notebook.yaml")
notebook_cfg.parent.mkdir(parents=True, exist_ok=True)
OmegaConf.save(cfg, notebook_cfg)
print("Saved config to", notebook_cfg)

The saved config mirrors `configs_c2i/pix256std1_repa_pixnerd_xl.yaml` but with local paths/tags. Training outputs land in `WORKDIR_ROOT/exp_<EXP_NAME>/...`, and the `SaveImagesHook` writes validation/prediction images under the same root.

In [None]:
# Optional: peek at the resolved config
print(OmegaConf.to_yaml(cfg, resolve=True)[:2000])

## Build Lightning objects without the CLI
This cell constructs the datamodule, model parts, callbacks, and trainer directly from the config so you can run training entirely inside the notebook.

In [None]:
# Respect optional hub caches
if cfg.get("huggingface_cache_dir"):
    os.environ["HUGGINGFACE_HUB_CACHE"] = cfg.huggingface_cache_dir
if cfg.get("torch_hub_dir"):
    os.environ["TORCH_HOME"] = cfg.torch_hub_dir
    torch.hub.set_dir(cfg.torch_hub_dir)

seed_everything(42, workers=True)

# Instantiate datasets and datamodule
train_dataset = instantiate_from_config(OmegaConf.to_container(cfg.data.train_dataset, resolve=True))
eval_dataset = instantiate_from_config(OmegaConf.to_container(cfg.data.eval_dataset, resolve=True))
pred_dataset = instantiate_from_config(OmegaConf.to_container(cfg.data.pred_dataset, resolve=True))

data_module = DataModule(
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    pred_dataset=pred_dataset,
    train_batch_size=cfg.data.train_batch_size,
    train_num_workers=cfg.data.train_num_workers,
    pred_batch_size=cfg.data.pred_batch_size,
    pred_num_workers=cfg.data.pred_num_workers,
)

# Instantiate model components
vae = instantiate_from_config(OmegaConf.to_container(cfg.model.vae, resolve=True))
conditioner = instantiate_from_config(OmegaConf.to_container(cfg.model.conditioner, resolve=True))
denoiser = instantiate_from_config(OmegaConf.to_container(cfg.model.denoiser, resolve=True))
diffusion_trainer = instantiate_from_config(OmegaConf.to_container(cfg.diffusion_trainer, resolve=True))
diffusion_sampler = instantiate_from_config(OmegaConf.to_container(cfg.diffusion_sampler, resolve=True))
ema_tracker = instantiate_from_config(OmegaConf.to_container(cfg.ema_tracker, resolve=True))
optimizer_fn = callable_from_config(OmegaConf.to_container(cfg.optimizer, resolve=True))
lr_scheduler_fn = callable_from_config(OmegaConf.to_container(cfg.get("lr_scheduler"), resolve=True))

lightning_model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=diffusion_trainer,
    diffusion_sampler=diffusion_sampler,
    ema_tracker=ema_tracker,
    optimizer=optimizer_fn,
    lr_scheduler=lr_scheduler_fn,
)

# Prepare EMA copies and compiled models
lightning_model.configure_model()

# Build callbacks/logger/plugins
trainer_conf = OmegaConf.to_container(cfg.trainer, resolve=True)
logger_conf = trainer_conf.pop("logger", None)
callbacks_conf = trainer_conf.pop("callbacks", [])
plugins_conf = trainer_conf.pop("plugins", None)

logger = instantiate_from_config(logger_conf) if logger_conf else None
callbacks = list_from_configs(callbacks_conf) or None
plugins = list_from_configs(plugins_conf) if plugins_conf else None

trainer = Trainer(
    logger=logger,
    callbacks=callbacks,
    plugins=plugins,
    **trainer_conf,
)

run_dir = Path(trainer.default_root_dir) / f"exp_{cfg.tags.exp}"
print("Trainer default_root_dir:", trainer.default_root_dir)
print("Run directory:", run_dir)

## Launch training
This runs standard PixNerd ImageNet training directly via `trainer.fit`. Expect the same runtime and logging behavior as the CLI version.

In [None]:
# Uncomment to start training
# trainer.fit(lightning_model, datamodule=data_module)

After training, checkpoints appear under `run_dir / "checkpoints"`. Point `CKPT_PATH` at the checkpoint you want to sample.

In [None]:
CKPT_PATH = run_dir / "checkpoints/last.ckpt"  # update if you prefer a specific step
print("Expected checkpoint:", CKPT_PATH)

# Utility to load a checkpoint back into the in-memory model (or skip if you just finished training)
if CKPT_PATH.exists():
    state = torch.load(CKPT_PATH, map_location="cpu")
    lightning_model.load_state_dict(state["state_dict"], strict=False)
    print("Loaded weights from", CKPT_PATH)
else:
    print("Checkpoint not found yet — train first or update CKPT_PATH.")

## Sample at training resolution (256×256)
Use `trainer.predict` with the EMA weights to generate class-conditional samples at the training size. Outputs are written by `SaveImagesHook` under the run directory.

In [None]:
# Uncomment to generate samples
# trainer.predict(lightning_model, dataloaders=data_module.predict_dataloader())

## Super-resolution sampling (e.g., 512×512)
Clone the saved config, bump the latent shape to the target resolution, lower the batch size if needed, rebuild the datamodule/model, and run `trainer.predict` again.

In [None]:
superres_cfg = OmegaConf.load(notebook_cfg)
superres_cfg.data.pred_dataset.init_args.latent_shape = [3, 512, 512]
superres_cfg.data.pred_batch_size = 4  # adjust for your memory budget
superres_cfg.trainer.default_root_dir = str(Path(cfg.trainer.default_root_dir) / "superres_runs")

superres_cfg_path = Path("notebooks/configs/pixnerd_imagenet512_predict.yaml")
OmegaConf.save(superres_cfg, superres_cfg_path)
print("Super-resolution config saved to", superres_cfg_path)

# Build fresh prediction objects
sr_pred_dataset = instantiate_from_config(OmegaConf.to_container(superres_cfg.data.pred_dataset, resolve=True))
sr_data_module = DataModule(
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    pred_dataset=sr_pred_dataset,
    train_batch_size=cfg.data.train_batch_size,
    train_num_workers=cfg.data.train_num_workers,
    pred_batch_size=superres_cfg.data.pred_batch_size,
    pred_num_workers=superres_cfg.data.pred_num_workers,
)

sr_trainer_conf = OmegaConf.to_container(superres_cfg.trainer, resolve=True)
sr_logger_conf = sr_trainer_conf.pop("logger", None)
sr_callbacks_conf = sr_trainer_conf.pop("callbacks", [])
sr_plugins_conf = sr_trainer_conf.pop("plugins", None)

sr_logger = instantiate_from_config(sr_logger_conf) if sr_logger_conf else None
sr_callbacks = list_from_configs(sr_callbacks_conf) or None
sr_plugins = list_from_configs(sr_plugins_conf) if sr_plugins_conf else None

sr_trainer = Trainer(logger=sr_logger, callbacks=sr_callbacks, plugins=sr_plugins, **sr_trainer_conf)
print("Super-res default_root_dir:", sr_trainer.default_root_dir)

# Uncomment to generate 512x512 (or higher) samples
# sr_trainer.predict(lightning_model, dataloaders=sr_data_module.predict_dataloader())

## Visualize predicted images
`SaveImagesHook` writes PNGs (first few samples) and a compressed `.npz` array. Update `PRED_DIR` to point at the prediction folder for either the 256×256 or super-resolution run and visualize a handful of images.

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

# Point this to the prediction directory you want to browse
PRED_DIR = Path(run_dir) / "val" / "predict"  # adjust if you used a different save_dir
print("Reading from", PRED_DIR)

images = sorted(PRED_DIR.glob("*.png"))[:8]
if not images:
    print("No PNGs found yet. Run predict first or point PRED_DIR to your output folder.")
else:
    fig, axes = plt.subplots(1, len(images), figsize=(4*len(images), 4))
    for ax, image_path in zip(axes, images):
        ax.imshow(Image.open(image_path))
        ax.set_title(image_path.name)
        ax.axis("off")
    plt.tight_layout()