In [1]:
import jax
import jax.numpy as jnp
import time
from flax import linen as nn

# Function to benchmark matrix multiplication
def benchmark_matmul(dtype, size=(1024, 1024)):
    # Generate random matrices of the given dtype
    key = jax.random.PRNGKey(0)
    A = jax.random.normal(key, size, dtype=dtype)
    B = jax.random.normal(key, size, dtype=dtype)

    # JIT compile the matrix multiplication for speed
    matmul_fn = jax.jit(lambda x, y: jnp.dot(x, y))

    # Warm-up to ensure accurate timing
    matmul_fn(A, B).block_until_ready()

    # Measure execution time
    start_time = time.time()
    matmul_fn(A, B).block_until_ready()
    end_time = time.time()

    return end_time - start_time

# Matrix size
matrix_size = (100, 1024, 1024)

# Benchmark float32
f32_time = benchmark_matmul(jnp.float32, size=matrix_size)
print(f"float32 time: {f32_time:.6f} seconds")

# Benchmark float16
f16_time = benchmark_matmul(jnp.float16, size=matrix_size)
print(f"float16 time: {f16_time:.6f} seconds")

# Benchmark bfloat16
bf16_time = benchmark_matmul(jnp.bfloat16, size=matrix_size)
print(f"bfloat16 time: {bf16_time:.6f} seconds")

# Print the comparison
if bf16_time < f32_time:
    print(f"bfloat16 is faster by {f32_time - bf16_time:.6f} seconds")
else:
    print(f"float32 is faster by {bf16_time - f32_time:.6f} seconds")


2025-01-22 07:14:42.021961: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 18.53GiB (19901340057 bytes) by rematerialization; only reduced to 39.46GiB (42366664720 bytes), down from 39.46GiB (42366664720 bytes) originally


float32 time: 0.192535 seconds
float16 time: 0.124722 seconds
bfloat16 time: 0.093090 seconds
bfloat16 is faster by 0.099445 seconds


In [29]:
# Function to benchmark a simple MLP network
def benchmark_mlp(dtype, input_size=(1024, 512), hidden_size=512, output_size=256, num_layers=200):
    # Generate random input of the given dtype
    key = jax.random.PRNGKey(0)
    inputs = jax.random.normal(key, input_size, dtype=dtype)

    # Initialize random weights for the MLP
    params = []
    in_features = input_size[1]
    for _ in range(num_layers):
        W = jax.random.normal(key, (in_features, hidden_size), dtype=dtype)
        b = jax.random.normal(key, (hidden_size,), dtype=dtype)
        params.append((W, b))
        in_features = hidden_size
    W_out = jax.random.normal(key, (hidden_size, output_size), dtype=dtype)
    b_out = jax.random.normal(key, (output_size,), dtype=dtype)
    params.append((W_out, b_out))

    # Define the MLP forward function
    def mlp_forward(x, params):
        for W, b in params[:-1]:
            x = jnp.dot(x, W) + b
            x = jax.nn.relu(x)
        W_out, b_out = params[-1]
        return jnp.dot(x, W_out) + b_out

    # JIT compile the MLP function
    mlp_fn = jax.jit(lambda x: mlp_forward(x, params))

    # Warm-up to ensure accurate timing
    mlp_fn(inputs).block_until_ready()

    # Measure execution time
    start_time = time.time()
    mlp_fn(inputs).block_until_ready()
    end_time = time.time()

    return end_time - start_time

# Input size
input_size = (1024, 512)

# Benchmark float32
f32_time = benchmark_mlp(jnp.float32, input_size=input_size)
print(f"float32 time: {f32_time:.6f} seconds")

# Benchmark bfloat16
bf16_time = benchmark_mlp(jnp.bfloat16, input_size=input_size)
print(f"bfloat16 time: {bf16_time:.6f} seconds")

# Benchmark bfloat16
f16_time = benchmark_mlp(jnp.float16, input_size=input_size)
print(f"float16 time: {f16_time:.6f} seconds")

# Print the comparison
print(f"bfloat16 takes {bf16_time / f32_time:.6f}x time")

float32 time: 0.004519 seconds
bfloat16 time: 0.002950 seconds
float16 time: 0.002773 seconds
bfloat16 takes 0.652896x time


In [2]:
# Define the MLP model using Flax
class MLP(nn.Module):
    hidden_size: int
    output_size: int
    num_layers: int

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_size)(x)
            x = nn.relu(x)
            x = nn.LayerNorm(epsilon=1e-4, use_bias=False, use_scale=False)(x)
            
        x = nn.Dense(self.output_size)(x)    
        return x

