In [24]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
import os
# Configure JAX to use TPU
os.environ['JAX_PLATFORMS'] = 'tpu'

import jax
import jax.numpy as jnp
import math
import time
import sys
import copy
import copy

import flax.linen as nn
from ml_collections import config_dict

# Add the md4 module to the path
sys.path.append('/home/wuhao/md4')

from md4.configs.md4 import molecular
from md4.models import utils as model_utils

print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")
print(f"Number of TPU devices: {len(jax.devices())}")

JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
JAX default backend: tpu
Number of TPU devices: 1


In [2]:
# Load the molecular configuration
config = molecular.get_config()
config.vocab_size = 1023
print("Configuration:")
print(f"  Model type: {config.model_type}")
print(f"  Dataset: {config.dataset}")
print(f"  Vocab size: {config.vocab_size}")
print(f"  Max length: {config.max_length}")
print(f"  Feature dim: {config.feature_dim}")
print(f"  Number of layers: {config.n_layers}")
print(f"  Number of heads: {config.num_heads}")
print(f"  Dropout rate: {config.dropout_rate}")
print(f"  Fingerprint dim: {config.fingerprint_dim}")
print(f"  Timesteps: {config.timesteps}")
print(f"  Batch size: {config.batch_size}")

# Create the model
model = model_utils.get_model(config)
print(f"\nModel created: {type(model)}")
print(f"Model: {model}")

Configuration:
  Model type: md4
  Dataset: pubchem_large
  Vocab size: 1023
  Max length: 128
  Feature dim: 64
  Number of layers: 12
  Number of heads: 12
  Dropout rate: 0.02
  Fingerprint dim: 2048
  Timesteps: 1000
  Batch size: 1024

Model created: <class 'md4.models.diffusion.md4.MD4'>
Model: MD4(
    # attributes
    data_shape = (128,)
    cont_time = True
    timesteps = 1000
    feature_dim = 64
    num_heads = 12
    antithetic_time_sampling = True
    n_layers = 12
    n_dit_layers = 0
    dit_num_heads = 12
    dit_hidden_size = 768
    ch_mult = (1,)
    vocab_size = 1023
    noise_schedule_type = 'cosine'
    dropout_rate = 0.02
    use_attn_dropout = True
    mlp_type = 'swiglu'
    depth_scaled_init = True
    cond_type = 'adaln_zero'
    outside_embed = True
    time_features = 't'
    classes = -1
    sampler = 'ancestral'
    sampling_grid = 'uniform'
    topp = 0.98
    model_sharding = False
    fingerprint_dim = 2048
    atom_type_size = 0
)


In [3]:
# Initialize the model with dummy data
batch_size = 8  # Use smaller batch size for performance testing
seq_length = config.max_length

# Create dummy inputs
rng = jax.random.PRNGKey(42)
rng, sample_rng, init_rng = jax.random.split(rng, 3)

# Input shape for molecular data (SMILES tokens)
dummy_input = jnp.ones((batch_size, seq_length), dtype="int32")

# Create conditioning (fingerprint)
conditioning = {
    "fingerprint": jnp.zeros((batch_size, config.fingerprint_dim), dtype="int32"),
}

print(f"Input shape: {dummy_input.shape}")
print(f"Conditioning fingerprint shape: {conditioning['fingerprint'].shape}")

# Initialize the model
print("\nInitializing model parameters...")
output, variables = model.init_with_output(
    {"sample": sample_rng, "params": init_rng},
    dummy_input,
    cond=conditioning,
    train=False,
)

params = variables["params"]
state = {k: v for k, v in variables.items() if k != "params"}

print(f"Model output keys: {list(output.keys())}")
print(f"Output loss: {output.get('loss', 'N/A')}")
print(f"Number of parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params)):,}")
print("Model initialized successfully!")

Input shape: (8, 128)
Conditioning fingerprint shape: (8, 2048)

Initializing model parameters...
Model output keys: ['loss', 'loss_diff', 'loss_prior', 'loss_recon']
Output loss: 8.797445297241211
Number of parameters: 91,821,120
Model initialized successfully!


In [None]:
from clu import parameter_overview
print("\nParameter overview:")
overview = parameter_overview.get_parameter_overview(params)
print(overview)

In [None]:
# Performance testing - Train step timing using actual train.py functions
print("=" * 60)
print("PERFORMANCE TESTING - ACTUAL TRAIN STEP")
print("=" * 60)

import optax
import functools
from md4.train import TrainState, get_learning_rate, create_metrics_class_from_keys, loss_fn

