# MeanFlow Training Diagnostics

This notebook helps diagnose why training isn't improving.

In [None]:
import os
# Prevent JAX from preallocating all GPU memory
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('/home/emil/KTH/Adv. Deep Learning/Project/Means Flow/src')

from data.cifar10 import make_cifar10
from core.schedules import linear_path, sample_r_t
from models.meanflow_net import MeanFlowNet
import optax
from flax.training import train_state
import tensorflow as tf

print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

## Test 1: Check if model can overfit to a single batch

In [None]:
# Get a single batch
ds = make_cifar10(batch_size=8, split="train[:100]", shuffle=False, cache=True)
single_batch = next(iter(ds.as_numpy_iterator()))
images, labels = single_batch

print(f"Batch shape: {images.shape}")
print(f"Labels: {labels}")
print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")

# Store a copy to avoid accidental overwriting
images_test = images.copy()
labels_test = labels.copy()

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(np.clip(images[i], 0, 1))
    ax.set_title(f"Class {labels[i]}")
    ax.axis('off')
plt.tight_layout()
plt.show()

## Test 2: Verify time sampling and noising

In [None]:
# Test time sampling
rng = jax.random.PRNGKey(42)
r, t = sample_r_t(rng, batch=8)

print("Time samples:")
for i in range(8):
    print(f"  r={r[i]:.3f}, t={t[i]:.3f}, t-r={t[i]-r[i]:.3f}")

# Verify r < t always
assert jnp.all(r <= t), "r should always be <= t"
print("\n✓ Time sampling constraint satisfied: r <= t")

In [None]:
# Test linear path noising
rng = jax.random.PRNGKey(42)
rng, noise_rng, time_rng = jax.random.split(rng, 3)

# Get one image
x_clean = images[0:1]  # [1, 32, 32, 3]
eps = jax.random.normal(noise_rng, x_clean.shape)

# Test at different time steps
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
test_times = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0])

for i, t_val in enumerate(test_times):
    t = jnp.array([t_val])
    zt, v_t = linear_path(x_clean, eps, t)
    
    # Plot noisy image
    axes[0, i].imshow(np.clip(zt[0], 0, 1))
    axes[0, i].set_title(f't={t_val:.2f}')
    axes[0, i].axis('off')
    
    # Plot velocity field (as RGB)
    v_vis = (v_t[0] - v_t[0].min()) / (v_t[0].max() - v_t[0].min() + 1e-8)
    axes[1, i].imshow(v_vis)
    axes[1, i].set_title(f'velocity')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Noisy Image', fontsize=12)
axes[1, 0].set_ylabel('Velocity', fontsize=12)
plt.tight_layout()
plt.show()

print("✓ Linear path interpolates correctly from clean (t=0) to noise (t=1)")

## Test 3: Check model initialization and forward pass

In [None]:
# Initialize model - FRESH INITIALIZATION (no cached params)
# If you get shape errors, make sure to run this cell fresh after any model changes!

# Use the test batch (ensure correct size)
test_images = images_test[:8]  # Force batch size 8
test_labels = labels_test[:8]

print(f"DEBUG: test_images.shape = {test_images.shape}")
print(f"DEBUG: test_labels.shape = {test_labels.shape}")
print(f"DEBUG: test_labels = {test_labels}")
print()

model = MeanFlowNet(
    in_ch=3,
    latent_hw=32,
    ch=32,
    num_classes=10,
    ch_mult=(1, 2, 4),
    num_res_blocks=2
)

# Use a fresh RNG key
rng_model = jax.random.PRNGKey(999)  # Changed seed to force fresh init
rng_model, init_rng = jax.random.split(rng_model)

# Dummy inputs
dummy_x = jnp.zeros((8, 32, 32, 3))
dummy_r = jnp.zeros((8,))
dummy_t = jnp.ones((8,))
dummy_cls = jnp.zeros((8,), dtype=jnp.int32)

print("Initializing model with fresh parameters...")
params = model.init(init_rng, dummy_x, dummy_r, dummy_t, dummy_cls,
                   train_cfg_drop=0.1, rng=init_rng)["params"]
print("Model initialized successfully!")

# Test forward pass with fresh r, t variables (avoid variable pollution from previous cells)
rng_model, forward_rng, time_rng = jax.random.split(rng_model, 3)
r_test, t_test = sample_r_t(time_rng, batch=8)

output = model.apply({"params": params}, test_images, r_test, t_test, test_labels.astype(jnp.int32),
                    train_cfg_drop=0.1, rng=forward_rng)

print(f"\nInput shape: {test_images.shape}")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
print(f"Output mean: {output.mean():.6f}")
print(f"Output std: {output.std():.6f}")

