# PixNerd Unconditional Generation

This notebook demonstrates unconditional image generation using a pretrained PixNerd checkpoint.
By setting guidance=1.0 and using only the unconditional embedding, we can generate random coherent images
without any text or class conditioning.

Super-resolution is still supported via the NF decoder patch scaling mechanism.

## Environment Setup

- Requires a GPU runtime
- Assumes dependencies from `requirements.txt` are installed
- Place the checkpoint at `PixNerd/checkpoints/checkpoints/PixNerd-XXL-P16-T2I/model.ckpt`
- Text encoder: Auto-detects local Qwen3-1.7B or downloads from HuggingFace (`Qwen/Qwen3-1.7B`)

In [None]:
# Navigate to PixNerd folder where src/ is located
import os
import sys

# Get the directory containing this notebook
NOTEBOOK_DIR = os.getcwd()
print(f"Starting directory: {NOTEBOOK_DIR}")

# Navigate to PixNerd folder (where src/ lives)
PIXNERD_DIR = os.path.join(NOTEBOOK_DIR, "PixNerd")
if os.path.exists(PIXNERD_DIR):
    os.chdir(PIXNERD_DIR)
    print(f"Changed to: {os.getcwd()}")
elif os.path.basename(NOTEBOOK_DIR) == "PixNerd":
    print(f"Already in PixNerd directory: {NOTEBOOK_DIR}")
else:
    # Try parent directory
    parent = os.path.dirname(NOTEBOOK_DIR)
    pixnerd_in_parent = os.path.join(parent, "PixNerd")
    if os.path.exists(pixnerd_in_parent):
        os.chdir(pixnerd_in_parent)
        print(f"Changed to: {os.getcwd()}")
    else:
        print(f"WARNING: Could not find PixNerd folder. Current dir: {NOTEBOOK_DIR}")

# Verify src/ exists
if os.path.exists("src"):
    print("Found src/ directory")
else:
    print("ERROR: src/ directory not found!")

In [None]:
import os
from pathlib import Path
import math
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

# Paths (relative to PixNerd directory)
PIXNERD_ROOT = Path(os.getcwd())

# ============================================================
# CHECKPOINT PATH - UPDATE THIS TO YOUR CHECKPOINT LOCATION
# ============================================================
CKPT_PATH = PIXNERD_ROOT / "checkpoints" / "checkpoints" / "PixNerd-XXL-P16-T2I" / "model.ckpt"

# Alternative: uncomment and modify if your checkpoint is elsewhere
# CKPT_PATH = Path("/path/to/your/model.ckpt")
# ============================================================

# Text encoder path - tries local paths first, falls back to HuggingFace
LOCAL_QWEN_PATHS = [
    "/pscratch/sd/d/dpark1/models/Qwen3-1.7B",  # Your scratch space
    "/pscratch/sd/k/kevinval/models/Qwen3-1.7B",  # Alternative location
    str(PIXNERD_ROOT / "models" / "Qwen3-1.7B"),  # Local to project
]

TEXT_ENCODER_PATH = None
for path in LOCAL_QWEN_PATHS:
    if os.path.exists(path):
        TEXT_ENCODER_PATH = path
        print(f"Found local Qwen3 at: {path}")
        break

if TEXT_ENCODER_PATH is None:
    # Fall back to HuggingFace (will download automatically)
    TEXT_ENCODER_PATH = "Qwen/Qwen3-1.7B"
    print(f"No local Qwen3 found, will download from HuggingFace: {TEXT_ENCODER_PATH}")

# Output directory for generated images
OUTPUT_DIR = PIXNERD_ROOT / "outputs" / "unconditional"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE != "cuda":
    raise RuntimeError("PixNerd inference requires a CUDA GPU")

print(f"PixNerd root: {PIXNERD_ROOT}")
print(f"Checkpoint path: {CKPT_PATH}")
print(f"Checkpoint exists: {CKPT_PATH.exists()}")
if not CKPT_PATH.exists():
    print(f"\n⚠️  CHECKPOINT NOT FOUND!")
    print(f"   Please update CKPT_PATH in this cell to point to your model.ckpt")
    print(f"   Looking for: {CKPT_PATH}")
