In [39]:
import squigglepy as sq
import numpy as np
import pandas as pd
from squigglepy.numbers import K, M, B

sq.set_seed(42)
np.random.seed(42)
np.seterr(invalid='raise')  # Warn on operations involving NaN
N_SAMPLES = 5000

from chip_estimates_utils import (
    normalize_shares,
    estimate_chip_sales,
    estimate_cumulative_chip_sales,
    aggregate_by_chip_type,
    interpolate_samples_to_calendar_quarters,
    compute_running_totals,
)


In [40]:
# ==============================================
# GOOGLE SHEETS CONFIGURATION
# ==============================================

SPREADSHEET_ID = "1eGk2AAdewEO81vx-YBRTtdlhZvAMstY7vZuHrf3sgNI"

REVENUE_SHEET = "TPU_Revenue"
PROD_MIX_SHEET = "Production_Mix"
PRICES_SHEET = "prices"

# Construct URLs for direct CSV export
REVENUE_URL = f"https://docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/gviz/tq?tqx=out:csv&sheet={REVENUE_SHEET}"
PROD_MIX_URL = f"https://docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/gviz/tq?tqx=out:csv&sheet={PROD_MIX_SHEET}"
PRICES_URL = f"https://docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/gviz/tq?tqx=out:csv&sheet={PRICES_SHEET}"


In [41]:
# ==============================================
# LOAD DATA FROM GOOGLE SHEETS
# ==============================================

revenue_df = pd.read_csv(REVENUE_URL).dropna(axis=1, how="all")
prod_mix_df = pd.read_csv(PROD_MIX_URL).dropna(axis=1, how="all")
prices_df = pd.read_csv(PRICES_URL).dropna(axis=1, how="all")

QUARTERS = revenue_df['quarter'].tolist()
TPU_VERSIONS = prod_mix_df['version'].dropna().unique().tolist()

print(f"Loaded {len(revenue_df)} quarters of revenue data")
print(revenue_df[['quarter', 'start_date', 'end_date', 'revenue_p5', 'revenue_p95', 'broadcom_margin_p5', 'broadcom_margin_p95']].head())
print()
print(prod_mix_df[['quarter', 'version', 'share_p5', 'share_p95']].head(10))
print()
print(prices_df)


Loaded 12 quarters of revenue data
   quarter  start_date    end_date  revenue_p5  revenue_p95  \
0  Q1_FY23  10/31/2022   1/29/2023        0.42         0.53   
1  Q2_FY23   1/30/2023   4/30/2023        0.53         0.66   
2  Q3_FY23    5/1/2023   7/30/2023        0.53         0.66   
3  Q4_FY23   7/31/2023  10/29/2023        0.79         0.99   
4  Q1_FY24  10/30/2023    2/4/2024        1.30         1.45   

   broadcom_margin_p5  broadcom_margin_p95  
0                0.55                 0.70  
1                0.55                 0.70  
2                0.55                 0.70  
3                0.55                 0.70  
4                0.50                 0.65  

   quarter version  share_p5  share_p95
0  Q1_FY23      v4      0.60       0.80
1  Q1_FY23     v4i      0.10       0.20
2  Q1_FY23     v5e      0.10       0.20
3  Q2_FY23      v4      0.60       0.90
4  Q2_FY23     v5e      0.10       0.40
5  Q3_FY23     v5e      0.65       0.95
6  Q3_FY23      v4      0.05       

In [42]:
# ==============================================
# CONVERT TO SQUIGGLEPY DISTRIBUTIONS
# ==============================================

TPU_REVENUE = {}
for _, row in revenue_df.iterrows():
    quarter = row['quarter']
    TPU_REVENUE[quarter] = sq.norm(row['revenue_p5'], row['revenue_p95'])

