In [1]:
import os, sys
from pathlib import Path

# Adjust this to your PixNerd repo root
ROOT = Path("/pscratch/sd/k/kevinval/PNBase/PixNerd")
os.chdir(ROOT)
sys.path.insert(0, str(ROOT))

import torch
from functools import partial
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger, WandbLogger, TensorBoardLogger

from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_heavydecoder import PixNerDiT

from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training import FlowMatchingTrainer

from src.callbacks.simple_ema import SimpleEMA
from src.callbacks.save_images import SaveImagesHook
from src.lightning_model import LightningModel
from src.lightning_data import DataModule
from src.data.dataset.cifar10 import PixCIFAR10, CIFAR10RandomNDataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from types import SimpleNamespace

# -------------------------
# Config / hyperparameters
# -------------------------
cfg = SimpleNamespace(
    # training
    max_steps       = 300_000,
    batch_size      = 128,
    lr              = 1e-4,
    num_workers     = 4,

    # model (same as train_cifar10.py)
    hidden_size         = 512,
    decoder_hidden_size = 64,
    num_encoder_blocks  = 8,
    num_decoder_blocks  = 2,
    patch_size          = 8,
    num_groups          = 8,
    num_classes         = 10,

    # flow matching / sampling
    guidance        = 2.0,
    num_sample_steps= 50,

    # sparsity-conditioning
    sparsity        = 0.4,   # total observed (cond + target), e.g. 40%
    cond_fraction   = 0.5,   # fraction of observed used as cond (so 20% cond, 20% target)

    # logging / trainer
    exp_name        = "cifar10_c2i_sparse_flowmatch_test",
    output_dir      = "./workdirs",
    use_wandb       = False,
    wandb_project   = "pixnerd_cifar10",
    save_every_n_steps = 5_000,
    val_every_n_epochs = 10,
    resume          = None,
    precision       = "bf16-mixed",
    devices         = 1,
)

# -------------------------
# Build model (same as build_model, but inline)
# -------------------------
main_scheduler = LinearScheduler()

vae = PixelAE(scale=1.0)
conditioner = LabelConditioner(num_classes=cfg.num_classes)

denoiser = PixNerDiT(
    in_channels=3,
    patch_size=cfg.patch_size,
    num_groups=cfg.num_groups,
    hidden_size=cfg.hidden_size,
    decoder_hidden_size=cfg.decoder_hidden_size,
    num_encoder_blocks=cfg.num_encoder_blocks,
    num_decoder_blocks=cfg.num_decoder_blocks,
    num_classes=cfg.num_classes,
)

sampler = EulerSampler(
    num_steps=cfg.num_sample_steps,
    guidance=cfg.guidance,
    guidance_interval_min=0.0,
    guidance_interval_max=1.0,
    scheduler=main_scheduler,
    w_scheduler=LinearScheduler(),
    guidance_fn=simple_guidance_fn,
    step_fn=ode_step_fn,
)

fm_trainer = FlowMatchingTrainer(
    scheduler=main_scheduler,
    lognorm_t=True,
    timeshift=1.0,
)

ema_tracker = SimpleEMA(decay=0.9999)
optimizer_ctor = partial(torch.optim.AdamW, lr=cfg.lr, weight_decay=0.0)

model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=fm_trainer,
    diffusion_sampler=sampler,
    ema_tracker=ema_tracker,
    optimizer=optimizer_ctor,
    lr_scheduler=None,
    eval_original_model=False,
    sparsity=cfg.sparsity,          # <--- your sparse conditioning
    cond_fraction=cfg.cond_fraction # <--- cond vs target split
)

print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# -------------------------
# Build datamodule (same as build_datamodule)
# -------------------------
train_dataset = PixCIFAR10(
    root="./data",
    train=True,
    random_flip=True,
    download=True,
)

eval_dataset = CIFAR10RandomNDataset(
    num_classes=cfg.num_classes,
    latent_shape=(3, 32, 32),
    max_num_instances=1000,
)

