In [1]:
# Disable autoreload during performance testing to avoid duplicate executions
%load_ext autoreload
%autoreload 0

# Note: Autoreload is disabled to prevent interference with performance measurements
print("Autoreload disabled for clean performance testing")

Autoreload disabled for clean performance testing


In [1]:
import os
# Configure JAX to use TPU
os.environ['JAX_PLATFORMS'] = 'tpu'

import jax
import jax.numpy as jnp
import math
import time
import sys
import copy

import flax.linen as nn
from ml_collections import config_dict

# Add the md4 module to the path
sys.path.append('/home/wuhao/md4')

from md4.configs.md4 import molecular
from md4.models import utils as model_utils

print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")
print(f"Number of TPU devices: {len(jax.devices())}")

2025-07-24 14:49:44.787611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753368584.804220  305969 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753368584.809473  305969 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753368584.822999  305969 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753368584.823013  305969 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753368584.823015  305969 computation_placer.cc:177] computation placer alr

JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
JAX default backend: tpu
Number of TPU devices: 1


In [None]:
# Load the molecular configuration
config = molecular.get_config()
config.vocab_size = 1024
config.feature_dim = 64
config.num_heads = 12
config.n_layers = 12
print("Configuration:")
print(f"  Model type: {config.model_type}")
print(f"  Dataset: {config.dataset}")
print(f"  Vocab size: {config.vocab_size}")
print(f"  Max length: {config.max_length}")
print(f"  Feature dim: {config.feature_dim}")
print(f"  Number of layers: {config.n_layers}")
print(f"  Number of heads: {config.num_heads}")
print(f"  Dropout rate: {config.dropout_rate}")
print(f"  Fingerprint dim: {config.fingerprint_dim}")
print(f"  Timesteps: {config.timesteps}")
print(f"  Batch size: {config.batch_size}")

# Create the model
model = model_utils.get_model(config)
print(f"\nModel created: {type(model)}")
print(f"Model: {model}")

Configuration:
  Model type: md4
  Dataset: pubchem_large
  Vocab size: 1024
  Max length: 128
  Feature dim: 64
  Number of layers: 12
  Number of heads: 12
  Dropout rate: 0.02
  Fingerprint dim: 2048
  Timesteps: 1000
  Batch size: 1024

Model created: <class 'md4.models.diffusion.md4.MD4'>
Model: MD4(
    # attributes
    data_shape = (128,)
    cont_time = True
    timesteps = 1000
    feature_dim = 64
    num_heads = 12
    antithetic_time_sampling = True
    n_layers = 12
    n_dit_layers = 0
    dit_num_heads = 12
    dit_hidden_size = 768
    ch_mult = (1,)
    vocab_size = 1024
    noise_schedule_type = 'cosine'
    dropout_rate = 0.02
    use_attn_dropout = True
    mlp_type = 'swiglu'
    depth_scaled_init = True
    cond_type = 'adaln_zero'
    outside_embed = True
    time_features = 't'
    classes = -1
    sampler = 'ancestral'
    sampling_grid = 'uniform'
    topp = 0.98
    model_sharding = False
    fingerprint_dim = 2048
    atom_type_size = 0
)


In [None]:
# Initialize the model with dummy data
batch_size = 8  # Use smaller batch size for performance testing
seq_length = config.max_length

# Create dummy inputs
rng = jax.random.PRNGKey(42)
rng, sample_rng, init_rng = jax.random.split(rng, 3)

# Input shape for molecular data (SMILES tokens)
dummy_input = jnp.ones((batch_size, seq_length), dtype="int32")

# Create conditioning (fingerprint)
conditioning = {
    "fingerprint": jnp.zeros((batch_size, config.fingerprint_dim), dtype="int32"),
}

print(f"Input shape: {dummy_input.shape}")
print(f"Conditioning fingerprint shape: {conditioning['fingerprint'].shape}")