# Function to benchmark a simple MLP network
def benchmark_mlp_with_flax(dtype, input_size=(1024, 512), hidden_size=512, output_size=256, num_layers=200):
    # Generate random input of the given dtype
    key = jax.random.PRNGKey(0)
    inputs = jax.random.normal(key, input_size, dtype=dtype)

    # Initialize the MLP model
    model = MLP(hidden_size=hidden_size, output_size=output_size, num_layers=num_layers)
    params = model.init(key, inputs)  # Initialize parameters

    params = jax.tree_map(lambda x: x.astype(dtype), params)  # Convert to the given dtype

    # JIT compile the MLP function
    mlp_fn = jax.jit(lambda x: model.apply(params, x))

    # Warm-up to ensure accurate timing
    mlp_fn(inputs).block_until_ready()

    # Measure execution time
    start_time = time.time()
    mlp_fn(inputs).block_until_ready()
    end_time = time.time()

    return end_time - start_time

# Input size
input_size = (1024, 512)

# Benchmark float32
f32_time = benchmark_mlp_with_flax(jnp.float32, input_size=input_size)
print(f"float32 time: {f32_time:.6f} seconds")

# Benchmark bfloat16
bf16_time = benchmark_mlp_with_flax(jnp.bfloat16, input_size=input_size)
print(f"bfloat16 time: {bf16_time:.6f} seconds")

# Benchmark float16
f16_time = benchmark_mlp_with_flax(jnp.float16, input_size=input_size)
print(f"float16 time: {f16_time:.6f} seconds")

# Print the comparison
print(f"bfloat16 takes {bf16_time / f32_time:.6f}x time")

  params = jax.tree_map(lambda x: x.astype(dtype), params)  # Convert to the given dtype


float32 time: 0.004996 seconds
bfloat16 time: 0.003548 seconds
float16 time: 0.003498 seconds
bfloat16 takes 0.710175x time


In [None]:
# Define a simple loss function
def loss_fn(params, model, inputs, targets):
    predictions = model.apply(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

# Function to benchmark a simple MLP network with gradient updates
def benchmark_mlp_with_grad_updates(dtype, input_size=(1024, 512), hidden_size=512, output_size=256, num_layers=200, num_updates=100):
    # Generate random input and targets of the given dtype
    key = jax.random.PRNGKey(0)
    key1, key2, key3 = jax.random.split(key, 3)
    inputs = jax.random.normal(key1, input_size, dtype=dtype)
    targets = jax.random.normal(key2, (input_size[0], output_size), dtype=dtype)

    # Initialize the MLP model
    model = MLP(hidden_size=hidden_size, output_size=output_size, num_layers=num_layers)
    params = model.init(key3, inputs)  # Initialize parameters
    params = jax.tree_map(lambda x: x.astype(dtype), params)  # Convert to the given dtype

    # Define the gradient update function
    def update(params, inputs, targets):
        grads = jax.grad(loss_fn)(params, model, inputs, targets)
        # import pdb; pdb.set_trace()
        new_params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)
        # new_params = params
        return grads, new_params

    # JIT compile the update function
    update_fn = jax.jit(update)

    # This should include some bfloat16 casts to float32 because we applied layernorm
    # print(jax.make_jaxpr(update_fn)(params, inputs, targets))

    # Warm-up to ensure accurate timing
    grads, params = update_fn(params, inputs, targets)

    # Measure execution time for multiple updates
    start_time = time.time()
    for _ in range(num_updates):
        grads, params = update_fn(params, inputs, targets)
    end_time = time.time()

    return end_time - start_time

# Input size
input_size = (1024, 512)

# This makes no difference
# jax.config.update("jax_default_matmul_precision", "bfloat16")

# Benchmark float32
f32_time = benchmark_mlp_with_grad_updates(jnp.float32, input_size=input_size)
print(f"float32 time: {f32_time:.6f} seconds")

# Benchmark bfloat16
f16_time = benchmark_mlp_with_grad_updates(jnp.float16, input_size=input_size)
print(f"float16 time: {f16_time:.6f} seconds")

# Benchmark bfloat16
bf16_time = benchmark_mlp_with_grad_updates(jnp.bfloat16, input_size=input_size)
print(f"bfloat16 time: {bf16_time:.6f} seconds")


# Print the comparison
print(f"bfloat16 takes {bf16_time / f32_time:.6f}x time")