# Tensor Core Math — Python Companion

This notebook runs the same math as the browser simulator, but with **real NumPy float16** instead of JS bit manipulation. Use this alongside the interactive tools to verify the numbers are accurate and to go deeper where the simulator simplifies.

Covers:
1. FP16 precision — what gets lost and why
2. MMA: D = A×B + C with real mixed precision
3. Precision error analysis vs FP32 reference
4. Binary reduction tree depth
5. GEMM tiling
6. Transformer FLOP breakdown

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd

plt.style.use('dark_background')
plt.rcParams.update({
    'axes.facecolor':  '#0f1117',
    'figure.facecolor':'#08090b',
    'axes.edgecolor':  '#252b3a',
    'axes.labelcolor': '#8a9ab0',
    'xtick.color':     '#5a6a80',
    'ytick.color':     '#5a6a80',
    'text.color':      '#dce5f0',
    'grid.color':      '#171a23',
    'font.family':     'monospace',
})

ACC  = '#5b8fff'
GRN  = '#2ecc7a'
ORG  = '#f0853a'
RED  = '#e05555'
PUR  = '#a67cff'

---
## 1. FP16 Precision — What Gets Lost

NumPy's `float16` is real IEEE 754 half-precision: **1 sign + 5 exponent + 10 mantissa bits**.  
The JS simulator approximates this by truncating the FP32 mantissa. Here we use the actual dtype.

Key limits:
- Max representable value: **65504**
- Smallest normal: **~6.1e-5**
- Machine epsilon (gap between 1.0 and next value): **~0.00098** (≈ 2⁻¹⁰)

In [None]:
print("=== FP16 dtype info ===")
info = np.finfo(np.float16)
print(f"  max value   : {info.max}")
print(f"  min normal  : {info.tiny}")
print(f"  epsilon     : {info.eps}  (gap between 1.0 and next representable)")
print(f"  precision   : ~{info.precision} decimal digits")
print()
print("=== FP32 dtype info ===")
info32 = np.finfo(np.float32)
print(f"  max value   : {info32.max:.3e}")
print(f"  epsilon     : {info32.eps}")
print(f"  precision   : ~{info32.precision} decimal digits")
print()

# Show where FP16 rounds
vals = [0.1, 0.3333, 1.0, 1.001, 100.0, 1000.0, 10000.0]
print(f"{'Value':>10}  {'FP32':>12}  {'FP16':>12}  {'Abs Error':>12}")
print("-" * 52)
for v in vals:
    f32 = np.float32(v)
    f16 = np.float16(v)
    err = abs(float(f32) - float(f16))
    print(f"{v:>10}  {float(f32):>12.6f}  {float(f16):>12.6f}  {err:>12.2e}")

In [None]:
# Visualize the density of representable FP16 values
# (they're not evenly spaced — denser near 0, sparser at large values)
ranges = [
    (0.0, 2.0, 'Dense near 0'),
    (1.0, 4.0, 'Mid range'),
    (512.0, 1024.0, 'Large values'),
]

fig, axes = plt.subplots(1, 3, figsize=(14, 3))
fig.suptitle('FP16 Representable Value Density', color='#dce5f0', fontsize=11)

for ax, (lo, hi, title) in zip(axes, ranges):
    pts = np.linspace(lo, hi, 2000, dtype=np.float32)
    pts_f16 = pts.astype(np.float16).astype(np.float32)
    unique_vals = np.unique(pts_f16)
    ax.scatter(unique_vals, np.ones_like(unique_vals), s=2, color=ACC, alpha=0.7)
    ax.set_title(title, fontsize=9, color='#8a9ab0')
    ax.set_xlabel('value', fontsize=8)
    ax.set_yticks([])
    ax.set_xlim(lo, hi)
    ax.annotate(f'{len(unique_vals)} distinct values', xy=(0.5, 0.75),
                xycoords='axes fraction', ha='center', fontsize=8, color=GRN)

plt.tight_layout()
plt.show()

---
## 2. MMA: D = A×B + C

The tensor core's one operation. Inputs A, B are FP16. C (accumulator) and output D are FP32.  
This mixed-precision path is what lets you compute fast (FP16 multiplies) while accumulating accurately (FP32 adds).

In [None]:
rng = np.random.default_rng(42)

def make_matrix_fp16(n=4, scale=3.0):
    """Random FP16 matrix, values in [-scale, scale]."""
    return (rng.uniform(-scale, scale, (n, n))).astype(np.float16)