# Initialize the model
print("\nInitializing model parameters...")
output, variables = model.init_with_output(
    {"sample": sample_rng, "params": init_rng},
    dummy_input,
    cond=conditioning,
    train=False,
)

params = variables["params"]
state = {k: v for k, v in variables.items() if k != "params"}

print(f"Model output keys: {list(output.keys())}")
print(f"Output loss: {output.get('loss', 'N/A')}")
print(f"Number of parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params)):,}")
print("Model initialized successfully!")

Input shape: (8, 128)
Conditioning fingerprint shape: (8, 2048)

Initializing model parameters...
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 128, 128)
xq: (8, 12, 128, 64), xk: (8, 12, 128, 64), xv: (8, 12, 128, 64)
scores: (8, 12, 12

In [4]:
from clu import parameter_overview
print("\nParameter overview:")
overview = parameter_overview.get_parameter_overview(params)
print(overview)


Parameter overview:
+---------------------------------------------------------------------+--------------+---------+-----------+-----------+---------+
| Name                                                                | Shape        | Dtype   | Size      | Mean      | Std     |
+---------------------------------------------------------------------+--------------+---------+-----------+-----------+---------+
| classifier/CondEmbedding_0/Dense_0/bias                             | (64,)        | float32 | 64        | 0.0       | 0.0     |
| classifier/CondEmbedding_0/Dense_0/kernel                           | (256, 64)    | float32 | 16,384    | -0.000179 | 0.0625  |
| classifier/CondEmbedding_0/dense0/bias                              | (256,)       | float32 | 256       | 0.0       | 0.0     |
| classifier/CondEmbedding_0/dense0/kernel                            | (128, 256)   | float32 | 32,768    | 0.000325  | 0.0882  |
| classifier/Embed_0/embedding                                

In [18]:
# Performance testing - Train step timing using actual train.py functions
print("=" * 60)
print("PERFORMANCE TESTING - ACTUAL TRAIN STEP")
print("=" * 60)

import optax
import functools
from md4.train import TrainState, get_learning_rate, create_metrics_class_from_keys, loss_fn

# Create learning rate schedule function like in actual training
num_train_steps = 10000  # Dummy value for schedule
schedule_fn = functools.partial(
    get_learning_rate,
    base_learning_rate=config.learning_rate,
    num_steps=num_train_steps,
    warmup_steps=config.warmup_steps,
    schedule_type=getattr(config, 'learning_rate_schedule', 'cosine'),
)

# Create optimizer exactly like in actual training
optimizer = optax.chain(
    optax.clip(config.clip) if config.clip > 0.0 else optax.identity(),
    optax.adamw(
        schedule_fn,
        b1=0.9,
        b2=config.b2,
        weight_decay=config.weight_decay,
    ),
)

# Create TrainState exactly like in actual training
train_state = TrainState(
    step=0,
    rng=jax.random.PRNGKey(42),
    params=params,
    ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
    opt_state=optimizer.init(params),
    state=state,
)

# Create metrics class like in actual training
# Get output keys from a dummy forward pass
dummy_output, _ = model.init_with_output(
    {"sample": jax.random.PRNGKey(0), "params": jax.random.PRNGKey(1)},
    dummy_input,
    cond=conditioning,
    train=False,
)
metric_keys = sorted(list(dummy_output.keys()) + ["learning_rate"])
train_metrics_class = create_metrics_class_from_keys(metric_keys)

print(f"Created TrainState with step: {train_state.step}")
print(f"Optimizer state initialized: {train_state.opt_state is not None}")
print(f"Metrics class keys: {metric_keys}")

# Create a simplified train step function without pmap operations
@jax.jit
def train_step_fn(train_state, batch):
    """Simplified train step function without pmap for performance testing."""
    rng, new_rng = jax.random.split(train_state.rng)
    # Remove the pmap-specific fold_in operation
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (new_state, metrics_dict)), grads = grad_fn(
        train_state.params, train_state.state, rng, model, batch, train=True
    )
    
    # Apply optimizer updates (without pmean for gradients)
    updates, new_opt_state = optimizer.update(
        grads, train_state.opt_state, train_state.params
    )
    new_params = optax.apply_updates(train_state.params, updates)
    
    # Handle EMA if configured
    ema_rate = getattr(config, 'ema_rate', 0.0)
    if ema_rate > 0.0:
        new_ema_params = jax.tree_util.tree_map(
            lambda x, y: x + (1.0 - ema_rate) * (y - x),
            train_state.ema_params,
            new_params,
        )
    else:
        new_ema_params = None
        
    new_train_state = train_state.replace(
        step=train_state.step + 1,
        rng=new_rng,
        params=new_params,
        ema_params=new_ema_params,
        opt_state=new_opt_state,
        state=new_state,
    )

    
    return new_train_state