# Check if output is all zeros (bad initialization)
if jnp.abs(output).max() < 1e-5:
    print("\n⚠️  WARNING: Output is nearly zero! Model might not be initialized properly.")
else:
    print(f"\n✓ Model produces non-zero outputs")
    print(f"  Max abs output: {jnp.abs(output).max():.6f}")

## Test 4: Check loss computation

In [None]:
from core.identity import meanflow_target

# Sample some noisy data with fresh variables
rng_loss = jax.random.PRNGKey(42)
rng_loss, noise_rng, time_rng = jax.random.split(rng_loss, 3)

eps = jax.random.normal(noise_rng, images.shape)
t_sample = jax.random.uniform(time_rng, (8,))
zt, v_t = linear_path(images, eps, t_sample)
r_loss, t_loss = sample_r_t(time_rng, 8)

# Define u_apply
def u_apply(params, zt_, r_, t_, cls_idx_, rng_local):
    return model.apply({"params": params}, zt_, r_, t_, cls_idx_,
                      train_cfg_drop=0.1, rng=rng_local)

# Compute target
rng_loss, target_rng = jax.random.split(rng_loss)
u_pred, u_star = meanflow_target(u_apply, params, zt, r_loss, t_loss, 
                                labels.astype(jnp.int32), v_t, rng=target_rng)

loss = jnp.mean((u_pred - u_star)**2)

print("Loss computation test:")
print(f"  u_pred shape: {u_pred.shape}")
print(f"  u_star shape: {u_star.shape}")
print(f"  u_pred range: [{u_pred.min():.3f}, {u_pred.max():.3f}]")
print(f"  u_star range: [{u_star.min():.3f}, {u_star.max():.3f}]")
print(f"  Loss: {loss:.6f}")

if loss > 10.0:
    print("\n⚠️  WARNING: Initial loss is very high. This might make learning difficult.")
elif loss < 0.01:
    print("\n⚠️  WARNING: Initial loss is very low. Check if target is being computed correctly.")
else:
    print("\n✓ Loss is in a reasonable range")

## Test 5: Overfit on single batch

In [None]:
# Create optimizer with higher learning rate for overfitting test
tx = optax.adam(learning_rate=1e-3)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

@jax.jit
def train_step_simple(state, batch, rng):
    images, labels = batch
    B = images.shape[0]
    
    rng, rng_r_t, rng_eps, rng_drop = jax.random.split(rng, 4)
    
    eps = jax.random.normal(rng_eps, images.shape)
    zt, v_t = linear_path(images, eps, jax.random.uniform(rng_r_t, (B,)))
    r, t = sample_r_t(rng_r_t, B)
    cls_idx = labels.astype(jnp.int32)
    
    def u_apply(params, zt_, r_, t_, cls_idx_, rng_local):
        return state.apply_fn({"params": params}, zt_, r_, t_, cls_idx_,
                             train_cfg_drop=0.1, rng=rng_local)
    
    def loss_fn(params):
        u_pred, u_star = meanflow_target(u_apply, params, zt, r, t, cls_idx, v_t, rng=rng_drop)
        return jnp.mean((u_pred - u_star)**2)
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    new_state = state.apply_gradients(grads=grads)
    
    # Check gradient norms
    grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grads)))
    
    return new_state, {"loss": loss, "grad_norm": grad_norm}, rng

print("Overfitting test on single batch...")
print("If the model can learn, loss should decrease dramatically.\n")

losses = []
grad_norms = []

rng = jax.random.PRNGKey(42)
for step in range(200):
    state, metrics, rng = train_step_simple(state, single_batch, rng)
    losses.append(float(metrics["loss"]))
    grad_norms.append(float(metrics["grad_norm"]))
    
    if step % 20 == 0:
        print(f"Step {step:3d}: loss={metrics['loss']:.6f}, grad_norm={metrics['grad_norm']:.3f}")

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

ax1.plot(losses)
ax1.set_xlabel('Step')
ax1.set_ylabel('Loss')
ax1.set_title('Loss on Single Batch (Overfitting Test)')
ax1.grid(True, alpha=0.3)

ax2.plot(grad_norms)
ax2.set_xlabel('Step')
ax2.set_ylabel('Gradient Norm')
ax2.set_title('Gradient Norms')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Diagnosis
initial_loss = losses[0]
final_loss = losses[-1]
improvement = (initial_loss - final_loss) / initial_loss * 100

print(f"\n{'='*60}")
print("DIAGNOSIS")
print(f"{'='*60}")
print(f"Initial loss: {initial_loss:.6f}")
print(f"Final loss: {final_loss:.6f}")
print(f"Improvement: {improvement:.2f}%")

if improvement > 50:
    print("\n✓ GOOD: Model can learn! Loss decreased significantly.")
    print("   Issue is likely: learning rate too small, need more training, or dataset size.")