PROD_MIX = {}
for quarter in prod_mix_df['quarter'].unique():
    quarter_data = prod_mix_df[prod_mix_df['quarter'] == quarter]
    PROD_MIX[quarter] = {}
    for _, row in quarter_data.iterrows():
        version = row['version']
        PROD_MIX[quarter][version] = sq.norm(row['share_p5'], row['share_p95'], lclip=0, rclip=1)

TPU_SPECS = {
    'v3':  {'tops': 123,  'full_name': 'TPU v3'},
    'v4i': {'tops': 138,  'full_name': 'TPU v4i'},
    'v4':  {'tops': 275,  'full_name': 'TPU v4'},
    'v5e': {'tops': 393,  'full_name': 'TPU v5e'},
    'v5p': {'tops': 918,  'full_name': 'TPU v5p'},
    'v6e': {'tops': 1836, 'full_name': 'TPU v6e'},
    'v7':  {'tops': 4614, 'full_name': 'TPU v7'},
}

# TPU manufacturing costs from prices sheet
TPU_COST = {row['version']: sq.to(row['cost_p5'], row['cost_p95']) for _, row in prices_df.iterrows()}

# Broadcom margins by quarter from revenue sheet
MARGIN_BY_QUARTER = {row['quarter']: sq.to(row['broadcom_margin_p5'], row['broadcom_margin_p95']) for _, row in revenue_df.iterrows()}

print("TPU costs (90% CI):")
for v, cost in TPU_COST.items():
    print(f"  {v}: ${cost.x:,.0f}-${cost.y:,.0f}")

print("Broadcom margins (first 6 quarters):")
for q in list(MARGIN_BY_QUARTER.keys())[:6]:
    m = MARGIN_BY_QUARTER[q]
    print(f"  {q}: {m.x:.0%}-{m.y:.0%}")


TPU costs (90% CI):
  v3: $940-$1,400
  v4i: $700-$1,100
  v4: $1,100-$1,500
  v5e: $950-$1,400
  v5p: $2,300-$2,900
  v6e: $1,600-$1,900
  v7: $4,600-$5,500
Broadcom margins (first 6 quarters):
  Q1_FY23: 55%-70%
  Q2_FY23: 55%-70%
  Q3_FY23: 55%-70%
  Q4_FY23: 55%-70%
  Q1_FY24: 50%-65%
  Q2_FY24: 50%-65%


In [43]:
# ==============================================
# SAMPLING FUNCTIONS
# ==============================================

# Base quarter for correlated price sampling (prices sampled once per chip, using this quarter's margin)
BASE_QUARTER = QUARTERS[0]


def sample_revenue(quarter):
    return (TPU_REVENUE[quarter] @ 1) * B


def sample_shares(quarter):
    mix = PROD_MIX[quarter]
    raw_shares = {version: dist @ 1 for version, dist in mix.items()}
    return normalize_shares(raw_shares)


def sample_base_price(version):
    """Sample base price = cost / (1 - margin) using base quarter margin."""
    return (TPU_COST[version] / (1 - MARGIN_BY_QUARTER[BASE_QUARTER])) @ 1


# Uncorrelated price sampler (for comparison)
def sample_price(quarter, version):
    """Sample price = cost / (1 - margin)"""
    return (TPU_COST[version] / (1 - MARGIN_BY_QUARTER[quarter])) @ 1


# Margin deflation factor for correlated model
# 
# We want to adjust prices based on each quarter's margin while keeping price
# uncertainty correlated across time. The base price is sampled once using
# BASE_QUARTER's margin, then scaled by this deflation factor for other quarters.
#
# We compute this empirically rather than using closed-form because:
# 1. price = cost / (1 - margin), and 1/(1-x) is convex (Jensen's inequality)
# 2. We want E[1/(1-margin_q)] / E[1/(1-margin_base)], not the ratio of medians
# 3. For lognormal margins, the mean of the transform ≠ transform of the mean
#
# Since margin only changes once (FY23: 55-70% → FY24+: 50-65%), we compute
# the deflation factor once and apply it based on fiscal year.

