# Floating Point Precision Comparison with JAX
This notebook demonstrates how to compare different floating point precisions in JAX including bf16, fp16, fp32, and fp64


In [69]:


import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
import psutil
import tracemalloc
import warnings
warnings.filterwarnings('ignore')

# Configure JAX
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

JAX version: 0.7.1
JAX backend: cpu
Available devices: [CpuDevice(id=0)]


Available precisions:
- fp16: 16 bits (2 bytes) - Exponent: 5, Mantissa: 10
- bf16: 16 bits (2 bytes) - Exponent: 8, Mantissa: 7
- fp32: 32 bits (4 bytes) - Exponent: 8, Mantissa: 23
- fp64: 64 bits (8 bytes) - Exponent: 11, Mantissa: 52

### NOTE: CPU calculations are much slower than GPU calculations. Results may vary based on the hardware.

In [70]:
# Memory and Performance Profiling Functions

def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process()
    return process.memory_info().rss / 1024 / 1024

def measure_memory_and_time(func):
    def wrapper(*args, **kwargs):
        # Start memory tracking
        tracemalloc.start()
        start_memory = get_memory_usage()
        
        # Time the function
        start_time = time.time()
        result = func(*args, **kwargs)
        jax.block_until_ready(result)  # Ensure computation is complete
        end_time = time.time()
        
        # Get memory after execution
        end_memory = get_memory_usage()
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        
        return {
            'result': result,
            'execution_time': end_time - start_time,
            'memory_delta': end_memory - start_memory,
            'peak_memory': peak / 1024 / 1024,  # Convert to MB
            'current_memory': current / 1024 / 1024  # Convert to MB
        }
    return wrapper




In [71]:
# Test Functions for Different Operations
@measure_memory_and_time
def matrix_multiplication_test(dtype, shape):
    """Test matrix multiplication with given precision"""
    key = jax.random.PRNGKey(42)
    key1, key2 = jax.random.split(key)
    
    def matmul_operation():
        a = jax.random.normal(key1, shape, dtype=dtype)
        b = jax.random.normal(key2, shape, dtype=dtype)
        return jnp.dot(a, b)
    
    return matmul_operation()

    
@measure_memory_and_time
def neural_network_forward_pass_test(dtype, input_size, hidden_size, output_size):
    """Test a simple neural network forward pass"""
    key = jax.random.PRNGKey(123)
    keys = jax.random.split(key, 3)
    
    def nn_forward():
        # Initialize weights
        W1 = jax.random.normal(keys[0], (input_size, hidden_size), dtype=dtype)
        b1 = jax.random.normal(keys[1], (hidden_size,), dtype=dtype)
        W2 = jax.random.normal(keys[2], (hidden_size, output_size), dtype=dtype)
        b2 = jax.random.normal(keys[2], (output_size,), dtype=dtype)
        
        # Input data
        x = jax.random.normal(keys[0], (input_size,), dtype=dtype)
        
        # Forward pass
        h = jnp.tanh(jnp.dot(x, W1) + b1)
        y = jnp.dot(h, W2) + b2
        
        return y
    
    return nn_forward()

@measure_memory_and_time
def precision_test(dtype, test_values):
    """Test numerical precision with specific values"""
    def precision_operation():
        # Convert test values to the specified dtype
        values = jnp.array(test_values, dtype=dtype)
        
        # Perform some operations that might reveal precision differences
        result1 = jnp.sum(values)
        result2 = jnp.mean(values)
        result3 = jnp.std(values)
        result4 = jnp.prod(values)
        
        return jnp.array([result1, result2, result3, result4])
    
    return precision_operation()


In [72]:
# Comprehensive Comparison Function
PRECISIONS = {
    'fp16': jnp.float16,
    'bf16': jnp.bfloat16, 
    'fp32': jnp.float32,
    'fp64': jnp.float64
}

def compare_precisions(operation_name, test_func, *args, **kwargs):
    """
    Compare all precisions for a given operation and return a summary table (pandas DataFrame).
    """
    import pandas as pd

    results = []
   
    for precision_name, dtype in PRECISIONS.items():
        # Run the test
        result = test_func(dtype, *args, **kwargs)
        row = {
            'Operation': operation_name,
            'Precision': precision_name,
            'Execution Time (s)': result.get('execution_time', None),
            'Memory Delta (MB)': result.get('memory_delta', None),
            'Peak Memory (MB)': result.get('peak_memory', None),
            'Result Shape': getattr(result.get('result', None), 'shape', None),
            'Result Dtype': str(getattr(result.get('result', None), 'dtype', None))
        }
    
        
        results.append(row)

    df = pd.DataFrame(results)
    return df
    

