# AERO-VIT 1.1B Phase 1 Sandbox

This notebook verifies the stability of the 8-lane Manifold Constrained Hyper Connections (mHC) and benchmarks the Sinkhorn-Knopp routing.

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from src.model.aero_vit import AeroVitModel
from src.model.routing import RoutingLayer, sinkhorn_knopp
import time
import functools

print("JAX devices:", jax.devices())

## 1. Model Initialization

Initialize the AERO-VIT model with 8-lane mHC.

In [None]:
SEQ_LEN = 1024 # Reduced for local debugging, target is 16k
BATCH_SIZE = 1
DIM = 256
HEADS = 8
LAYERS = 2

model = AeroVitModel(num_layers=LAYERS, dim=DIM, num_heads=HEADS, num_classes=10)
dummy_input = jnp.ones((BATCH_SIZE, SEQ_LEN, DIM))

key = jax.random.PRNGKey(0)
params = model.init(key, dummy_input)

print("Model initialized.")

## 2. Stability Check

Run a forward pass and check for NaNs/Infs.

In [None]:
@jax.jit
def forward(params, x):
    return model.apply(params, x)

output = forward(params, dummy_input)
print("Output shape:", output.shape)

if jnp.isnan(output).any() or jnp.isinf(output).any():
    print("Stability Check FAILED: NaNs or Infs detected.")
else:
    print("Stability Check PASSED.")

## 3. MFU Benchmarking: Custom vs ott-jax

Compare the performance of our JAX implementation vs ott-jax Sinkhorn.

In [None]:
try:
    import ott
    from ott.geometry import geometry
    from ott.problems.linear import linear_problem
    from ott.solvers.linear import sinkhorn

    print("ott-jax imported successfully.")
    
    def ott_sinkhorn_wrapper(log_a):
        # Simple wrapper to match our API approximately
        # ott expects (n, m) geometry
        # We simulate a problem where cost is -log_a
        geom = geometry.Geometry(cost_matrix=-log_a)
        prob = linear_problem.LinearProblem(geom)
        solver = sinkhorn.Sinkhorn()
        out = solver(prob)
        return out.matrix

    # Benchmark setup
    N = 1024
    matrix = jax.random.normal(key, (N, N))
    
    # JAX Custom
    start = time.time()
    _ = jax.block_until_ready(sinkhorn_knopp(matrix))
    end = time.time()
    print(f"Custom JAX Sinkhorn Time: {end - start:.4f}s")
    
    # ott-jax
    # Note: ott-jax compilation might take time first run
    start = time.time()
    _ = jax.block_until_ready(ott_sinkhorn_wrapper(matrix))
    end = time.time()
    print(f"ott-jax Sinkhorn Time: {end - start:.4f}s")

except ImportError:
    print("ott-jax not installed.")