pred_dataset = CIFAR10RandomNDataset(
    num_classes=cfg.num_classes,
    latent_shape=(3, 32, 32),
    max_num_instances=1000,
)

datamodule = DataModule(
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    pred_dataset=pred_dataset,
    train_batch_size=cfg.batch_size,
    train_num_workers=cfg.num_workers,
    pred_batch_size=64,
    pred_num_workers=2,
)

# -------------------------
# Logging, callbacks, Trainer (same as main())
# -------------------------
output_dir = Path(cfg.output_dir) / f"exp_{cfg.exp_name}"
output_dir.mkdir(parents=True, exist_ok=True)
print("Output dir:", output_dir)

# logger choice
if cfg.use_wandb:
    logger = WandbLogger(
        project=cfg.wandb_project,
        name=cfg.exp_name,
        save_dir=str(output_dir),
    )
else:
    # use TensorBoard if available, else CSV
    try:
        import tensorboard  # noqa
        from lightning.pytorch.loggers import TensorBoardLogger
        logger = TensorBoardLogger(
            save_dir=str(output_dir),
            name="logs",
        )
    except Exception:
        logger = CSVLogger(
            save_dir=str(output_dir),
            name="logs",
        )
        print("Using CSVLogger")

callbacks = [
    ModelCheckpoint(
        dirpath=output_dir / "checkpoints",
        every_n_train_steps=cfg.save_every_n_steps,
        save_top_k=-1,
        save_last=True,
    ),
    LearningRateMonitor(logging_interval="step"),
    SaveImagesHook(
        save_dir="val",
        save_compressed=True,
    ),
]

trainer = Trainer(
    default_root_dir=str(output_dir),
    accelerator="auto",
    devices=cfg.devices,
    precision=cfg.precision,
    max_steps=cfg.max_steps,
    check_val_every_n_epoch=cfg.val_every_n_epochs,
    num_sanity_val_steps=0,
    log_every_n_steps=50,
    logger=logger,
    callbacks=callbacks,
)

# -------------------------
# Start training from notebook
# -------------------------
trainer.fit(
    model,
    datamodule=datamodule,
    ckpt_path=cfg.resume,
)


current sampler is ODE sampler, but w_scheduler is enabled


Total parameters: 126,977,158
Files already downloaded and verified


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Output dir: workdirs/exp_cifar10_c2i_sparse_flowmatch_test
Using CSVLogger


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name              | Type                | Params | Mode 
------------------------------------------------------------------
0 | vae               | PixelAE             | 0      | eval 
1 | conditioner       | LabelConditioner    | 0      | eval 
2 | denoiser          | PixNerDiT           | 63.5 M | train
3 | ema_denoiser      | PixNerDiT           | 63.5 M | eval 
4 | diffusion_sampler | EulerSampler        | 0      | train
5 | diffusion_trainer | FlowMatchingTrainer | 0      | train
------------------------------------------------------------------
63.5 M    Trainable params
63.5 M    Non-trainable params
126 M     Total params
507.909   Total estimated model params size (MB)
147       Modules in train mode
147       Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0:   0%|          | 0/391 [00:00<?, ?it/s] 

  buf20.copy_(buf19)


Epoch 0: 100%|█████████▉| 390/391 [01:16<00:00,  5.09it/s, v_num=0, loss=0.0927]

  File "/pscratch/sd/k/kevinval/PNBase/PixNerd/src/models/transformer/pixnerd_c2i_heavydecoder.py", line 310, in forward
    s = self.blocks[i](s, condition, xpos)
  File "/pscratch/sd/k/kevinval/PNBase/PixNerd/src/models/transformer/pixnerd_c2i_heavydecoder.py", line 99, in forward
    x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
  File "/pscratch/sd/k/kevinval/PNBase/PixNerd/src/models/layers/swiglu.py", line 15, in forward
    return self.w3(torch.nn.functional.silu(x1)*x2)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Detected KeyboardInterrupt, attempting graceful shutdown ...

KeyboardInterrupt