print(f"Text encoder: {TEXT_ENCODER_PATH}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")

## Build Model Components

We still need the Qwen3TextEncoder to generate the unconditional embedding (empty string encoding).
The model architecture is identical to T2I - we just use it differently during sampling.

In [None]:
# Import PixNerd components from src/
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.qwen3_text_encoder import Qwen3TextEncoder
from src.models.transformer.pixnerd_t2i_heavydecoder import PixNerDiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.adam_sampling import AdamLMSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training_repa import REPATrainer
from src.callbacks.simple_ema import SimpleEMA
from src.lightning_model import LightningModel
from src.models.encoder import IndentityMapping
from src.models.autoencoder.base import fp2uint8

# Model hyperparameters (must match checkpoint)
HIDDEN_SIZE = 1536
TXT_EMBED_DIM = 2048
PATCH_SIZE = 16
TXT_MAX_LENGTH = 128
BASE_RES = 512  # Training resolution

print("Imports successful!")

In [None]:
# Initialize model components
print("Initializing model components...")

main_scheduler = LinearScheduler()

vae = PixelAE(scale=1.0)

print("Loading Qwen3 text encoder (this may take a moment)...")
conditioner = Qwen3TextEncoder(
    weight_path=TEXT_ENCODER_PATH,
    embed_dim=TXT_EMBED_DIM,
    max_length=TXT_MAX_LENGTH,
)

denoiser = PixNerDiT(
    in_channels=3,
    patch_size=PATCH_SIZE,
    num_groups=24,
    hidden_size=HIDDEN_SIZE,
    txt_embed_dim=TXT_EMBED_DIM,
    txt_max_length=TXT_MAX_LENGTH,
    num_text_blocks=4,
    decoder_hidden_size=64,
    num_encoder_blocks=16,
    num_decoder_blocks=2,
)

# Sampler configured for unconditional generation (guidance=1.0)
sampler = AdamLMSampler(
    num_steps=25,
    guidance=1.0,  # No guidance amplification for unconditional
    timeshift=3.0,
    order=2,
    scheduler=main_scheduler,
    guidance_fn=simple_guidance_fn,
    step_fn=ode_step_fn,
)

# REPATrainer stub (needed for checkpoint loading)
trainer_stub = REPATrainer(
    scheduler=main_scheduler,
    lognorm_t=True,
    timeshift=4.0,
    feat_loss_weight=0.5,
    encoder=IndentityMapping(),
    align_layer=6,
    proj_denoiser_dim=HIDDEN_SIZE,
    proj_hidden_dim=HIDDEN_SIZE,
    proj_encoder_dim=768,
)

ema_tracker = SimpleEMA(decay=0.9999)

model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=trainer_stub,
    diffusion_sampler=sampler,
    ema_tracker=ema_tracker,
    optimizer=None,
    lr_scheduler=None,
    eval_original_model=False,
)

model.eval()
model.to(DEVICE)
print("Model initialized and moved to GPU.")

## Load Checkpoint

In [None]:
print(f"Loading checkpoint from: {CKPT_PATH}")
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
print("Checkpoint loaded. Ready for unconditional sampling!")

## Helper Functions

In [None]:
def set_decoder_scale(scale: float):
    """Set NF decoder patch scaling for super-resolution."""
    for net in [model.denoiser, getattr(model, "ema_denoiser", None)]:
        if net is None:
            continue
        net.decoder_patch_scaling_h = scale
        net.decoder_patch_scaling_w = scale


