# PixNerd T2I (heavy decoder) class-conditioned trainingSelf-contained notebook that reuses the text-to-image heavy decoder with flow-matching on ImageNet-style folders, but replaces text conditioning with learnable class embeddings.

In [None]:
import os, math, json, random, pathlib, itertoolsimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderimport torchvisionfrom torchvision import transforms, datasetsimport lightning.pytorch as plfrom lightning.pytorch.callbacks import ModelCheckpointfrom lightning.pytorch.loggers import CSVLoggerfrom src.models.autoencoder.pixel import PixelAEfrom src.models.transformer.pixnerd_t2i_heavydecoder import PixNerDiTfrom src.diffusion.flow_matching.scheduling import LinearSchedulerfrom src.diffusion.flow_matching.adam_sampling import AdamLMSampler, ode_step_fnfrom src.diffusion.base.guidance import simple_guidance_fnfrom src.diffusion.flow_matching.training_repa import REPATrainerfrom src.callbacks.simple_ema import SimpleEMAfrom src.lightning_model import LightningModelfrom src.models.encoder import IndentityMappingfrom src.models.autoencoder.base import fp2uint8from src.models.conditioner.base import BaseConditionerpl.seed_everything(42)DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Paths and hyperparametersDATA_ROOT = 'pscratch/sd/k/kevinval/datasets/imagenet256'BATCH_SIZE = 32NUM_WORKERS = 8IMAGE_SIZE = 256MAX_EPOCHS = 1  # change as neededPRECISION = 'bf16' if torch.cuda.is_available() else 32# Model hyperparameters (mirroring configs_t2i/inference_heavydecoder.yaml)HIDDEN_SIZE = 1536TXT_EMBED_DIM = 2048PATCH_SIZE = 16TXT_MAX_LENGTH = 128LOG_DIR = 'logs/pixnerd_t2i_notext'CKPT_DIR = os.path.join(LOG_DIR, 'checkpoints')os.makedirs(CKPT_DIR, exist_ok=True)

In [None]:
# Datasets and loaders with metadata for REPA feature alignmenttransform_raw = transforms.Compose([    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),    transforms.ToTensor(),])transform = transforms.Compose([    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),    transforms.ToTensor(),    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])class ImageFolderWithMetadata(datasets.ImageFolder):    def __getitem__(self, index):        path, target = self.samples[index]        img = self.loader(path)        raw_tensor = transform_raw(img)        img_tensor = transform(img)        metadata = {"raw_image": raw_tensor}        return img_tensor, target, metadatatrain_dataset = ImageFolderWithMetadata(DATA_ROOT)val_dataset = ImageFolderWithMetadata(DATA_ROOT)NUM_CLASSES = len(train_dataset.classes)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
# Conditioning: replace text with learned class token embeddings but keep the same interfaceclass LabelTokenConditioner(BaseConditioner):    def __init__(self, num_classes: int, embed_dim: int, max_length: int = 128):        super().__init__()        self.embed = nn.Embedding(num_classes, embed_dim)        self.max_length = max_length        self.register_buffer('uncond', torch.zeros(1, max_length, embed_dim), persistent=False)    def _impl_condition(self, y, metadata=None):        tokens = self.embed(y.to(self.embed.weight.device))  # (B, embed_dim)        tokens = tokens.unsqueeze(1)        # repeat to max_length so positional embeddings stay aligned with the heavy decoder        return tokens.repeat(1, self.max_length, 1)    def _impl_uncondition(self, y, metadata=None):        b = y.shape[0]        return self.uncond.repeat(b, 1, 1)# Lightweight vision encoder for REPA feature alignmentclass SimpleVisionEncoder(nn.Module):    def __init__(self, proj_dim: int):        super().__init__()        self.backbone = nn.Sequential(            nn.Conv2d(3, 64, 3, stride=2, padding=1),            nn.SiLU(),            nn.Conv2d(64, 128, 3, stride=2, padding=1),            nn.SiLU(),            nn.Conv2d(128, 256, 3, stride=2, padding=1),            nn.SiLU(),            nn.AdaptiveAvgPool2d(1),        )        self.proj = nn.Linear(256, proj_dim)    def forward(self, x):        feat = self.backbone(x)        feat = feat.flatten(1)        feat = self.proj(feat)        return feat.unsqueeze(1)  # (B, 1, proj_dim)

In [None]:
# Instantiate model componentsmain_scheduler = LinearScheduler()vae = PixelAE(scale=1.0)conditioner = LabelTokenConditioner(num_classes=NUM_CLASSES, embed_dim=TXT_EMBED_DIM, max_length=TXT_MAX_LENGTH)# heavy decoder DiTdenoiser = PixNerDiT(    in_channels=3,    patch_size=PATCH_SIZE,    num_groups=24,    hidden_size=HIDDEN_SIZE,    txt_embed_dim=TXT_EMBED_DIM,    txt_max_length=TXT_MAX_LENGTH,    num_text_blocks=4,    decoder_hidden_size=64,    num_encoder_blocks=16,    num_decoder_blocks=2,)sampler = AdamLMSampler(    num_steps=25,    guidance=4.0,    timeshift=3.0,    order=2,    scheduler=main_scheduler,    guidance_fn=simple_guidance_fn,    step_fn=ode_step_fn,)trainer_core = REPATrainer(    scheduler=main_scheduler,    lognorm_t=True,    timeshift=4.0,    feat_loss_weight=0.5,    encoder=SimpleVisionEncoder(proj_dim=768),    align_layer=6,    proj_denoiser_dim=HIDDEN_SIZE,    proj_hidden_dim=HIDDEN_SIZE,    proj_encoder_dim=768,)ema_tracker = SimpleEMA(decay=0.9999)# optimizer callable for LightningModeloptimizer_fn = lambda params: torch.optim.AdamW(params, lr=2e-4, betas=(0.9, 0.99))model = LightningModel(    vae=vae,    conditioner=conditioner,    denoiser=denoiser,    diffusion_trainer=trainer_core,    diffusion_sampler=sampler,    ema_tracker=ema_tracker,    optimizer=optimizer_fn,    lr_scheduler=None,    eval_original_model=False,)

In [None]:
# Traininglogger = CSVLogger(save_dir=LOG_DIR, name='pixnerd_t2i_notext')ckpt_cb = ModelCheckpoint(dirpath=CKPT_DIR, save_top_k=1, monitor='loss', mode='min')trainer = pl.Trainer(    max_epochs=MAX_EPOCHS,    precision=PRECISION,    accelerator='gpu' if torch.cuda.is_available() else 'cpu',    devices=1,    callbacks=[ckpt_cb],    logger=logger,)trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
# Sampling and (optional) super-resolution decodingimport matplotlib.pyplot as pltfrom torchvision.utils import make_grid# set decoder patch scaling >1.0 to upsample at inferenceupscale = 2.0  # change to 1.0 for base resolutionmodel.denoiser.decoder_patch_scaling_h = upscalemodel.denoiser.decoder_patch_scaling_w = upscalemodel.eval()with torch.no_grad():    batch = next(iter(val_loader))    _, y, metadata = batch    xT = torch.randn_like(batch[0])    condition, uncondition = model.conditioner(y, metadata)    samples = model.diffusion_sampler(model.ema_denoiser, xT.to(DEVICE), condition.to(DEVICE), uncondition.to(DEVICE))    imgs = model.vae.decode(samples).cpu()    imgs_uint8 = fp2uint8(imgs)    grid = make_grid(imgs_uint8, nrow=4)    plt.figure(figsize=(8,8))    plt.axis('off')    plt.imshow(grid.permute(1,2,0))    plt.show()