In [None]:
import squigglepy as sq
import numpy as np
import pandas as pd
from squigglepy.numbers import K, M, B
from chip_estimates_utils import (
    normalize_shares,
    compute_h100_equivalents,
    export_quarterly_by_version,
    print_cumulative_summary,
    estimate_chip_sales
)

sq.set_seed(42)
np.random.seed(42)
N_SAMPLES = 5000
H100_TOPS = 1979

In [None]:
# ==============================================
# GOOGLE SHEETS CONFIGURATION
# ==============================================
# Replace SPREADSHEET_ID with your actual Google Sheets ID after uploading
# The ID is the long string in the URL: https://docs.google.com/spreadsheets/d/SPREADSHEET_ID/edit

SPREADSHEET_ID = "1CRXA0T7jpyg7tDJNQ4KHIznpkDopPH3krspvXj0OlWk"

# Sheet names (you can change these if you rename the sheets)
REVENUE_SHEET = "TPU_Revenue"
PROD_MIX_SHEET = "Production_Mix"

# Construct URLs for direct CSV export
REVENUE_URL = f"https://docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/gviz/tq?tqx=out:csv&sheet={REVENUE_SHEET}"
PROD_MIX_URL = f"https://docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/gviz/tq?tqx=out:csv&sheet={PROD_MIX_SHEET}"
print(REVENUE_URL)
print(PROD_MIX_URL)

In [None]:
# ==============================================
# LOAD DATA FROM GOOGLE SHEETS
# ==============================================

print("Loading TPU revenue data from Google Sheets...")
revenue_df = pd.read_csv(REVENUE_URL)
print(f"Loaded {len(revenue_df)} quarters of revenue data")
print(revenue_df.head())

print("\nLoading production mix data from Google Sheets...")
prod_mix_df = pd.read_csv(PROD_MIX_URL)
print(f"Loaded {len(prod_mix_df)} version-quarter combinations")
print(prod_mix_df.head(10))

In [None]:
# ==============================================
# CONVERT TO SQUIGGLEPY DISTRIBUTIONS
# ==============================================

# Convert revenue DataFrame to dictionary of distributions
TPU_REVENUE = {}
for _, row in revenue_df.iterrows():
    quarter = row['quarter']
    TPU_REVENUE[quarter] = sq.to(row['revenue_p5'], row['revenue_p95'])

print(f"Created {len(TPU_REVENUE)} revenue distributions")
print(f"Quarters: {list(TPU_REVENUE.keys())}")

# Convert production mix DataFrame to nested dictionary of distributions
PROD_MIX = {}
for quarter in prod_mix_df['quarter'].unique():
    quarter_data = prod_mix_df[prod_mix_df['quarter'] == quarter]
    PROD_MIX[quarter] = {}
    for _, row in quarter_data.iterrows():
        version = row['version']
        PROD_MIX[quarter][version] = sq.to(row['share_p5'], row['share_p95'])

print(f"\nCreated production mix for {len(PROD_MIX)} quarters")
print(f"Example Q1_FY23 versions: {list(PROD_MIX['Q1_FY23'].keys())}")

In [None]:
# ======================
# TPU Specs and Margins
# ======================

# TPU specs: 8-bit TOPS and manufacturing costs
TPU_SPECS = {
    'v3':  {'tops': 123,  'cost': sq.to(940, 1400)},
    'v4':  {'tops': 275,  'cost': sq.to(1100, 1500)},
    'v5e': {'tops': 393,  'cost': sq.to(950, 1400)},
    'v5p': {'tops': 918,  'cost': sq.to(2300, 2900)},
    'v6e': {'tops': 1836, 'cost': sq.to(1600, 1900)},
    'v7':  {'tops': 4614, 'cost': sq.to(4600, 5500)},
}

# Broadcom margins (higher in FY23, lower afterward)
MARGIN_FY23 = sq.to(0.60, 0.75)
MARGIN = sq.to(0.50, 0.70)

In [None]:
def get_price_distribution(version, is_fy23=False):
    """Get price distribution for a TPU version: cost / (1 - margin)."""
    margin = MARGIN_FY23 if is_fy23 else MARGIN
    return TPU_SPECS[version]['cost'] / (1 - margin)

# Pre-compute price distributions for each version and margin regime
PRICE_DIST_FY23 = {version: get_price_distribution(version, is_fy23=True) for version in TPU_SPECS}
PRICE_DIST = {version: get_price_distribution(version, is_fy23=False) for version in TPU_SPECS}

# Define sampling functions which get passed into estimate_chip_sales
def sample_revenue(quarter):
    # draw one sample from the quarter's distribution
    return (TPU_REVENUE[quarter] @ 1) * B

