You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
What's new
Added
examples/config-mps.yaml: Recommended training config for Apple Silicon — batch_size=2, use_fp16=false, use_gradient_checkpointing=true, workers=0.
scripts/profile_mps.py: Profiles CPU fallbacks during a VAE forward+backward pass on MPS.
ctx_loss console logging: Context dim loss shown in training progress line as | Ctx: 0.0123.
ctx_loss graph curve: Plotted as a separate curve in training_losses.png.
ctx_loss_weight config param: Tunes the relative weight of context-dim v-prediction loss.
Fixed
MPS cache clearing: torch.mps.empty_cache() added alongside CUDA clearing in VAETrainer and FlowTrainer.
MPS adaptive pooling: Replaced ad-hoc try/except with mps_safe_pool2d from fluxflow.utils.mps.
v-prediction loss target: Flow trainer was computing loss against clean x0 despite prediction_type="v_prediction". Now correctly computes v = alpha_t * noise - sigma_t * x0.
Context dim train/inference mismatch: Training now noises all dims uniformly to match inference.
Context dim loss scale: VAE and context dims normalised independently to prevent ~10x loss imbalance.
Gradient clipping: Replaced broken adaptive formula with straightforward clip_grad_norm_.
GAN instance noise never applied: Inverted guard removed; noise now always applied.
GAN adaptive weight explosion at startup: _compute_adaptive_weight clamped to max_weight=5.0.
Discriminator trained on deterministic latents: Changed to training=True with torch.no_grad() for reparameterisation.
ctx_loss normalisation scale: Now normalises by v-target's own std for consistent scale across timesteps.
Changed
fluxflow dependency bumped to >=0.8.1: Required for fluxflow.utils.mps.mps_safe_pool2d.