@torch.no_grad()
def sample_unconditional(
    batch_size: int = 4,
    height: int = 512,
    width: int = 512,
    seed: int = 42,
    num_steps: int = 25,
    base_res: int = BASE_RES,
):
    """
    Generate images without any conditioning.
    
    Args:
        batch_size: Number of images to generate
        height: Output image height (can be different from base_res for super-res)
        width: Output image width
        seed: Random seed for reproducibility
        num_steps: Number of ODE solver steps
        base_res: Training resolution (used to compute scaling factor)
    
    Returns:
        torch.Tensor: Generated images as uint8 [B, 3, H, W]
    """
    torch.manual_seed(seed)
    
    # Compute and set decoder scale for super-resolution
    if height == base_res and width == base_res:
        set_decoder_scale(1.0)
        print(f"Generating at native {base_res}x{base_res}")
    else:
        scale_h = height / float(base_res)
        scale_w = width / float(base_res)
        assert scale_h == scale_w, "Only square scaling supported"
        set_decoder_scale(scale_h)
        print(f"Generating at {height}x{width} (scale={scale_h:.2f}x)")
    
    # Configure sampler for unconditional generation
    model.diffusion_sampler.guidance = 1.0  # No CFG amplification
    model.diffusion_sampler.num_steps = num_steps
    
    # Start from Gaussian noise at target resolution
    noise = torch.randn(batch_size, 3, height, width, device=DEVICE)
    
    # Get unconditional embedding (empty string encoded by Qwen3)
    # We pass uncondition for BOTH condition and uncondition arguments
    dummy_prompts = [""] * batch_size
    _, uncondition = model.conditioner(dummy_prompts)
    uncondition = uncondition.to(DEVICE)
    
    # Run sampling with uncondition for both slots
    # CFG formula: uncond + guidance * (cond - uncond)
    # With cond = uncond and guidance = 1.0: output = uncond (pure unconditional)
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        uncondition,  # condition slot (using uncond)
        uncondition,  # uncondition slot
    )
    
    # Decode to pixel space
    images = model.vae.decode(samples)
    images = torch.clamp(images, -1.0, 1.0)
    images_uint8 = fp2uint8(images)
    
    return images_uint8.cpu()


def show_images(images_uint8, title="", cols=None):
    """Display a batch of images."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    n = len(imgs_np)
    if cols is None:
        cols = min(n, 4)
    rows = math.ceil(n / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for i, (ax, img) in enumerate(zip(axes, imgs_np)):
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"{title} #{i}")
    
    # Hide empty subplots
    for ax in axes[n:]:
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()


def save_grid(images_uint8, filename, cols=None):
    """Save images as a grid."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    imgs = [Image.fromarray(img) for img in imgs_np]
    
    n = len(imgs)
    if cols is None:
        cols = min(n, 4)
    rows = math.ceil(n / cols)
    
    w, h = imgs[0].size
    grid = Image.new("RGB", (cols * w, rows * h))
    for idx, img in enumerate(imgs):
        r, c = divmod(idx, cols)
        grid.paste(img, (c * w, r * h))
    
    out_path = OUTPUT_DIR / filename
    grid.save(out_path)
    print(f"Saved: {out_path}")
    return out_path


print("Helper functions defined.")

## Unconditional Generation at Base Resolution (512x512)

Generate random images without any text or class conditioning.

In [None]:
# Generate 4 unconditional images at 512x512
print("=" * 50)
print("Unconditional Generation: 512x512 (native)")
print("=" * 50)

images_512 = sample_unconditional(
    batch_size=4,
    height=512,
    width=512,
    seed=42,
    num_steps=25,
)

print(f"Output shape: {images_512.shape}")
show_images(images_512, title="Unconditional 512x512")
save_grid(images_512, "unconditional_512x512_grid.png")

## Unconditional Super-Resolution (1024x1024)

The NF decoder can upsample to higher resolutions even in unconditional mode.

In [None]:
# Generate unconditional images at 1024x1024 (2x super-res)
print("=" * 50)
print("Unconditional Generation: 1024x1024 (2x super-res)")
print("=" * 50)

images_1024 = sample_unconditional(
    batch_size=4,
    height=1024,
    width=1024,
    seed=42,  # Same seed as 512 for comparison
    num_steps=25,
)

print(f"Output shape: {images_1024.shape}")
show_images(images_1024, title="Unconditional 1024x1024")
save_grid(images_1024, "unconditional_1024x1024_grid.png")

## Unconditional Super-Resolution (768x768)

In [None]:
# Generate at 768x768 (1.5x super-res)
print("=" * 50)
print("Unconditional Generation: 768x768 (1.5x super-res)")
print("=" * 50)