def sample_shares(quarter):
    mix = PROD_MIX[quarter]
    # draw one sample from each version's distribution
    raw_shares = {version: dist @ 1 for version, dist in mix.items()}
    return normalize_shares(raw_shares)

def sample_price(quarter, version):
    price_dists = PRICE_DIST_FY23 if 'FY23' in quarter else PRICE_DIST
    return price_dists[version] @ 1

In [None]:
"""
Run Monte Carlo simulation to estimate chip volumes.

Args:
    quarters: list of quarter identifiers (e.g., ['Q1_FY23', 'Q2_FY23', ...])
    versions: list of chip types (e.g., ['v3', 'v4', 'v5e', ...])
    sample_revenue: fn(quarter) -> float, samples or looks up total chip revenue in dollars for a quarter
    sample_shares: fn(quarter) -> dict, samples {version: share} for a quarter (should sum to 1)
    sample_price: fn(quarter, version) -> float, samples or looks up price for a chip type in a quarter
    n_samples: number of Monte Carlo samples

Returns:
    Dictionary of {quarter: {version: [array of samples of chip unit counts]}}
    To find median, confidence intervals, etc you will need to take the percentiles of the result

Note on cross-quarter correlations:
    The sampling functions are called independently for each quarter within each iteration.
    This means any parameters you want correlated across quarters (e.g., a single margin
    value affecting all quarters) will NOT be correlated by default. To preserve cross-quarter
    correlations, pre-sample those parameters outside this function and have your sampling
    functions reference them.
"""

sim_results = estimate_chip_sales(
    quarters=list(TPU_REVENUE.keys()),
    versions=list(TPU_SPECS.keys()),
    sample_revenue=sample_revenue,
    sample_shares=sample_shares,
    sample_price=sample_price,
    n_samples=N_SAMPLES
)

In [None]:
# Summarize quarterly results
def summarize_results(results):
    """Create summary DataFrame with percentiles."""
    rows = []
    for quarter in results:
        row = {'Quarter': quarter}
        total = np.zeros(N_SAMPLES)
        for version in TPU_SPECS:
            arr = np.array(results[quarter][version])
            total += arr
            if arr.sum() > 0:
                row[f'{version}_p50'] = int(np.percentile(arr, 50))
        row['total_p5'] = int(np.percentile(total, 5))
        row['total_p50'] = int(np.percentile(total, 50))
        row['total_p95'] = int(np.percentile(total, 95))
        rows.append(row)
    return pd.DataFrame(rows)

df = summarize_results(sim_results)
print("TPU Production Volumes by Quarter (chips)")
print(df[['Quarter', 'total_p5', 'total_p50', 'total_p95']].to_string(index=False))

In [None]:
# Cumulative totals by TPU version
# 
# Note: don't trust the confidence intervals here, because they don't account for correlation across quarters
# This means they are probably too narrow

cumulative = {version: np.zeros(N_SAMPLES) for version in TPU_SPECS}
for quarter in sim_results:
    for version in TPU_SPECS:
        cumulative[version] += np.array(sim_results[quarter][version])

print_cumulative_summary(cumulative, TPU_SPECS, "Cumulative TPU Production (FY23-FY25)")

In [None]:
# H100 equivalents (based on 8-bit TOPS)
h100_eq = compute_h100_equivalents(cumulative, TPU_SPECS, H100_TOPS)
print_cumulative_summary(h100_eq, TPU_SPECS, "H100 Equivalents (8-bit TOPS basis)")

In [None]:
# Fiscal year totals
def fiscal_year_totals(results):
    fy_totals = {'FY23': np.zeros(N_SAMPLES), 'FY24': np.zeros(N_SAMPLES), 'FY25': np.zeros(N_SAMPLES)}
    for quarter in results:
        fy = quarter.split('_')[1]
        for version in TPU_SPECS:
            fy_totals[fy] += np.array(results[quarter][version])
    return fy_totals

fy = fiscal_year_totals(sim_results)
print("\nTotal TPU Production by Fiscal Year")
for year in ['FY23', 'FY24', 'FY25']:
    p5, p50, p95 = [int(np.percentile(fy[year], p)) for p in [5, 50, 95]]
    print(f"{year}: {p50:,} chips (90% CI: {p5:,} - {p95:,})")

In [None]:
# Export quarterly volumes by version to CSV (with H100 equivalents)
export_df = export_quarterly_by_version(
    sim_results, TPU_SPECS, 'tpu_volumes_by_quarter_version.csv', N_SAMPLES, H100_TOPS
)
print(export_df.to_string(index=False))