elif improvement > 10:
    print("\n⚠️  MODERATE: Model is learning slowly.")
    print("   Suggestions: increase learning rate, simplify model, or check data preprocessing.")
else:
    print("\n❌ PROBLEM: Model is not learning!")
    print("   Possible issues:")
    print("   - Loss function implementation")
    print("   - Model architecture bugs")
    print("   - Gradient flow problems")
    print("   - Data preprocessing issues")

## Test 6: Visualize what model predicts

In [None]:
# Deep dive into what the model is actually learning
print("="*60)
print("DETAILED LOSS ANALYSIS")
print("="*60)

# Get a fresh batch and compute everything step by step
rng_debug = jax.random.PRNGKey(999)
test_batch = single_batch
test_images, test_labels = test_batch
B = test_images.shape[0]

# Sample random components
rng_debug, rng_r_t, rng_eps = jax.random.split(rng_debug, 3)
eps = jax.random.normal(rng_eps, test_images.shape)
r, t = sample_r_t(rng_r_t, B)

print(f"\n1. Input data:")
print(f"   Images shape: {test_images.shape}")
print(f"   Images range: [{test_images.min():.3f}, {test_images.max():.3f}]")
print(f"   Labels: {test_labels}")

print(f"\n2. Time sampling:")
print(f"   r range: [{r.min():.3f}, {r.max():.3f}]")
print(f"   t range: [{t.min():.3f}, {t.max():.3f}]")
print(f"   t-r range: [{(t-r).min():.3f}, {(t-r).max():.3f}]")

# Create noisy images
zt, v_t = linear_path(test_images, eps, t)
print(f"\n3. Noisy images:")
print(f"   zt range: [{zt.min():.3f}, {zt.max():.3f}]")
print(f"   v_t (true velocity) range: [{v_t.min():.3f}, {v_t.max():.3f}]")
print(f"   v_t mean: {v_t.mean():.6f}, std: {v_t.std():.6f}")

# Get model prediction
rng_debug, pred_rng = jax.random.split(rng_debug)
u_pred_direct = model.apply(
    {"params": state.params}, 
    zt, r, t, test_labels.astype(jnp.int32),
    train_cfg_drop=0.0, 
    rng=pred_rng
)

print(f"\n4. Model prediction (u_pred):")
print(f"   u_pred range: [{u_pred_direct.min():.3f}, {u_pred_direct.max():.3f}]")
print(f"   u_pred mean: {u_pred_direct.mean():.6f}, std: {u_pred_direct.std():.6f}")

# Compute the meanflow target
def u_apply_debug(params, zt_, r_, t_, cls_idx_, rng_local):
    return model.apply({"params": params}, zt_, r_, t_, cls_idx_,
                      train_cfg_drop=0.0, rng=rng_local)

from core.identity import meanflow_target
rng_debug, target_rng = jax.random.split(rng_debug)
u_pred, u_star = meanflow_target(
    u_apply_debug, state.params, zt, r, t, 
    test_labels.astype(jnp.int32), v_t, rng=target_rng
)

print(f"\n5. MeanFlow target computation:")
print(f"   u_pred range: [{u_pred.min():.3f}, {u_pred.max():.3f}]")
print(f"   u_star (target) range: [{u_star.min():.3f}, {u_star.max():.3f}]")
print(f"   u_star mean: {u_star.mean():.6f}, std: {u_star.std():.6f}")

# Compute loss
mse_loss = jnp.mean((u_pred - u_star)**2)
print(f"\n6. Loss:")
print(f"   MSE loss: {mse_loss:.6f}")

# Check if predictions are reasonable
diff = u_pred - u_star
print(f"\n7. Prediction error analysis:")
print(f"   Error (u_pred - u_star) range: [{diff.min():.3f}, {diff.max():.3f}]")
print(f"   Error mean: {diff.mean():.6f}, std: {diff.std():.6f}")
print(f"   Relative error: {(jnp.abs(diff).mean() / jnp.abs(u_star).mean()):.2%}")

# Check if model output correlates with target at all
correlation = jnp.corrcoef(u_pred.flatten(), u_star.flatten())[0, 1]
print(f"   Correlation between u_pred and u_star: {correlation:.4f}")

print("\n" + "="*60)
print("DIAGNOSIS:")
print("="*60)

if jnp.abs(u_pred).max() < 1e-4:
    print("❌ CRITICAL: Model outputs are nearly zero!")
    print("   → Check model initialization")
elif jnp.abs(correlation) < 0.1:
    print("❌ CRITICAL: No correlation between prediction and target!")
    print("   → Model is not learning the right thing")
elif mse_loss > 1.0:
    print("⚠️  WARNING: Loss is very high")
    print("   → Model needs more training or higher learning rate")