# Performance testing with different batch sizes - PROFILING ONLY
batch_sizes = [256]  # Test various batch sizes
num_profile_steps = 100  # Number of steps to profile for each batch size

print(f"\nProfiling actual train step performance:")
print(f"Sequence length: {seq_length}")
print(f"Number of profiling steps per batch size: {num_profile_steps}")
print("-" * 60)

profile_results = []

for bs in batch_sizes:
    print(f"\nProfiling batch size {bs}...")
    
    # Create batch in the format expected by train.py
    batch = {
        "smiles": jnp.ones((bs, seq_length), dtype="int32"),
        "fingerprint": jnp.zeros((bs, config.fingerprint_dim), dtype="int32"),
    }
    
    # Reset train state for each batch size
    test_train_state = TrainState(
        step=0,
        rng=jax.random.PRNGKey(42),
        params=params,
        ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
        opt_state=optimizer.init(params),
        state=state,
    )
    
    # Warm up for this batch size (important for JIT compilation)
    print(f"  Warming up for batch size {bs}...")
    for _ in range(3):
        test_train_state = train_step_fn(
            train_state=test_train_state, 
            batch=batch
        )
        # Block to ensure compilation is complete
        jax.block_until_ready([test_train_state])
    
    # Reset train state after warmup for realistic profiling
    test_train_state = TrainState(
        step=0,
        rng=jax.random.PRNGKey(42),
        params=params,
        ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
        opt_state=optimizer.init(params),
        state=state,
    )
    
    # Capture profiler trace for this batch size
    profile_dir = f"/home/wuhao/md4/profilenew/profile-data-bs-{bs}"
    print(f"  Capturing profiler trace to {profile_dir}...")
    
    with jax.profiler.trace(profile_dir):
        # Run realistic training steps with profiling
        for i in range(num_profile_steps):
            new_train_state = train_step_fn(
                train_state=test_train_state,
                batch=batch
            )
            # Use updated state for next iteration (realistic training simulation)
            test_train_state = new_train_state

        jax.block_until_ready([test_train_state])
    print(f"  Profiler trace saved to {profile_dir}")
    print(f"  Final step: {test_train_state.step}")
    
    profile_results.append({
        'batch_size': bs,
        'profile_dir': profile_dir,
        'final_step': test_train_state.step
    })

print("-" * 60)
print("\nProfiler traces captured:")
for result in profile_results:
    print(f"  Batch size {result['batch_size']}: {result['profile_dir']} (steps: {result['final_step']})")

print("\nYou can analyze these traces using TensorBoard:")
for result in profile_results:
    print(f"  tensorboard --logdir {result['profile_dir']}")

print("\nOr open all traces in one TensorBoard session:")
print("  tensorboard --logdir /tmp --logdir_spec bs128:/tmp/profile-data-bs-128,bs256:/tmp/profile-data-bs-256,bs512:/tmp/profile-data-bs-512")

PERFORMANCE TESTING - ACTUAL TRAIN STEP
Created TrainState with step: 0
Optimizer state initialized: True
Metrics class keys: ['learning_rate', 'loss', 'loss_diff', 'loss_prior', 'loss_recon']