images_768 = sample_unconditional(
    batch_size=4,
    height=768,
    width=768,
    seed=123,
    num_steps=25,
)

print(f"Output shape: {images_768.shape}")
show_images(images_768, title="Unconditional 768x768")
save_grid(images_768, "unconditional_768x768_grid.png")

## Seed Variation

Different seeds produce different random images.

In [None]:
# Generate with different seeds
print("=" * 50)
print("Seed Variation at 512x512")
print("=" * 50)

all_images = []
seeds = [0, 42, 123, 456, 789, 1000, 2024, 9999]

for seed in seeds:
    img = sample_unconditional(
        batch_size=1,
        height=512,
        width=512,
        seed=seed,
        num_steps=25,
    )
    all_images.append(img)

# Stack all images
all_images = torch.cat(all_images, dim=0)
print(f"Generated {len(seeds)} images with different seeds")
show_images(all_images, title="Seed", cols=4)
save_grid(all_images, "unconditional_seed_variation.png", cols=4)

## Compare: Base vs Super-Resolution (Same Seed)

Using the same seed, compare 512x512 native vs 1024x1024 super-res.

In [None]:
print("=" * 50)
print("Comparison: Same seed, different resolutions")
print("=" * 50)

COMPARE_SEED = 777

# 512x512
img_512 = sample_unconditional(
    batch_size=1,
    height=512,
    width=512,
    seed=COMPARE_SEED,
    num_steps=25,
)

# 1024x1024
img_1024 = sample_unconditional(
    batch_size=1,
    height=1024,
    width=1024,
    seed=COMPARE_SEED,
    num_steps=25,
)

# Bilinear upscale 512 to 1024 for comparison
img_512_tensor = img_512.float() / 255.0
img_512_upscaled = F.interpolate(
    img_512_tensor,
    size=(1024, 1024),
    mode='bilinear',
    align_corners=False,
)
img_512_upscaled = (img_512_upscaled * 255).to(torch.uint8)

# Display comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(img_512[0].permute(1, 2, 0).numpy())
axes[0].set_title("512x512 Native")
axes[0].axis('off')

axes[1].imshow(img_512_upscaled[0].permute(1, 2, 0).numpy())
axes[1].set_title("512 -> Bilinear 1024")
axes[1].axis('off')

axes[2].imshow(img_1024[0].permute(1, 2, 0).numpy())
axes[2].set_title("1024x1024 NF Super-Res")
axes[2].axis('off')

plt.suptitle(f"Seed {COMPARE_SEED}: Native vs Bilinear vs NF Super-Resolution")
plt.tight_layout()
plt.show()

## High-Resolution Unconditional (1920x1920)

In [None]:
# Generate at very high resolution
print("=" * 50)
print("Unconditional Generation: 1920x1920 (3.75x super-res)")
print("=" * 50)

images_1920 = sample_unconditional(
    batch_size=2,  # Fewer images due to GPU memory
    height=1920,
    width=1920,
    seed=42,
    num_steps=25,
)

print(f"Output shape: {images_1920.shape}")
show_images(images_1920, title="Unconditional 1920x1920", cols=2)
save_grid(images_1920, "unconditional_1920x1920_grid.png", cols=2)

## Summary

This notebook demonstrated **unconditional generation** with PixNerd:

### How It Works
- Set `guidance=1.0` (no CFG amplification)
- Use the unconditional embedding (empty string `""` encoded by Qwen3) for **both** condition and uncondition slots
- CFG formula becomes: `uncond + 1.0 * (uncond - uncond) = uncond`

### Super-Resolution Still Works
- The NF decoder patch scaling mechanism is independent of conditioning
- Can generate at arbitrary resolutions: 512, 768, 1024, 1920, etc.

### Trade-offs
| Aspect | Unconditional | Conditional (T2I) |
|--------|---------------|-------------------|
| Control | None | Full text control |
| Output | Random from training distribution | Specified content |
| Quality | Good | Slightly better with CFG |

### Use Cases
- Exploring the model's learned image distribution
- Generating diverse random images
- Baseline comparison for conditional generation
- Testing super-resolution without conditioning effects

In [None]:
print("Done!")