In [73]:


matrix_shape = (5000, 5000)
results_matrix_multiplication = compare_precisions(
    "Matrix Multiplication (1000x1000)", 
    matrix_multiplication_test, 
    matrix_shape
)

results_matrix_multiplication


Unnamed: 0,Operation,Precision,Execution Time (s),Memory Delta (MB),Peak Memory (MB),Result Shape,Result Dtype
0,Matrix Multiplication (1000x1000),fp16,1.024406,448.609375,0.013244,"(5000, 5000)",float16
1,Matrix Multiplication (1000x1000),bf16,0.97766,49.328125,0.008247,"(5000, 5000)",bfloat16
2,Matrix Multiplication (1000x1000),fp32,0.942557,-49.25,0.008247,"(5000, 5000)",float32
3,Matrix Multiplication (1000x1000),fp64,0.927667,7.84375,0.008503,"(5000, 5000)",float32


In [74]:
results_neural_network = compare_precisions(
    "Neural Network Forward Pass (784->256->10)", 
    neural_network_forward_pass_test, 
    784, 256, 10
)

results_neural_network

Unnamed: 0,Operation,Precision,Execution Time (s),Memory Delta (MB),Peak Memory (MB),Result Shape,Result Dtype
0,Neural Network Forward Pass (784->256->10),fp16,0.006846,11.59375,0.020218,"(10,)",float16
1,Neural Network Forward Pass (784->256->10),bf16,0.004057,4.8125,0.015495,"(10,)",bfloat16
2,Neural Network Forward Pass (784->256->10),fp32,0.003193,7.15625,0.015388,"(10,)",float32
3,Neural Network Forward Pass (784->256->10),fp64,0.002193,0.140625,0.015644,"(10,)",float32


In [75]:


# Test with values that might reveal precision differences
test_values = [0.1, 0.2, 0.3, 0.4, 0.5, 1e-6, 1e6, np.pi, np.e]

results_precision_test = compare_precisions(
    "Numerical Precision Test", 
    precision_test, 
    test_values
)


results_precision_test

Unnamed: 0,Operation,Precision,Execution Time (s),Memory Delta (MB),Peak Memory (MB),Result Shape,Result Dtype
0,Numerical Precision Test,fp16,0.004691,2.890625,0.022018,"(4,)",float16
1,Numerical Precision Test,bf16,0.010344,2.109375,0.016884,"(4,)",bfloat16
2,Numerical Precision Test,fp32,0.021163,1.734375,0.01709,"(4,)",float32
3,Numerical Precision Test,fp64,0.000902,0.03125,0.017176,"(4,)",float32


In [86]:

# Example of mixed precision training

# Create a simple model
key = jax.random.PRNGKey(42)
key1, key2, key3 = jax.random.split(key, 3)

# Model parameters in FP32 (for stability)
W = jax.random.normal(key1, (10, 5), dtype=jnp.float32)
b = jax.random.normal(key2, (5,), dtype=jnp.float32)

# Input data in bf16 (for memory efficiency)
x = jax.random.normal(key3, (10,), dtype=jnp.bfloat16)

@measure_memory_and_time
def forward_pass(x, W, b):
    # Convert to FP64 for computation
    x_fp32 = x.astype(jnp.float64)
    W_fp32 = W.astype(jnp.float64)
    b_fp32 = b.astype(jnp.float64)
    
    # Forward pass
    y = jnp.dot(x_fp32, W_fp32) + b_fp32
    
    # Convert back to FP16 for memory efficiency
    return y.astype(jnp.float16)

# Run forward pass
result = forward_pass(x, W, b)
print(result)

print(f"Input dtype: {x.dtype}")
print(f"Weight dtype: {W.dtype}")
print(f"Output dtype: {result['result'].dtype}") 
print(f"Output shape: {result['result'].shape}")
print(f"Memory usage:")
print(f"  Input: {x.nbytes} bytes")
print(f"  Weights: {W.nbytes + b.nbytes} bytes")
print(f"  Output: {result['result'].nbytes} bytes")
print(f"  Total: {x.nbytes + W.nbytes + b.nbytes + result['result'].nbytes} bytes")



{'result': Array([ 0.7227,  4.688 , -0.503 ,  3.291 , -6.97  ], dtype=float16), 'execution_time': 0.03721189498901367, 'memory_delta': 18.4375, 'peak_memory': 0.041400909423828125, 'current_memory': 0.03329658508300781}
Input dtype: bfloat16
Weight dtype: float32
Output dtype: float16
Output shape: (5,)
Memory usage:
  Input: 20 bytes
  Weights: 220 bytes
  Output: 10 bytes
  Total: 250 bytes
