# Part 4, Lab 1: FP8 Conversion

**Time:** ~30 minutes

FP8 (8-bit floating point) is the sweet spot for inference—half the memory of FP16 with minimal quality loss. In this lab, you'll understand the FP8 formats and implement conversion.

## Learning Objectives

1. Understand FP8 E4M3 vs E5M2 formats
2. Implement FP8 conversion manually
3. Compare precision vs range trade-offs
4. Use PyTorch's native FP8 support

---
## 1. FP8 Format Overview

FP8 comes in two variants:
- **E4M3**: 4 exponent bits, 3 mantissa bits → more precision, less range (±448)
- **E5M2**: 5 exponent bits, 2 mantissa bits → more range (±57344), less precision

E4M3 is typically used for forward pass (activations), E5M2 for gradients.

In [None]:
import numpy as np
import torch

# Check for FP8 support (requires H100/Ada or newer)
if torch.cuda.is_available():
    capability = torch.cuda.get_device_capability()
    print(f"GPU Compute Capability: {capability}")
    fp8_supported = capability[0] >= 9 or (capability[0] == 8 and capability[1] >= 9)
    print(f"Native FP8 support: {fp8_supported}")
else:
    print("No GPU available - will use CPU emulation")

---
## 2. Manual FP8 Conversion

Let's understand FP8 by implementing conversion ourselves.

In [None]:
def float_to_fp8_e4m3(value):
    """Convert float to FP8 E4M3 format (emulated).
    
    E4M3: 1 sign + 4 exponent + 3 mantissa bits
    Bias: 7, Range: ±448, No infinity/NaN
    """
    if value == 0:
        return 0
    
    # Extract sign
    sign = 0 if value >= 0 else 1
    value = abs(value)
    
    # Clamp to representable range
    max_val = 448.0  # Max for E4M3
    min_val = 2**-9  # Smallest normal
    value = min(max(value, min_val), max_val)
    
    # Calculate exponent and mantissa
    import math
    exp = math.floor(math.log2(value))
    exp_biased = exp + 7  # Bias is 7 for E4M3
    exp_biased = max(0, min(15, exp_biased))  # Clamp to 4 bits
    
    # Calculate mantissa (3 bits = 8 values)
    mantissa = value / (2 ** exp) - 1.0  # Normalized: 1.xxx
    mantissa_int = int(round(mantissa * 8))  # 3 mantissa bits
    mantissa_int = max(0, min(7, mantissa_int))
    
    # Pack into byte
    fp8_byte = (sign << 7) | (exp_biased << 3) | mantissa_int
    return fp8_byte

def fp8_e4m3_to_float(fp8_byte):
    """Convert FP8 E4M3 back to float."""
    sign = (fp8_byte >> 7) & 1
    exp = (fp8_byte >> 3) & 0xF  # 4 bits
    mantissa = fp8_byte & 0x7     # 3 bits
    
    if exp == 0 and mantissa == 0:
        return 0.0
    
    # Reconstruct value
    value = (1.0 + mantissa / 8.0) * (2 ** (exp - 7))
    return -value if sign else value

# Test conversion
test_values = [1.0, 0.5, 2.0, 100.0, 0.125, -3.14]
print("FP8 E4M3 Conversion Test:")
print("-" * 50)
for val in test_values:
    fp8 = float_to_fp8_e4m3(val)
    recovered = fp8_e4m3_to_float(fp8)
    error = abs(val - recovered) / abs(val) * 100 if val != 0 else 0
    print(f"{val:>8.3f} → 0x{fp8:02X} → {recovered:>8.3f}  (error: {error:.1f}%)")

---
## 3. Quantization Error Analysis

Let's visualize the quantization error across the representable range.

In [None]:
# Generate test values across the FP8 range
test_range = np.logspace(-3, 2.5, 1000)  # 0.001 to ~300

errors = []
for val in test_range:
    fp8 = float_to_fp8_e4m3(val)
    recovered = fp8_e4m3_to_float(fp8)
    rel_error = abs(val - recovered) / val * 100
    errors.append(rel_error)

print(f"Average relative error: {np.mean(errors):.2f}%")
print(f"Max relative error: {np.max(errors):.2f}%")
print(f"Values within 5% error: {sum(e < 5 for e in errors) / len(errors) * 100:.1f}%")

---
## 4. PyTorch FP8 (if available)

PyTorch 2.1+ has native FP8 support on compatible hardware.

In [None]:
# Note: Requires PyTorch 2.1+ and H100/Ada GPU
try:
    # Create test tensor
    x = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
    
    # Convert to FP8 (if supported)
    if hasattr(torch, 'float8_e4m3fn'):
        x_fp8 = x.to(torch.float8_e4m3fn)
        x_back = x_fp8.to(torch.float16)
        
        error = (x - x_back).abs().mean().item()
        print(f"Native FP8 conversion error: {error:.6f}")
    else:
        print("FP8 dtype not available in this PyTorch version")
except Exception as e:
    print(f"FP8 test failed: {e}")

---
## Exercises

1. **E5M2 Implementation**: Implement `float_to_fp8_e5m2()` with 5 exponent bits and 2 mantissa bits
2. **Error Distribution**: Plot a histogram of quantization errors for random neural network weights
3. **Scaling**: Implement per-tensor scaling to improve FP8 utilization for values outside the natural range

## Key Takeaways

- FP8 E4M3 has good precision (±6.25% max error) for values in a reasonable range
- E4M3 is preferred for activations (better precision), E5M2 for gradients (better range)
- Scaling factors are critical for mapping your actual value distribution to FP8's representable range