def mma(A_fp16, B_fp16, C_fp32):
    """
    Simulate tensor core MMA: D = A @ B + C
    
    Hardware path:
      1. Load A, B as FP16
      2. Each element multiply: FP16 × FP16 → promoted to FP32
      3. Accumulate products in FP32 (reduction tree)
      4. Add FP32 accumulator C
      5. Output D in FP32
    """
    # Cast up to FP32 for accumulation — mirrors hardware upcast after multiply
    A = A_fp16.astype(np.float32)
    B = B_fp16.astype(np.float32)
    return A @ B + C_fp32

def mma_fp32_reference(n=4, scale=3.0):
    """Pure FP32 reference — same values but no FP16 rounding."""
    A = rng.uniform(-scale, scale, (n, n)).astype(np.float32)
    B = rng.uniform(-scale, scale, (n, n)).astype(np.float32)
    C = rng.uniform(-0.5, 0.5, (n, n)).astype(np.float32)
    return A @ B + C

# Run one MMA
A = make_matrix_fp16()
B = make_matrix_fp16()
C = rng.uniform(-0.5, 0.5, (4, 4)).astype(np.float32)

D = mma(A, B, C)

print("A (FP16):")
print(A)
print(f"\nB (FP16):")
print(B)
print(f"\nC accumulator (FP32):")
print(np.round(C, 4))
print(f"\nD = A×B + C (FP32 output):")
print(np.round(D, 4))
print(f"\nD dtype: {D.dtype}  ← always FP32 out")

---
## 3. Precision Error Analysis

How much error does the FP16 input path introduce vs a pure FP32 computation?  
We run many MMAs, collect the max absolute error per run, and look at the distribution.

In [None]:
def run_error_trial(scale=3.0, n=4):
    """One trial: compare FP16-input MMA vs pure FP32 MMA on the same values."""
    # Generate in FP32, then cast A and B to FP16 (introducing rounding)
    A_fp32 = rng.uniform(-scale, scale, (n, n)).astype(np.float32)
    B_fp32 = rng.uniform(-scale, scale, (n, n)).astype(np.float32)
    C      = rng.uniform(-0.5, 0.5, (n, n)).astype(np.float32)

    A_fp16 = A_fp32.astype(np.float16)
    B_fp16 = B_fp32.astype(np.float16)

    D_tc  = mma(A_fp16, B_fp16, C)        # tensor core path: FP16 inputs
    D_ref = A_fp32 @ B_fp32 + C           # reference: pure FP32

    return np.max(np.abs(D_tc - D_ref))

N_TRIALS = 5000
errors = [run_error_trial() for _ in range(N_TRIALS)]
errors = np.array(errors)

print(f"Over {N_TRIALS} random 4×4 MMAs (FP16 inputs, scale ±3):")
print(f"  mean max error  : {errors.mean():.4f}")
print(f"  median max error: {np.median(errors):.4f}")
print(f"  p99 max error   : {np.percentile(errors, 99):.4f}")
print(f"  worst case      : {errors.max():.4f}")
print()
print("Compare to FP16 epsilon (~0.001) — error scales with magnitude of values.")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(13, 4))