Profiling actual train step performance:
Sequence length: 128
Number of profiling steps per batch size: 100
------------------------------------------------------------

Profiling batch size 256...
  Warming up for batch size 256...
  Capturing profiler trace to /home/wuhao/md4/profilenew/profile-data-bs-256...
  Profiler trace saved to /home/wuhao/md4/profilenew/profile-data-bs-256
  Final step: 100
------------------------------------------------------------

Profiler traces captured:
  Batch size 256: /home/wuhao/md4/profilenew/profile-data-bs-256 (steps: 100)

You can analyze these traces using TensorBoard:
  tensorboard --logdir /home/wuhao/md4/profilenew/profile-data-bs-256

Or open all traces in one TensorBoard session:
  tensorboard --logdir /tmp --logdir_spec bs128:/tmp/profile-data-bs-

In [20]:
# Profile analysis summary
print("\nPROFILING SUMMARY")
print("=" * 60)

if 'profile_results' in locals() and profile_results:
    print(f"Model: {config.model_type} with {config.n_layers} layers, {config.num_heads} heads")
    print(f"Sequence length: {seq_length}, Vocab size: {config.vocab_size}")
    print(f"Feature dimension: {config.feature_dim}, Fingerprint dimension: {config.fingerprint_dim}")
    print(f"Total parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params)):,}")
    print(f"Device: {jax.devices()[0]}")
    print()
    
    print("Profiling completed for the following configurations:")
    for result in profile_results:
        print(f"  Batch size {result['batch_size']:3d}: {result['final_step']} training steps profiled")
        print(f"    Profile data: {result['profile_dir']}")
    
    print("\nProfile Analysis Instructions:")
    print("1. Open TensorBoard to analyze the traces:")
    for result in profile_results:
        print(f"   tensorboard --logdir {result['profile_dir']}  # For batch size {result['batch_size']}")
    
    print("\n2. Compare all batch sizes in one view:")
    logdir_specs = ",".join([f"bs{r['batch_size']}:{r['profile_dir']}" for r in profile_results])
    print(f"   tensorboard --logdir_spec {logdir_specs}")
    
    print("\n3. Key metrics to analyze in TensorBoard:")
    print("   - Trace Viewer: Step-by-step execution timeline")
    print("   - Overview Page: Operation statistics and recommendations")
    print("   - Memory Profile: Memory usage patterns")
    print("   - Kernel Stats: GPU/TPU kernel execution times")
    
    print("\n4. What to look for:")
    print("   - Memory bandwidth utilization")
    print("   - Compute vs memory bound operations")
    print("   - Batch size scaling efficiency")
    print("   - Gradient computation vs parameter update ratios")
    
    # Calculate theoretical metrics
    total_tokens_per_batch = {r['batch_size']: r['batch_size'] * seq_length for r in profile_results}
    print(f"\nBatch configurations:")
    for result in profile_results:
        tokens = total_tokens_per_batch[result['batch_size']]
        print(f"  Batch size {result['batch_size']:3d}: {tokens:6,} tokens per step")
    
    print(f"\nRealistic training simulation completed:")
    print(f"- Each profile includes {num_profile_steps} actual training steps")
    print(f"- State updates and optimizer steps are realistic")
    print(f"- Gradients computed and parameters updated each step")
    
else:
    print("No profiling results available. Run the profiling cell first.")

print("=" * 60)


PERFORMANCE SUMMARY
Model: md4 with 12 layers, 12 heads
Sequence length: 128, Vocab size: 1023
Feature dimension: 64, Fingerprint dimension: 2048
Total parameters: 91,821,120
Device: TPU_0(process=0,(0,0,0,0))

Optimal batch size: 128 (211032 tokens/sec)

Scaling analysis:
  128 -> 256: 0.48x efficiency (0.96x throughput for 2.0x batch size)

Recommendations:
- Consider using smaller batch sizes for better efficiency
- For single inference: ~77.6ms per sample
- For batch processing: use batch size 128 for 211032 tokens/sec