elif mse_loss < 0.01:
    print("✓ EXCELLENT: Model is predicting well!")
else:
    print("⚠️  MODERATE: Model is learning but slowly")
    print(f"   → Loss {mse_loss:.4f} should decrease with more training")

print("="*60)

## Test 5.5: Deep Dive into Loss Computation

Let's check what's actually happening in the loss computation step by step.

In [None]:
# Test if model output changes with different t values
print("Testing if model responds to time conditioning...")
print("="*60)

x_test = test_images[0:1]  # Single image
cls_test = test_labels[0:1].astype(jnp.int32)
r_fixed = jnp.array([0.0])  # Fixed r

# Test at different time points
test_t_values = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0])
outputs_at_different_t = []

for t_val in test_t_values:
    t_test = jnp.array([t_val])
    output = model.apply(
        {"params": state.params},
        x_test, r_fixed, t_test, cls_test,
        train_cfg_drop=0.0,
        rng=None
    )
    outputs_at_different_t.append(output)
    print(f"t={t_val:.2f}: output mean={output.mean():.6f}, std={output.std():.6f}")

# Check if outputs are different
print(f"\nVariance across time steps:")
for i in range(1, len(outputs_at_different_t)):
    diff = outputs_at_different_t[i] - outputs_at_different_t[0]
    print(f"  ||output(t={test_t_values[i]:.2f}) - output(t=0.0)||² = {jnp.mean(diff**2):.6f}")

# If all differences are near zero, model is ignoring time!
max_diff = max([jnp.mean((outputs_at_different_t[i] - outputs_at_different_t[0])**2) 
                for i in range(1, len(outputs_at_different_t))])

print(f"\nMax difference from t=0.0: {max_diff:.6f}")

if max_diff < 1e-6:
    print("❌ CRITICAL BUG: Model output doesn't change with time!")
    print("   → Time embedding is not working")
elif max_diff < 0.01:
    print("⚠️  WARNING: Model barely responds to time changes")
    print("   → Time conditioning might be too weak")
else:
    print("✓ Model responds to time conditioning")

print("="*60)

## Test 5.6: Check Time Conditioning

Verify that the model actually responds to different time values.

In [None]:
# Generate predictions at different noise levels
rng_viz = jax.random.PRNGKey(42)
x_clean = images[0:1]
cls = labels[0:1].astype(jnp.int32)

fig, axes = plt.subplots(3, 5, figsize=(15, 9))
test_times = [0.0, 0.25, 0.5, 0.75, 1.0]

rng_viz, noise_rng = jax.random.split(rng_viz)
eps = jax.random.normal(noise_rng, x_clean.shape)

for i, t_val in enumerate(test_times):
    t_viz = jnp.array([t_val])
    r_viz = jnp.array([max(0.0, t_val - 0.1)])  # r slightly less than t
    
    zt, v_t = linear_path(x_clean, eps, t_viz)
    
    rng_viz, pred_rng = jax.random.split(rng_viz)
    u_pred = model.apply({"params": state.params}, zt, r_viz, t_viz, cls,
                        train_cfg_drop=0.0, rng=pred_rng)
    
    # Visualize
    axes[0, i].imshow(np.clip(zt[0], 0, 1))
    axes[0, i].set_title(f't={t_val:.2f}')
    axes[0, i].axis('off')
    
    # Velocity target (normalized for visualization)
    v_vis = (v_t[0] - v_t[0].min()) / (v_t[0].max() - v_t[0].min() + 1e-8)
    axes[1, i].imshow(v_vis)
    axes[1, i].set_title('Target velocity')
    axes[1, i].axis('off')
    
    # Predicted velocity (normalized for visualization)
    u_vis = (u_pred[0] - u_pred[0].min()) / (u_pred[0].max() - u_pred[0].min() + 1e-8)
    axes[2, i].imshow(u_vis)
    axes[2, i].set_title('Predicted velocity')
    axes[2, i].axis('off')

axes[0, 0].set_ylabel('Noisy Input', fontsize=12)
axes[1, 0].set_ylabel('Target', fontsize=12)
axes[2, 0].set_ylabel('Prediction', fontsize=12)

plt.tight_layout()
plt.show()

print("Check if predictions look similar to targets (especially after overfitting test)")

## Summary and Recommendations

Based on the tests above, we can diagnose the training issues:

1. **If overfitting test shows improvement**: The model CAN learn, but needs:
   - Higher learning rate (try 3e-4 or 5e-4)
   - Fewer images initially (start with 1000-5000)
   - More training steps

2. **If overfitting test shows NO improvement**: There's a bug in:
   - Loss computation
   - Model architecture
   - Data preprocessing

3. **If gradients are vanishing** (grad_norm < 0.01): Check:
   - Model initialization
   - Activation functions
   - Gradient flow through architecture