def _compute_expected_price_multiplier(margin_dist, n=10000):
    """Compute E[1/(1-margin)] by sampling."""
    samples = margin_dist @ n
    return np.mean(1 / (1 - samples))

# Compute deflation factor for FY24+ relative to FY23
FY23_MARGIN = sq.to(0.55, 0.70)
FY24_MARGIN = sq.to(0.50, 0.65)
FY24_DEFLATION = _compute_expected_price_multiplier(FY24_MARGIN) / _compute_expected_price_multiplier(FY23_MARGIN)

def get_margin_deflation_factor(quarter, version):
    """Return price adjustment factor: 1.0 for FY23, ~0.88 for FY24+."""
    if 'FY23' in quarter:
        return 1.0
    return FY24_DEFLATION

print(f"FY24+ deflation factor: {FY24_DEFLATION:.4f}")

FY24+ deflation factor: 0.8763


In [None]:
# ==============================================
# RUN CORRELATED SIMULATION
# ==============================================

quarterly_samples = estimate_cumulative_chip_sales(
    quarters=QUARTERS,
    chip_types=TPU_VERSIONS,
    sample_revenue=sample_revenue,
    sample_shares=sample_shares,
    sample_base_price=sample_base_price,
    get_deflation_factor=get_margin_deflation_factor,
    sample_revenue_uncertainty=None,
    n_samples=N_SAMPLES,
)

cumulative_samples = aggregate_by_chip_type(quarterly_samples)

print("Simulation complete.")

In [None]:
# ==============================================
# CUMULATIVE SUMMARY
# ==============================================

def print_cumulative_summary(cumulative_samples, versions, title="Cumulative Production"):
    print(f"{title}")
    print(f"{'Version':<8} {'p5':>12} {'p50':>12} {'p95':>12}")
    print("-" * 51)

    grand_total = None
    for v in versions:
        arr = cumulative_samples[v]
        if arr.sum() > 0:
            if grand_total is None:
                grand_total = np.zeros_like(arr)
            grand_total += arr
            print(f"{v:<8} {int(np.percentile(arr, 5)):>12,} {int(np.percentile(arr, 50)):>12,} {int(np.percentile(arr, 95)):>12,}")

    if grand_total is not None:
        print("-" * 51)
        print(f"{'TOTAL':<8} {int(np.percentile(grand_total, 5)):>12,} {int(np.percentile(grand_total, 50)):>12,} {int(np.percentile(grand_total, 95)):>12,}")

print_cumulative_summary(cumulative_samples, TPU_VERSIONS, "Cumulative TPU Production (Correlated Model)")

In [None]:
# ==============================================
# CUMULATIVE RUNNING TOTALS BY FISCAL QUARTER
# ==============================================

running_totals_samples = compute_running_totals(quarterly_samples)

print("Cumulative Running Totals by Fiscal Quarter")
print(f"{'Quarter':<10} {'Version':<8} {'p5':>12} {'p50':>12} {'p95':>12}")
print("=" * 60)

for quarter in QUARTERS:
    quarter_has_data = False
    for v in TPU_VERSIONS:
        arr = running_totals_samples[quarter][v]
        if arr.sum() > 0:
            quarter_has_data = True
            print(f"{quarter:<10} {v:<8} {int(np.percentile(arr, 5)):>12,} {int(np.percentile(arr, 50)):>12,} {int(np.percentile(arr, 95)):>12,}")

    if quarter_has_data:
        total = sum(running_totals_samples[quarter][v] for v in TPU_VERSIONS)
        print(f"{quarter:<10} {'TOTAL':<8} {int(np.percentile(total, 5)):>12,} {int(np.percentile(total, 50)):>12,} {int(np.percentile(total, 95)):>12,}")
        print("-" * 60)

In [None]:
# ==============================================
# CALENDAR QUARTER INTERPOLATION (SAMPLE-BASED)
# ==============================================