# Error distribution
ax = axes[0]
ax.hist(errors, bins=60, color=ACC, alpha=0.8, edgecolor='none')
ax.axvline(errors.mean(), color=GRN, lw=1.5, ls='--', label=f'mean={errors.mean():.3f}')
ax.axvline(np.percentile(errors, 99), color=ORG, lw=1.5, ls='--', label=f'p99={np.percentile(errors,99):.3f}')
ax.set_xlabel('Max absolute error per MMA')
ax.set_ylabel('Count')
ax.set_title('FP16-path vs FP32 reference error\n(4×4 MMA, 5000 trials)', fontsize=10)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Error vs input scale
ax = axes[1]
scales = np.linspace(0.1, 10, 40)
mean_errors = [np.mean([run_error_trial(scale=s) for _ in range(200)]) for s in scales]
ax.plot(scales, mean_errors, color=ORG, lw=2)
ax.set_xlabel('Input value scale (±s)')
ax.set_ylabel('Mean max error')
ax.set_title('Error grows with input magnitude\n(FP16 has fixed relative precision)', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print("\nKey insight: FP16 error is proportional to value magnitude (relative precision is fixed).")
print("This is why weight initialization and gradient scaling matter in mixed-precision training.")

---
## 4. Binary Reduction Tree

Inside the tensor core, N multiplier outputs are summed using a binary adder tree.  
Depth = ⌈log₂(N)⌉ — this is the critical path that limits clock speed, not N itself.

In [None]:
import math

def tree_depth(n):
    return math.ceil(math.log2(n))

def simulate_reduction_tree(values, input_bits=16):
    """
    Simulate a binary reduction tree.
    Returns: (result, depth, bit_widths_per_level)
    Each level of adds needs 1 extra bit to hold carry.
    """
    level = list(values)
    depth = 0
    bit_widths = [input_bits]
    while len(level) > 1:
        next_level = []
        for i in range(0, len(level) - 1, 2):
            next_level.append(level[i] + level[i+1])
        if len(level) % 2 == 1:
            next_level.append(level[-1])  # pass odd one through
        level = next_level
        depth += 1
        bit_widths.append(input_bits + depth)  # grows by 1 bit per level
    return level[0], depth, bit_widths

# Compare tree depth vs sequential depth
ns = [2, 4, 8, 16, 32, 64, 128, 256]
df = pd.DataFrame({
    'N (multipliers)': ns,
    'Tree depth (log₂N)': [tree_depth(n) for n in ns],
    'Sequential depth (N-1)': [n-1 for n in ns],
    'Speedup factor': [round((n-1)/tree_depth(n), 1) for n in ns],
    'Final bit width (FP16 in)': [16 + tree_depth(n) for n in ns],
})
print(df.to_string(index=False))
print()
print("Full Volta tensor core: N=64 multipliers → depth=6 adds on critical path")
print("(vs 63 sequential adds — 10.5× shallower pipeline)")

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

ns_plot = range(2, 257)
tree_depths   = [tree_depth(n) for n in ns_plot]
seq_depths    = [n-1 for n in ns_plot]

ax.plot(list(ns_plot), tree_depths, color=ACC, lw=2.5, label='Tree depth: ⌈log₂(N)⌉')
ax.plot(list(ns_plot), seq_depths,  color=RED, lw=1.5, ls='--', label='Sequential: N-1')

# Annotate the Volta TC point
ax.axvline(64, color=GRN, lw=1, ls=':', alpha=0.8)
ax.annotate('Volta TC\nN=64, depth=6',
            xy=(64, tree_depth(64)), xytext=(80, 20),
            color=GRN, fontsize=9,
            arrowprops=dict(arrowstyle='->', color=GRN, lw=1))

ax.set_xlabel('N (number of multipliers)')
ax.set_ylabel('Adder depth (critical path)')
ax.set_title('Binary Reduction Tree: log₂(N) depth vs sequential N-1', fontsize=11)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(2, 256)

plt.tight_layout()
plt.show()

---
## 5. GEMM Tiling

A large GEMM (e.g. 1024×1024) is broken into 4×4 tiles that fit in on-chip SRAM.  
Each tile executes one tensor core MMA. The outer loops are software-managed.

In [None]:
def tiled_gemm_fp16(A_fp16, B_fp16, tile=4):
    """
    Tiled GEMM simulating how cuBLAS dispatches tensor core MMAs.
    A: [M, K] FP16
    B: [K, N] FP16
    Returns D: [M, N] FP32
    """
    M, K = A_fp16.shape
    K2, N = B_fp16.shape
    assert K == K2
    assert M % tile == 0 and K % tile == 0 and N % tile == 0

    D = np.zeros((M, N), dtype=np.float32)  # FP32 output
    mma_count = 0

    for tm in range(M // tile):
        for tn in range(N // tile):
            C_tile = np.zeros((tile, tile), dtype=np.float32)  # FP32 accumulator
            for tk in range(K // tile):
                A_tile = A_fp16[tm*tile:(tm+1)*tile, tk*tile:(tk+1)*tile]
                B_tile = B_fp16[tk*tile:(tk+1)*tile, tn*tile:(tn+1)*tile]
                # One tensor core MMA: FP16 inputs, FP32 accumulate
                C_tile = mma(A_tile, B_tile, C_tile)
                mma_count += 1
            D[tm*tile:(tm+1)*tile, tn*tile:(tn+1)*tile] = C_tile

    return D, mma_count

# Test on a 16×16 matrix
M = N = K = 16
A = rng.uniform(-2, 2, (M, K)).astype(np.float16)
B = rng.uniform(-2, 2, (K, N)).astype(np.float16)

D_tiled, n_mmas = tiled_gemm_fp16(A, B, tile=4)
D_ref = A.astype(np.float32) @ B.astype(np.float32)

flops = 2 * M * N * K
print(f"GEMM {M}×{K} × {K}×{N}:")
print(f"  MMA calls     : {n_mmas}  (= {M//4} × {N//4} × {K//4} tiles)")
print(f"  FLOPs         : {flops:,}  (= 2MNK)")
print(f"  FLOPs/MMA     : {flops // n_mmas}  (= 2×4³ = 128 per 4×4×4 MMA)")
print(f"  Max abs error : {np.max(np.abs(D_tiled - D_ref)):.5f}")
print()

sizes = [16, 32, 64, 128, 256, 512, 1024]
print(f"{'M=N=K':>8}  {'MMAs':>8}  {'FLOPs':>12}  {'Bytes moved (FP16)':>20}  {'Arith Intensity':>16}")
print("-" * 72)
for s in sizes:
    mmas   = (s//4)**3
    flops  = 2 * s**3
    bytes_ = 2 * s**2 * 2   # A + B, FP16 = 2 bytes each
    ai     = flops / bytes_
    print(f"{s:>8}  {mmas:>8,}  {flops:>12,}  {bytes_:>20,}  {ai:>15.1f}")

print("\nArithmetic intensity (FLOPs/byte) grows with matrix size — large GEMMs are compute-bound.")

---
## 6. Transformer FLOP Breakdown

For a transformer layer with `d_model` and sequence length `S`, almost all FLOPs are GEMMs and therefore tensor core workloads.

In [None]:
MODELS = {
    'GPT-2 (1.5B)':     dict(d=1600,  layers=48,  dff=6400,  heads=25),
    'GPT-3 (175B)':     dict(d=12288, layers=96,  dff=49152, heads=96),
    'LLaMA-3 (70B)':    dict(d=8192,  layers=80,  dff=28672, heads=64),
    'GPT-4 class':      dict(d=18432, layers=120, dff=73728, heads=128),
}

GPUS = {
    'V100':  125e12,   # 125 TFLOPS FP16
    'A100':  312e12,
    'H100':  989e12,
}

def transformer_flops(d, layers, dff, S):
    """FLOPs for one forward pass. All per-layer values × L layers."""
    per_layer = {
        'QKᵀ  (TC)':      2 * S * d * d,        # attention: Q·Kᵀ
        '·V   (TC)':      2 * S * d * d,        # attention: softmax·V
        'OutP (TC)':      2 * S * d * d,        # output projection
        'MLP↑ (TC)':      2 * S * d * dff,      # MLP up-project
        'MLP↓ (TC)':      2 * S * dff * d,      # MLP down-project
        'Softmax':        S * S,                 # elementwise — NOT TC
        'LayerNorm':      10 * S * d,            # elementwise — NOT TC
    }
    return {k: v * layers for k, v in per_layer.items()}

SEQ_LEN = 2048

rows = []
for model_name, cfg in MODELS.items():
    flops = transformer_flops(cfg['d'], cfg['layers'], cfg['dff'], SEQ_LEN)
    total = sum(flops.values())
    tc    = sum(v for k, v in flops.items() if 'TC' in k)
    rows.append({
        'Model': model_name,
        'Total FLOPs': total,
        'TC FLOPs %': f"{tc/total*100:.1f}%",
        'V100 time': f"{total/GPUS['V100']*1000:.1f} ms",
        'A100 time': f"{total/GPUS['A100']*1000:.1f} ms",
        'H100 time': f"{total/GPUS['H100']*1000:.1f} ms",
    })

df = pd.DataFrame(rows).set_index('Model')
df['Total FLOPs'] = df['Total FLOPs'].apply(lambda x: f"{x:.2e}")
print(f"Transformer FLOPs — seq_len={SEQ_LEN} (single forward pass, peak TFLOPS, no batching)")
print()
print(df.to_string())

In [None]:
# Stacked bar: FLOP breakdown by operation for each model
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

op_colors = {
    'QKᵀ  (TC)': ACC,
    '·V   (TC)': '#4a7ce8',
    'OutP (TC)': '#3d6cd0',
    'MLP↑ (TC)': GRN,
    'MLP↓ (TC)': '#26b86e',
    'Softmax':   ORG,
    'LayerNorm': '#5a6a80',
}

model_names = list(MODELS.keys())
all_flops   = {m: transformer_flops(**{k:v for k,v in cfg.items()}, S=SEQ_LEN)
               for m, cfg in MODELS.items()}
op_names    = list(op_colors.keys())

# Left: absolute FLOPs
ax = axes[0]
bottoms = np.zeros(len(model_names))
for op in op_names:
    vals = np.array([all_flops[m][op] for m in model_names])
    ax.bar(model_names, vals / 1e12, bottom=bottoms / 1e12,
           color=op_colors[op], label=op.strip(), width=0.6)
    bottoms += vals
ax.set_ylabel('FLOPs (TFLOPS equivalent)')
ax.set_title(f'Total FLOPs by operation\n(seq_len={SEQ_LEN})', fontsize=10)
ax.tick_params(axis='x', labelrotation=15)
ax.legend(fontsize=8, loc='upper left')
ax.grid(True, alpha=0.2, axis='y')

# Right: percentage breakdown
ax = axes[1]
bottoms = np.zeros(len(model_names))
for op in op_names:
    totals = np.array([sum(all_flops[m].values()) for m in model_names])
    vals   = np.array([all_flops[m][op] for m in model_names])
    pcts   = vals / totals * 100
    ax.bar(model_names, pcts, bottom=bottoms,
           color=op_colors[op], label=op.strip(), width=0.6)
    bottoms += pcts
ax.set_ylabel('% of total FLOPs')
ax.set_title('FLOP distribution (% of total)\nBlue/green = tensor core, orange/gray = elementwise', fontsize=10)
ax.tick_params(axis='x', labelrotation=15)
ax.set_ylim(0, 105)
ax.axhline(100, color='#252b3a', lw=1)
ax.grid(True, alpha=0.2, axis='y')

# Annotate TC % on right chart
for i, m in enumerate(model_names):
    total = sum(all_flops[m].values())
    tc    = sum(v for k, v in all_flops[m].items() if 'TC' in k)
    ax.text(i, 102, f"{tc/total*100:.0f}% TC", ha='center', fontsize=8, color=GRN)

plt.tight_layout()
plt.show()

---
## 7. Roofline Model

Two limits to GPU throughput: **peak compute** and **memory bandwidth**.  
The roofline shows which constraint applies given a workload's arithmetic intensity.

In [None]:
gpu_specs = {
    'V100': dict(peak_tflops=125,  bw_gbps=900,   color='#5a6a80'),
    'A100': dict(peak_tflops=312,  bw_gbps=2000,  color=ACC),
    'H100': dict(peak_tflops=989,  bw_gbps=3350,  color=GRN),
}

# Arithmetic intensity = FLOPs / bytes
ai = np.logspace(-2, 3, 500)  # 0.01 to 1000 FLOPs/byte

fig, ax = plt.subplots(figsize=(12, 6))

for gpu_name, spec in gpu_specs.items():
    peak  = spec['peak_tflops'] * 1e12   # FLOPs/s
    bw    = spec['bw_gbps'] * 1e9        # bytes/s
    roof  = np.minimum(peak, ai * bw)    # roofline: min(compute bound, memory bound)
    ax.plot(ai, roof / 1e12, color=spec['color'], lw=2, label=f"{gpu_name} ({spec['peak_tflops']} TFLOPS, {spec['bw_gbps']} GB/s)")
    ridge = peak / bw  # arithmetic intensity at the "ridge point"
    ax.axvline(ridge, color=spec['color'], lw=0.8, ls=':', alpha=0.6)

# Annotate workload operating points
workloads = [
    ('LLM inference\n(batch=1)', 0.05, 50, RED, '^'),
    ('LLM inference\n(batch=32)', 1.5, 180, ORG, '^'),
    ('LLM training\n(large batch)', 100, 600, GRN, 's'),
]
for label, x, y, color, marker in workloads:
    ax.scatter([x], [y], color=color, s=80, zorder=5, marker=marker)
    ax.annotate(label, (x, y), textcoords='offset points', xytext=(10, 5),
                fontsize=8, color=color)

ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Arithmetic Intensity (FLOPs / byte)')
ax.set_ylabel('Attainable throughput (TFLOPS)')
ax.set_title('Roofline Model — V100 / A100 / H100', fontsize=12)
ax.legend(fontsize=9, loc='upper left')
ax.grid(True, alpha=0.3, which='both')

ax.annotate('← memory bound', xy=(0.02, 5), fontsize=9, color=ORG, alpha=0.8)
ax.annotate('compute bound →', xy=(200, 5), fontsize=9, color=GRN, alpha=0.8)

plt.tight_layout()
plt.show()

print("Ridge points (memory bound → compute bound transition):")
for gpu_name, spec in gpu_specs.items():
    ridge = (spec['peak_tflops'] * 1e12) / (spec['bw_gbps'] * 1e9)
    print(f"  {gpu_name}: {ridge:.0f} FLOPs/byte")