# Create learning rate schedule function like in actual training
num_train_steps = 10000  # Dummy value for schedule
schedule_fn = functools.partial(
    get_learning_rate,
    base_learning_rate=config.learning_rate,
    num_steps=num_train_steps,
    warmup_steps=config.warmup_steps,
    schedule_type=getattr(config, 'learning_rate_schedule', 'cosine'),
)

# Create optimizer exactly like in actual training
optimizer = optax.chain(
    optax.clip(config.clip) if config.clip > 0.0 else optax.identity(),
    optax.adamw(
        schedule_fn,
        b1=0.9,
        b2=config.b2,
        weight_decay=config.weight_decay,
    ),
)

# Create TrainState exactly like in actual training
train_state = TrainState(
    step=0,
    rng=jax.random.PRNGKey(42),
    params=params,
    ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
    opt_state=optimizer.init(params),
    state=state,
)

# Create metrics class like in actual training
# Get output keys from a dummy forward pass
dummy_output, _ = model.init_with_output(
    {"sample": jax.random.PRNGKey(0), "params": jax.random.PRNGKey(1)},
    dummy_input,
    cond=conditioning,
    train=False,
)
metric_keys = sorted(list(dummy_output.keys()) + ["learning_rate"])
train_metrics_class = create_metrics_class_from_keys(metric_keys)

print(f"Created TrainState with step: {train_state.step}")
print(f"Optimizer state initialized: {train_state.opt_state is not None}")
print(f"Metrics class keys: {metric_keys}")

# Create a simplified train step function without pmap operations
@jax.jit
def train_step_fn(train_state, batch):
    """Simplified train step function without pmap for performance testing."""
    rng, new_rng = jax.random.split(train_state.rng)
    # Remove the pmap-specific fold_in operation
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (new_state, metrics_dict)), grads = grad_fn(
        train_state.params, train_state.state, rng, model, batch, train=True
    )
    
    # Apply optimizer updates (without pmean for gradients)
    updates, new_opt_state = optimizer.update(
        grads, train_state.opt_state, train_state.params
    )
    new_params = optax.apply_updates(train_state.params, updates)
    
    # Handle EMA if configured
    ema_rate = getattr(config, 'ema_rate', 0.0)
    if ema_rate > 0.0:
        new_ema_params = jax.tree_util.tree_map(
            lambda x, y: x + (1.0 - ema_rate) * (y - x),
            train_state.ema_params,
            new_params,
        )
    else:
        new_ema_params = None
        
    new_train_state = train_state.replace(
        step=train_state.step + 1,
        rng=new_rng,
        params=new_params,
        ema_params=new_ema_params,
        opt_state=new_opt_state,
        state=new_state,
    )

    
    return new_train_state

# Performance testing with different batch sizes
batch_sizes = [64, 128]  # Start with larger batch sizes since this is more realistic
num_runs = 10

print(f"\nTesting actual train step performance:")
print(f"Sequence length: {seq_length}")
print(f"Number of runs per batch size: {num_runs}")
print("-" * 60)

results = []

for bs in batch_sizes:
    print(f"\nTesting batch size {bs}...")
    
    # Create batch in the format expected by train.py
    batch = {
        "smiles": jnp.ones((bs, seq_length), dtype="int32"),
        "fingerprint": jnp.zeros((bs, config.fingerprint_dim), dtype="int32"),
    }
    
    # Reset train state for each batch size
    test_train_state = TrainState(
        step=0,
        rng=jax.random.PRNGKey(42),
        params=params,
        ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
        opt_state=optimizer.init(params),
        state=state,
    )
    
    # Warm up for this batch size (important for JIT compilation)
    print(f"  Warming up for batch size {bs}...")
    for _ in range(3):
        test_train_state = train_step_fn(
            train_state=test_train_state, 
            batch=batch
        )
    
    # Reset for timing
    test_train_state = TrainState(
        step=0,
        rng=jax.random.PRNGKey(42),
        params=params,
        ema_params=copy.deepcopy(params) if getattr(config, 'ema_rate', 0.0) > 0.0 else None,
        opt_state=optimizer.init(params),
        state=state,
    )
    
    # Time multiple runs
    print(f"  Running {num_runs} timed iterations...")
    times = []
    for i in range(num_runs):
        start_time = time.time()
        new_train_state = train_step_fn(
            train_state=test_train_state,
            batch=batch
        )
        # Block until computation is complete
        jax.block_until_ready([new_train_state])
        end_time = time.time()
        times.append(end_time - start_time)
        
        # Use updated state for next iteration (realistic training simulation)
        test_train_state = new_train_state
    
    avg_time = sum(times) / len(times)
    min_time = min(times)
    max_time = max(times)
    
    # Calculate throughput
    samples_per_sec = bs / avg_time
    tokens_per_sec = bs * seq_length / avg_time
    
    print(f"Batch size {bs:3d}: {avg_time*1000:6.2f}ms avg ({min_time*1000:5.2f}-{max_time*1000:5.2f}ms) | "
          f"{samples_per_sec:6.1f} samples/sec | {tokens_per_sec:8.0f} tokens/sec")
    
    results.append({
        'batch_size': bs,
        'avg_time': avg_time,
        'min_time': min_time,
        'max_time': max_time,
        'samples_per_sec': samples_per_sec,
        'tokens_per_sec': tokens_per_sec
    })