quarter_dates = {q: (revenue_df.loc[revenue_df['quarter'] == q, 'start_date'].iloc[0],
                     revenue_df.loc[revenue_df['quarter'] == q, 'end_date'].iloc[0])
                 for q in QUARTERS}

calendar_quarterly_samples = interpolate_samples_to_calendar_quarters(quarterly_samples, quarter_dates)
calendar_running_totals_samples = compute_running_totals(calendar_quarterly_samples)

print("Cumulative Running Totals by Calendar Quarter")
print(f"{'Quarter':<10} {'Version':<8} {'p5':>12} {'p50':>12} {'p95':>12}")
print("=" * 60)

for cq in calendar_running_totals_samples:
    quarter_has_data = False
    for v in TPU_VERSIONS:
        arr = calendar_running_totals_samples[cq][v]
        if arr.sum() > 0:
            quarter_has_data = True
            print(f"{cq:<10} {v:<8} {int(np.percentile(arr, 5)):>12,} {int(np.percentile(arr, 50)):>12,} {int(np.percentile(arr, 95)):>12,}")

    if quarter_has_data:
        total = sum(calendar_running_totals_samples[cq][v] for v in TPU_VERSIONS)
        print(f"{cq:<10} {'TOTAL':<8} {int(np.percentile(total, 5)):>12,} {int(np.percentile(total, 50)):>12,} {int(np.percentile(total, 95)):>12,}")
        print("-" * 60)

In [None]:
# ==============================================
# CSV EXPORTS: CUMULATIVE RUNNING TOTALS BY CHIP TYPE
# ==============================================
# Export calendar quarter running totals broken down by TPU version

def get_calendar_quarter_dates(cal_q):
    """Return (start_date, end_date) strings for a calendar quarter like 'Q1 2024'."""
    parts = cal_q.split()
    q_num = int(parts[0][1])
    year = int(parts[1])
    if q_num == 1:
        return f"1/1/{year}", f"3/31/{year}"
    elif q_num == 2:
        return f"4/1/{year}", f"6/30/{year}"
    elif q_num == 3:
        return f"7/1/{year}", f"9/30/{year}"
    else:
        return f"10/1/{year}", f"12/31/{year}"

rows = []
for cq in calendar_running_totals_samples:
    start_date, end_date = get_calendar_quarter_dates(cq)
    for version in TPU_VERSIONS:
        arr = calendar_running_totals_samples[cq][version]
        if arr.sum() > 0:
            rows.append({
                'Name': f"Google {version} cumulative through {cq}",
                'Start date': start_date,
                'End date': end_date,
                'Chip type': version,
                'Number of units (5th percentile)': int(np.percentile(arr, 5)),
                'Number of units (median)': int(np.percentile(arr, 50)),
                'Number of units (95th percentile)': int(np.percentile(arr, 95)),
            })

by_chip_df = pd.DataFrame(rows)
by_chip_df.to_csv('tpu_cumulative_by_chip.csv', index=False)
print(f"Exported {len(by_chip_df)} rows to tpu_cumulative_by_chip.csv")
print(by_chip_df.head(10))

In [None]:
# ==============================================
# CSV EXPORTS: FULL-TPU AGGREGATE STATS
# ==============================================
# Export calendar quarter running totals with aggregate metrics across all TPU versions
# Metrics: total units, H100e compute

H100_TOPS = 1979  # Reference for H100-equivalent calculation

rows = []
for cq in calendar_running_totals_samples:
    start_date, end_date = get_calendar_quarter_dates(cq)
    
    # Compute aggregate metrics across all TPU versions
    n_samples = N_SAMPLES
    units_total = np.zeros(n_samples)
    h100e_total = np.zeros(n_samples)
    
    for version in TPU_VERSIONS:
        arr = calendar_running_totals_samples[cq][version]
        units_total += arr
        if version in TPU_SPECS:
            tops = TPU_SPECS[version]['tops']
            h100e_total += arr * (tops / H100_TOPS)
    
    rows.append({
        'Name': f"Google TPU cumulative through {cq}",
        'Designer': 'Google',
        'Start date': start_date,
        'End date': end_date,
        'Number of units (5th percentile)': int(np.percentile(units_total, 5)),
        'Number of units (median)': int(np.percentile(units_total, 50)),
        'Number of units (95th percentile)': int(np.percentile(units_total, 95)),
        'Compute estimate in H100e (5th percentile)': int(np.percentile(h100e_total, 5)),
        'Compute estimate in H100e (median)': int(np.percentile(h100e_total, 50)),
        'Compute estimate in H100e (95th percentile)': int(np.percentile(h100e_total, 95)),
    })

totals_df = pd.DataFrame(rows)
totals_df.to_csv('tpu_cumulative_totals.csv', index=False)
print(f"Exported {len(totals_df)} rows to tpu_cumulative_totals.csv")
print(totals_df)

In [None]:
# ==============================================
# CHRONOLOGICAL VIEW: FISCAL + CALENDAR INTERLEAVED
# ==============================================

from datetime import datetime

selected_fiscal = QUARTERS[:5]
selected_calendar = list(calendar_running_totals_samples.keys())[:4]

timeline = []

for q in selected_fiscal:
    end_date = pd.to_datetime(revenue_df.loc[revenue_df['quarter'] == q, 'end_date'].iloc[0])
    timeline.append({
        'end_date': end_date,
        'label': q,
        'data': running_totals_samples[q],
        'type': 'FISCAL',
    })

for cq in selected_calendar:
    parts = cq.split()
    q_num, year = int(parts[0][1]), int(parts[1])
    end_dates = {1: (3, 31), 2: (6, 30), 3: (9, 30), 4: (12, 31)}
    end_date = datetime(year, *end_dates[q_num])
    timeline.append({
        'end_date': end_date,
        'label': cq,
        'data': calendar_running_totals_samples[cq],
        'type': 'CALENDAR',
    })

# Sort chronologically
timeline.sort(key=lambda x: x['end_date'])

print("Chronological Comparison: Fiscal vs Calendar Quarter Running Totals")
print("=" * 80)

for entry in timeline:
    print(f"\n{entry['type']}: {entry['label']} (ends {entry['end_date'].strftime('%Y-%m-%d')})")
    print(f"  {'Version':<8} {'p5':>12} {'p50':>12} {'p95':>12}")
    print(f"  {'-'*50}")

    total = np.zeros(N_SAMPLES)
    for v in TPU_VERSIONS:
        arr = entry['data'][v]
        if arr.sum() > 0:
            total += arr
            print(f"  {v:<8} {int(np.percentile(arr, 5)):>12,} {int(np.percentile(arr, 50)):>12,} {int(np.percentile(arr, 95)):>12,}")

    print(f"  {'-'*50}")
    print(f"  {'TOTAL':<8} {int(np.percentile(total, 5)):>12,} {int(np.percentile(total, 50)):>12,} {int(np.percentile(total, 95)):>12,}")

In [None]:
# ==============================================
# UNCORRELATED MODEL COMPARISON
# ==============================================

uncorrelated_samples = estimate_chip_sales(
    quarters=QUARTERS,
    versions=TPU_VERSIONS,
    sample_revenue=sample_revenue,
    sample_shares=sample_shares,
    sample_price=sample_price,
    n_samples=1000,  # smaller for speed
)

cumulative_uncorrelated_samples = aggregate_by_chip_type(uncorrelated_samples)

In [None]:
print_cumulative_summary(cumulative_uncorrelated_samples, TPU_VERSIONS, "Cumulative TPU Production (Uncorrelated Model)")
print(" ")
print_cumulative_summary(cumulative_samples, TPU_VERSIONS, "Cumulative TPU Production (Correlated Model)")