print("-" * 60)

PERFORMANCE TESTING - ACTUAL TRAIN STEP
Created TrainState with step: 0
Optimizer state initialized: True
Metrics class keys: ['learning_rate', 'loss', 'loss_diff', 'loss_prior', 'loss_recon']

Testing actual train step performance:
Sequence length: 128
Number of runs per batch size: 10
------------------------------------------------------------

Testing batch size 32...
  Warming up for batch size 32...


In [15]:
# Analyze and summarize results
print("\nPERFORMANCE SUMMARY")
print("=" * 60)

if results:
    print(f"Model: {config.model_type} with {config.n_layers} layers, {config.num_heads} heads")
    print(f"Sequence length: {seq_length}, Vocab size: {config.vocab_size}")
    print(f"Feature dimension: {config.feature_dim}, Fingerprint dimension: {config.fingerprint_dim}")
    print(f"Total parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params)):,}")
    print(f"Device: {jax.devices()[0]}")
    print()
    
    # Find optimal batch size (highest throughput)
    best_throughput = max(results, key=lambda x: x['tokens_per_sec'])
    print(f"Optimal batch size: {best_throughput['batch_size']} "
          f"({best_throughput['tokens_per_sec']:.0f} tokens/sec)")
    
    # Memory efficiency analysis
    print("\nScaling analysis:")
    for i in range(1, len(results)):
        prev_bs = results[i-1]['batch_size']
        curr_bs = results[i]['batch_size']
        prev_tps = results[i-1]['tokens_per_sec']
        curr_tps = results[i]['tokens_per_sec']
        
        bs_ratio = curr_bs / prev_bs
        tps_ratio = curr_tps / prev_tps
        efficiency = tps_ratio / bs_ratio
        
        print(f"  {prev_bs} -> {curr_bs}: {efficiency:.2f}x efficiency "
              f"({tps_ratio:.2f}x throughput for {bs_ratio:.1f}x batch size)")
    
    print("\nRecommendations:")
    if best_throughput['batch_size'] < max(r['batch_size'] for r in results):
        print("- Consider using smaller batch sizes for better efficiency")
    else:
        print("- Larger batch sizes provide better throughput")
    
    print(f"- For single inference: ~{results[0]['avg_time']*1000:.1f}ms per sample")
    print(f"- For batch processing: use batch size {best_throughput['batch_size']} "
          f"for {best_throughput['tokens_per_sec']:.0f} tokens/sec")
else:
    print("No performance results available.")

print("=" * 60)


PERFORMANCE SUMMARY
Model: md4 with 12 layers, 12 heads
Sequence length: 128, Vocab size: 1024
Feature dimension: 64, Fingerprint dimension: 2048
Total parameters: 91,821,952
Device: TPU_0(process=0,(0,0,0,0))

Optimal batch size: 64 (1305923 tokens/sec)

Scaling analysis:
  1 -> 4: 0.90x efficiency (3.60x throughput for 4.0x batch size)
  4 -> 8: 0.88x efficiency (1.77x throughput for 2.0x batch size)
  8 -> 16: 0.85x efficiency (1.70x throughput for 2.0x batch size)
  16 -> 32: 0.69x efficiency (1.38x throughput for 2.0x batch size)
  32 -> 64: 0.51x efficiency (1.02x throughput for 2.0x batch size)
  64 -> 128: 0.36x efficiency (0.72x throughput for 2.0x batch size)
  128 -> 256: 0.43x efficiency (0.85x throughput for 2.0x batch size)

Recommendations:
- Consider using smaller batch sizes for better efficiency
- For single inference: ~1.5ms per sample
- For batch processing: use batch size 64 for 1305923 tokens/sec
