# Robotics Industry Analytics

Last updated by Michael Harries and Claude Code, Feb 21, 2026.

Part of the IEEE RAS robotics industry report generation pipeline.

---

Generates publication-ready funding and company-formation charts from [Tracxn](https://tracxn.com/) data exports. Upload two Excel files (companies + funding rounds), run all cells, and download a ZIP of 300 DPI PNG charts.

**Required inputs** (two Tracxn Excel exports):

| File | Key columns |
|---|---|
| Companies export | `Subcategory`, `Domain Name`, `Category`, `Country`, `Founded Year` |
| Funding rounds export | `Domain Name`, `Round Amount (in USD)` or `Round Amount (USD)`, `Round Date` |

**How to run:** Runtime > Run all, upload the two files when prompted, then download `robotics_charts.zip` from the file browser (folder icon, left sidebar).

**Output:** Companies-founded charts, global/sector/region/subcategory funding charts -- all using a colorblind-friendly palette (Paul Tol) and IEEE-compatible styling.

In [None]:
# Cell 2: Setup & Imports
!pip install -q openpyxl

%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import FuncFormatter
import pandas as pd
import numpy as np
import os
import zipfile
import gc
from google.colab import files

In [None]:
# Cell 3: Configuration

# Year ranges
START_YEAR_ALL = 1900
START_YEAR_DEFAULT = 2000

# Funding thresholds
ROUND_SIZE_THRESHOLD = 100_000_000  # $100M boundary for stacked funding charts

# Display limits
TOP_SECTORS_COUNT = 7  # Sectors shown individually before 'Other' bucket

# Output
CHART_OUTPUT_DIR = 'no_title_charts'

# --- Color Palette (Paul Tol Colorblind-Friendly) ---
PAUL_TOL_PRIMARY = ['#4477AA', '#EE6677', '#228833', '#CCBB44', '#66CCEE', '#AA3377', '#BBBBBB']

PAUL_TOL_EXTENDED = PAUL_TOL_PRIMARY + [
    '#332288', '#88CCEE', '#44AA99', '#117733',
    '#999933', '#DDCC77', '#CC6677', '#882255', '#AA4499'
]

# --- Country Standardization ---
COUNTRY_STANDARDIZATION = {
    'USA': 'United States',
    'Republic of Korea': 'South Korea',
    'United Arab Emirates': 'UAE',
}

# --- Second-Tier Region Mapping ---
# All country names use POST-standardization values
SECOND_TIER_REGIONS = {
    'USA': ['United States'],
    'Canada': ['Canada'],
    'China': ['China'],
    'United Kingdom': ['United Kingdom'],
    'India': ['India'],
    'APAC': [
        'Japan', 'Kazakhstan', 'Nepal', 'Pakistan', 'Australia',
        'Bangladesh', 'New Zealand', 'South Korea', 'Singapore',
        'Indonesia', 'Malaysia', 'Thailand', 'Vietnam',
        'Philippines', 'Sri Lanka', 'Taiwan'
    ],
    'Middle East': [
        'Jordan', 'Kuwait', 'Oman', 'Armenia', 'Israel',
        'Turkey', 'UAE', 'Iran', 'Lebanon', 'Saudi Arabia'
    ],
    'Africa': [
        'Senegal', 'Rwanda', 'Ghana', 'Tunisia', 'Morocco',
        'South Africa', 'Egypt', 'Nigeria', 'Kenya'
    ],
    'Western Europe': [
        'Germany', 'France', 'Spain', 'Italy', 'Netherlands',
        'Austria', 'Belgium', 'Switzerland', 'Luxembourg',
        'Portugal', 'Ireland'
    ],
    'Eastern Europe': [
        'Belarus', 'Serbia', 'Bulgaria', 'Ukraine', 'Poland',
        'Russia', 'Czech Republic', 'Slovakia', 'Slovenia',
        'Croatia', 'Cyprus', 'Greece', 'Hungary', 'Romania'
    ],
    'Nordics and Baltics': [
        'Sweden', 'Norway', 'Finland', 'Denmark', 'Iceland',
        'Estonia', 'Latvia', 'Lithuania'
    ],
    'Americas (No Canada and USA)': [
        'Guatemala', 'Puerto Rico', 'Costa Rica', 'Uruguay',
        'Panama', 'Ecuador', 'Mexico', 'Brazil', 'Argentina',
        'Chile', 'Colombia', 'Peru'
    ],
}

# --- Reverse Region Lookup ---
COUNTRY_TO_REGION = {}
for region, countries in SECOND_TIER_REGIONS.items():
    for country in countries:
        COUNTRY_TO_REGION[country] = region

In [None]:
# Cell 4: Data Upload & Cleaning
import re
from datetime import datetime

# --- Cleaning Functions ---

def clean_amount(amount):
    """Clean and convert amount to float. Returns 0.0 on failure."""
    if pd.isna(amount) or amount is None:
        return 0.0
    if isinstance(amount, (int, float)):
        return float(amount)
    if isinstance(amount, str):
        amount = amount.replace(',', '').replace('$', '').replace(' ', '')
        if amount == '':
            return 0.0
        try:
            return float(amount)
        except ValueError:
            return 0.0
    return 0.0

def extract_year(date_str):
    """Extract year from various date formats. Valid range: 1800-2030."""
    if pd.isna(date_str) or date_str is None:
        return None
    date_str = str(date_str).strip()
    if date_str == '' or date_str.lower() == 'none':
        return None

    # Handle bare year (4 digits)
    if len(date_str) == 4 and date_str.isdigit():
        year = int(date_str)
        return year if 1800 <= year <= 2030 else None

    # Handle "Jan 01, 2020" format (Tracxn)
    if ',' in date_str:
        try:
            year = datetime.strptime(date_str.strip(), '%b %d, %Y').year
            return year if 1800 <= year <= 2030 else None
        except (ValueError, TypeError):
            pass

    # Try standard date formats
    for fmt in ['%Y-%m-%d', '%m/%d/%Y', '%d/%m/%Y', '%b %Y', '%B %Y']:
        try:
            year = datetime.strptime(date_str, fmt).year
            return year if 1800 <= year <= 2030 else None
        except ValueError:
            continue

    # Regex fallback
    match = re.search(r'\b(18|19|20)\d{2}\b', date_str)
    if match:
        year = int(match.group())
        return year if 1800 <= year <= 2030 else None

    return None

def clean_category(category):
    """Clean category names - take first if comma-separated."""
    if pd.isna(category) or category is None:
        return 'Unknown'
    category = str(category).strip()
    if category == '':
        return 'Unknown'
    if ',' in category:
        category = category.split(',')[0].strip()
    return category

def clean_country(country):
    """Clean country name - take first if comma-separated, apply standardization."""
    if pd.isna(country) or country is None:
        return 'Unknown'
    country = str(country).strip()
    if country == '':
        return 'Unknown'
    if ',' in country:
        country = country.split(',')[0].strip()
    return COUNTRY_STANDARDIZATION.get(country, country)

# --- File Upload ---
print("Upload two Excel files: one companies export and one funding rounds export from Tracxn.")
uploaded = files.upload()
filenames = list(uploaded.keys())
assert len(filenames) == 2, f"Expected 2 files, got {len(filenames)}: {filenames}"

dfs = {}
for fn in filenames:
    dfs[fn] = pd.read_excel(fn, engine='openpyxl')
    print(f"  Loaded {fn}: {dfs[fn].shape[0]} rows, {dfs[fn].shape[1]} columns")

# --- Auto-Detection ---
companies_file = None
funding_file = None

for fn, df in dfs.items():
    if 'Subcategory' in df.columns:
        if companies_file is not None:
            raise ValueError("Both files contain 'Subcategory'. Please upload one companies file and one funding rounds file.")
        companies_file = fn
    else:
        funding_file = fn

if companies_file is None:
    all_cols = {fn: list(df.columns) for fn, df in dfs.items()}
    raise ValueError(f"Neither uploaded file contains a 'Subcategory' column. "
                     f"The companies file from Tracxn must include this column. "
                     f"Available columns: {all_cols}")

if funding_file is None:
    funding_file = [fn for fn in filenames if fn != companies_file][0]

print(f"\nDetected companies file: {companies_file}")
print(f"Detected funding file:  {funding_file}")

companies_raw = dfs[companies_file]
funding_raw = dfs[funding_file]

# --- Validation ---
assert 'Domain Name' in companies_raw.columns, \
    f"Companies file missing 'Domain Name'. Columns: {list(companies_raw.columns)}"
assert 'Domain Name' in funding_raw.columns, \
    f"Funding file missing 'Domain Name'. Columns: {list(funding_raw.columns)}"

# Round amount column may be named differently across Tracxn export versions
ROUND_AMOUNT_COL = None
for candidate in ['Round Amount (in USD)', 'Round Amount (USD)']:
    if candidate in funding_raw.columns:
        ROUND_AMOUNT_COL = candidate
        break
assert ROUND_AMOUNT_COL is not None, \
    f"Funding file missing round amount column. Looked for 'Round Amount (in USD)' or 'Round Amount (USD)'. Columns: {list(funding_raw.columns)}"
print(f"Using round amount column: '{ROUND_AMOUNT_COL}'")

assert 'Round Date' in funding_raw.columns, \
    f"Funding file missing 'Round Date'. Columns: {list(funding_raw.columns)}"

# --- Inner Join ---
merged_df = pd.merge(companies_raw, funding_raw, on='Domain Name', how='inner',
                     suffixes=('_Companies', '_Funding'))
print(f"\nMerged: {merged_df.shape[0]} rows (inner join on Domain Name)")

# --- Apply Cleaning ---
# Determine column names (may have suffixes from merge)
cat_col = 'Category_Companies' if 'Category_Companies' in merged_df.columns else 'Category'
country_col = 'Country_Companies' if 'Country_Companies' in merged_df.columns else 'Country'

merged_df['amount_usd'] = merged_df[ROUND_AMOUNT_COL].apply(clean_amount)
merged_df['year'] = merged_df['Round Date'].apply(extract_year)
merged_df['category_clean'] = merged_df[cat_col].apply(clean_category)
merged_df['country_clean'] = merged_df[country_col].apply(clean_country)
merged_df['investment_type'] = merged_df['Round Name'].fillna('Unspecified').replace('', 'Unspecified').str.strip()

# Note: Rows with category_clean='Unknown' or country_clean='Unknown' are kept.
# They contribute to global totals and charts that don't filter by that dimension.
# Scope filters in chart functions naturally exclude them from dimension-specific
# charts (e.g., a region filter won't match 'Unknown' country).

# --- Filter to analysis_df (disclosed amounts, valid years) ---
analysis_df = merged_df[(merged_df['amount_usd'] > 0) & (merged_df['year'].notna())].copy()
analysis_df['year'] = analysis_df['year'].astype(int)

# --- all_events_df (all events with valid years, including undisclosed) ---
all_events_df = merged_df[merged_df['year'].notna()].copy()
all_events_df['year'] = all_events_df['year'].astype(int)

# --- companies_df (companies-only, for founded-year charts) ---
companies_df = companies_raw.copy()
companies_df['founded_year'] = companies_df['Founded Year'].apply(extract_year)
companies_df = companies_df[companies_df['founded_year'].notna()].copy()
companies_df['founded_year'] = companies_df['founded_year'].astype(int)
companies_df = companies_df[(companies_df['founded_year'] >= 1800) & (companies_df['founded_year'] <= 2030)]
companies_df['category_clean'] = companies_df['Category'].apply(clean_category)
companies_df['country_clean'] = companies_df['Country'].apply(clean_country)

# --- Summary Statistics ---
print(f"\n{'='*50}")
print(f"DATA PIPELINE SUMMARY")
print(f"{'='*50}")
print(f"Companies uploaded:    {companies_raw.shape[0]:,}")
print(f"Funding rounds uploaded: {funding_raw.shape[0]:,}")
print(f"Merged rows:           {merged_df.shape[0]:,}")
print(f"analysis_df rows:      {analysis_df.shape[0]:,} (disclosed amounts, valid years)")
print(f"all_events_df rows:    {all_events_df.shape[0]:,} (all events, valid years)")
print(f"companies_df rows:     {companies_df.shape[0]:,} (valid founded year)")
print(f"Year range:            {int(analysis_df['year'].min())} - {int(analysis_df['year'].max())}")
print(f"Unique companies:      {merged_df['Domain Name'].nunique():,}")
print(f"Unique sectors:        {analysis_df['category_clean'].nunique()}")
print(f"{'='*50}")

In [None]:
# Cell 5: Shared Utilities

def setup_charts():
    """Configure matplotlib rcParams for IEEE-compatible styling in Colab."""
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['DejaVu Serif', 'Liberation Serif', 'serif'],
        'text.usetex': False,
        'font.size': 10,
        'axes.labelsize': 10,
        'axes.titlesize': 12,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 9,
        'figure.dpi': 100,
        'savefig.dpi': 300,
        'figure.figsize': [7.25, 4.48],
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 0.8,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linewidth': 0.5,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.05,
        'savefig.facecolor': 'white',
        'savefig.edgecolor': 'none',
        'lines.linewidth': 1.5,
        'lines.markersize': 6,
    })
    # Set spine colors
    plt.rcParams['axes.edgecolor'] = '#CCCCCC'

setup_charts()


def format_currency_axis(ax, axis='y', scale='auto'):
    """Apply currency formatting to an axis.

    Args:
        ax: matplotlib Axes object
        axis: 'y' for y-axis formatting
        scale: 'B' for billions, 'M' for millions, 'auto' to detect
    """
    if scale == 'auto':
        if axis == 'y':
            lim = ax.get_ylim()
        else:
            lim = ax.get_xlim()
        max_val = max(abs(lim[0]), abs(lim[1]))
        scale = 'B' if max_val >= 1.0 else 'M'

    if scale == 'B':
        formatter = FuncFormatter(lambda x, p: f'${x:.1f}B')
    else:
        formatter = FuncFormatter(lambda x, p: f'${x:.0f}M')

    if axis == 'y':
        ax.yaxis.set_major_formatter(formatter)
    else:
        ax.xaxis.set_major_formatter(formatter)


def save_figure(fig, filename):
    """Save figure as PNG to CHART_OUTPUT_DIR with _no_title suffix."""
    os.makedirs(CHART_OUTPUT_DIR, exist_ok=True)
    filepath = os.path.join(CHART_OUTPUT_DIR, f'{filename}_no_title.png')
    fig.savefig(filepath, dpi=300, bbox_inches='tight', pad_inches=0.05)
    print(f'Saved: {filepath}')


def get_colors(n):
    """Return n colors from Paul Tol colorblind-friendly palette."""
    if n <= 7:
        return PAUL_TOL_PRIMARY[:n]
    elif n <= 16:
        return PAUL_TOL_EXTENDED[:n]
    else:
        # Cycle through extended palette
        return [PAUL_TOL_EXTENDED[i % len(PAUL_TOL_EXTENDED)] for i in range(n)]


def sanitize_filename(name):
    """Convert display name to filename-safe string matching desktop pipeline."""
    import re as _re
    name = name.replace('&', 'and')
    name = name.replace('/', '_')
    name = name.replace(',', '')
    name = name.replace('(', '').replace(')', '')
    name = name.replace(' ', '_')
    name = _re.sub(r'[^a-zA-Z0-9_.]', '', name)
    return name


def complete_year_range(data, start_year=None, end_year=None, year_col=None):
    """Ensure data has entries for every year in range, filling gaps with 0.

    Args:
        data: pandas Series (year as index) or DataFrame
        start_year: start of range (default: min year in data)
        end_year: end of range (default: max year in data)
        year_col: column name if data is a DataFrame (unused for Series)

    Returns:
        Reindexed data with gaps filled with 0
    """
    if isinstance(data, pd.Series):
        if start_year is None:
            start_year = int(data.index.min())
        if end_year is None:
            end_year = int(data.index.max())
        full_range = range(start_year, end_year + 1)
        return data.reindex(full_range, fill_value=0)
    else:
        if year_col is None:
            raise ValueError("year_col required for DataFrame input")
        if start_year is None:
            start_year = int(data[year_col].min())
        if end_year is None:
            end_year = int(data[year_col].max())
        full_range = pd.DataFrame({year_col: range(start_year, end_year + 1)})
        return full_range.merge(data, on=year_col, how='left').fillna(0)

In [None]:
# Cell 6: Chart Function Definitions

# --- Companies Founded Charts ---

def create_companies_founded_line_chart(companies_df, start_year, filename):
    """Line chart of companies founded per year."""
    counts = companies_df.groupby('founded_year').size()
    counts = complete_year_range(counts, start_year=start_year)

    fig, ax = plt.subplots()
    color = get_colors(1)[0]
    ax.plot(counts.index, counts.values, marker='o', linewidth=2.5,
            markersize=5, color=color, markerfacecolor=color,
            markeredgecolor='black', markeredgewidth=0.5)
    ax.fill_between(counts.index, counts.values, alpha=0.15, color=color)
    ax.set_xlabel('Year')
    ax.set_ylabel('Number of Companies Founded')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    save_figure(fig, filename)
    return fig


def create_companies_founded_bar_chart(companies_df, start_year, filename):
    """Bar chart of companies founded per year."""
    counts = companies_df.groupby('founded_year').size()
    counts = complete_year_range(counts, start_year=start_year)

    fig, ax = plt.subplots()
    color = get_colors(1)[0]
    ax.bar(counts.index, counts.values, color=color, edgecolor='black',
           linewidth=0.5, alpha=0.8)
    ax.set_xlabel('Year')
    ax.set_ylabel('Number of Companies Founded')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    save_figure(fig, filename)
    return fig


def create_companies_founded_outcomes_chart(companies_df, start_year, filename):
    """Stacked bar showing company outcomes by founding year."""
    df = companies_df.copy()
    # Map outcomes with priority: Deadpooled > Acquired > IPO > Active/Unknown
    def map_outcome(row):
        if str(row.get('Is Deadpooled', '')).strip().lower() == 'yes':
            return 'Failed'
        if str(row.get('Is Acquired', '')).strip().lower() == 'yes':
            return 'Acquired'
        if str(row.get('Is IPO', '')).strip().lower() == 'yes':
            return 'IPO'
        return 'Active/Unknown'

    df['outcome'] = df.apply(map_outcome, axis=1)
    pivot = df.groupby(['founded_year', 'outcome']).size().unstack(fill_value=0)

    # Ensure all outcome columns exist
    outcome_order = ['Active/Unknown', 'Failed', 'Acquired', 'IPO']
    for col in outcome_order:
        if col not in pivot.columns:
            pivot[col] = 0
    pivot = pivot[outcome_order]

    # Complete year range
    full_years = range(start_year, int(pivot.index.max()) + 1)
    pivot = pivot.reindex(full_years, fill_value=0)

    fig, ax = plt.subplots()
    colors = get_colors(4)
    bottom = np.zeros(len(pivot))
    years = pivot.index

    for i, outcome in enumerate(outcome_order):
        values = pivot[outcome].values.astype(float)
        ax.bar(years, values, bottom=bottom, color=colors[i], label=outcome,
               alpha=0.8, edgecolor='black', linewidth=0.5)
        bottom += values

    ax.set_xlabel('Year')
    ax.set_ylabel('Number of Companies Founded')
    ax.legend(loc='upper left', frameon=True, fancybox=True, shadow=True)

    # Add summary statistics text box
    total_companies = pivot.sum().sum()
    outcome_totals = pivot.sum()
    pct_active = (outcome_totals['Active/Unknown'] / total_companies) * 100
    pct_failed = (outcome_totals['Failed'] / total_companies) * 100
    pct_acquired = (outcome_totals['Acquired'] / total_companies) * 100
    pct_ipo = (outcome_totals['IPO'] / total_companies) * 100

    end_year = int(pivot.index.max())
    stats_text = (f"Total Companies ({start_year}-{end_year}): {total_companies:,.0f}\n"
                  f"Active/Unknown: {outcome_totals['Active/Unknown']:,.0f} ({pct_active:.1f}%)\n"
                  f"Failed: {outcome_totals['Failed']:,.0f} ({pct_failed:.1f}%)\n"
                  f"Acquired: {outcome_totals['Acquired']:,.0f} ({pct_acquired:.1f}%)\n"
                  f"IPO: {outcome_totals['IPO']:,.0f} ({pct_ipo:.1f}%)")
    ax.text(0.98, 0.98, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8),
            fontsize=9)

    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    save_figure(fig, filename)
    return fig


def create_companies_founded_by_sector_chart(companies_df, start_year, filename):
    """Stacked bar of companies founded per year by top sectors."""
    pivot = companies_df.groupby(['founded_year', 'category_clean']).size().unstack(fill_value=0)

    # Find top sectors
    sector_totals = pivot.sum().sort_values(ascending=False)
    top_sectors = sector_totals.head(TOP_SECTORS_COUNT).index.tolist()
    other_sectors = [c for c in pivot.columns if c not in top_sectors]

    if other_sectors:
        pivot['Other'] = pivot[other_sectors].sum(axis=1)
        pivot = pivot.drop(columns=other_sectors)

    # Reorder: top sectors then Other
    col_order = top_sectors + (['Other'] if 'Other' in pivot.columns else [])
    pivot = pivot[col_order]

    # Complete year range
    full_years = range(start_year, int(pivot.index.max()) + 1)
    pivot = pivot.reindex(full_years, fill_value=0)

    fig, ax = plt.subplots()
    colors = get_colors(len(pivot.columns))
    pivot.plot(kind='bar', stacked=True, ax=ax, color=colors,
              edgecolor='black', linewidth=0.3)
    ax.set_xlabel('Year')
    ax.set_ylabel('Number of Companies Founded')
    ax.legend(loc='best', fontsize=8)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    save_figure(fig, filename)
    return fig


def create_companies_founded_by_funding_status_chart(companies_df, start_year, filename):
    """Stacked bar of funded vs not-funded companies by founding year."""
    df = companies_df.copy()
    df['funding_status'] = df['Is Funded'].apply(
        lambda x: 'Funded' if str(x).strip().lower() == 'yes' else 'Not Funded'
    )

    pivot = df.groupby(['founded_year', 'funding_status']).size().unstack(fill_value=0)

    # Ensure both columns exist in correct order
    for col in ['Funded', 'Not Funded']:
        if col not in pivot.columns:
            pivot[col] = 0
    pivot = pivot[['Funded', 'Not Funded']]

    # Complete year range
    full_years = range(start_year, int(pivot.index.max()) + 1)
    pivot = pivot.reindex(full_years, fill_value=0)

    fig, ax = plt.subplots()
    colors = get_colors(2)
    pivot.plot(kind='bar', stacked=True, ax=ax, color=colors,
              edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Year')
    ax.set_ylabel('Number of Companies Founded')
    ax.legend(loc='best')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    save_figure(fig, filename)
    return fig


# --- Funding Charts ---

def create_total_funding_by_year_chart(funding_df, start_year, filename):
    """Dual-axis chart: bars for total annual funding, line for round counts."""
    yearly = funding_df.groupby('year').agg(
        total=('amount_usd', 'sum'),
        count=('amount_usd', 'size')
    )
    yearly['total_billions'] = yearly['total'] / 1e9

    # Complete year range
    full_years = range(start_year, int(yearly.index.max()) + 1)
    yearly = yearly.reindex(full_years, fill_value=0)

    fig, ax = plt.subplots()
    colors = get_colors(3)

    # Bars: total funding in billions
    ax.bar(yearly.index, yearly['total_billions'], color=colors[0],
           alpha=0.8, edgecolor='black', linewidth=0.5)
    format_currency_axis(ax, axis='y', scale='B')
    ax.set_xlabel('Year')
    ax.set_ylabel('Total Funding')

    # Line: round counts on secondary axis
    ax2 = ax.twinx()
    ax2.plot(yearly.index, yearly['count'], color=colors[1],
             marker='o', markersize=4, linewidth=2)
    ax2.set_ylabel('Number of Rounds', color=colors[1])
    ax2.tick_params(axis='y', labelcolor=colors[1])
    ax2.spines['right'].set_visible(True)
    ax2.spines['right'].set_color(colors[1])

    plt.xticks(rotation=45, ha='right')
    fig.tight_layout()
    save_figure(fig, filename)
    return fig


def create_stacked_funding_chart(funding_df, all_events_df, scope_name, scope_filter,
                                  start_year, filename):
    """Parameterized stacked bar+line chart for funding by round size.

    Args:
        funding_df: analysis_df (disclosed amounts only, for bar heights)
        all_events_df: all events including undisclosed (for round count line)
        scope_name: display label (e.g., 'All Robotics', 'USA')
        scope_filter: None (no filter) or callable(df) -> filtered_df
        start_year: typically 2000
        filename: output filename stem

    Returns:
        Figure object, or None if no data matches the filter.
    """
    # Apply scope filter
    if scope_filter is not None:
        filtered_df = scope_filter(funding_df)
        filtered_events = scope_filter(all_events_df)
    else:
        filtered_df = funding_df
        filtered_events = all_events_df

    if len(filtered_df) == 0:
        return None

    # Categorize round sizes
    filtered_df = filtered_df.copy()
    filtered_df['round_size_category'] = np.where(
        filtered_df['amount_usd'] >= ROUND_SIZE_THRESHOLD,
        '>= $100M', '< $100M'
    )

    # Auto-scale: check total funding across ALL years
    total_funding = filtered_df['amount_usd'].sum()
    if total_funding >= 1_000_000_000:
        scale = 'B'
        divisor = 1e9
    else:
        scale = 'M'
        divisor = 1e6

    # Pivot by year and round size category
    pivot = filtered_df.groupby(['year', 'round_size_category'])['amount_usd'].sum().unstack(fill_value=0)

    # Ensure both columns exist in correct order (< $100M at bottom)
    for col in ['< $100M', '>= $100M']:
        if col not in pivot.columns:
            pivot[col] = 0
    pivot = pivot[['< $100M', '>= $100M']]

    # Scale amounts
    pivot = pivot / divisor

    # Complete year range
    full_years = range(start_year, int(pivot.index.max()) + 1)
    pivot = pivot.reindex(full_years, fill_value=0)

    # Round counts from all_events_df (including undisclosed)
    round_counts = filtered_events.groupby('year').size()
    round_counts = round_counts.reindex(full_years, fill_value=0)

    # Plot
    fig, ax = plt.subplots()
    colors = get_colors(3)

    # Stacked bars
    years = pivot.index
    bottom = np.zeros(len(pivot))

    ax.bar(years, pivot['< $100M'].values, bottom=bottom, color=colors[0],
           label='< $100M', alpha=0.8, edgecolor='black', linewidth=0.5, width=0.8)
    bottom += pivot['< $100M'].values

    ax.bar(years, pivot['>= $100M'].values, bottom=bottom, color=colors[1],
           label='>= $100M', alpha=0.8, edgecolor='black', linewidth=0.5, width=0.8)

    format_currency_axis(ax, axis='y', scale=scale)
    ax.set_xlabel('Year')
    ax.set_ylabel('Total Funding')

    # Round count line on secondary axis
    ax2 = ax.twinx()
    ax2.plot(years, round_counts.values, color=colors[2],
             marker='o', markersize=6, linewidth=3, label='Round Count')
    ax2.set_ylabel('Number of Rounds', color=colors[2])
    ax2.tick_params(axis='y', labelcolor=colors[2])
    ax2.spines['right'].set_visible(True)
    ax2.spines['right'].set_color(colors[2])

    # Combined legend
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='best', fontsize=8)

    plt.xticks(rotation=45, ha='right')
    fig.tight_layout()
    save_figure(fig, filename)
    return fig

In [None]:
# Cell 7: Master Generation

os.makedirs(CHART_OUTPUT_DIR, exist_ok=True)
charts_generated = []
charts_failed = []

def generate_chart(name, chart_fn, *args, **kwargs):
    """Wrapper: call chart function, track success/failure, manage memory."""
    try:
        fig = chart_fn(*args, **kwargs)
        if fig is not None:
            charts_generated.append(name)
            plt.show()
            plt.close(fig)
        gc.collect()
    except Exception as e:
        charts_failed.append((name, str(e)))
        print(f"  ERROR generating {name}: {e}")
        plt.close('all')
        gc.collect()

# --- 8.2 Companies Founded Charts (7 charts) ---
print("Generating companies-founded charts...")

generate_chart('companies_founded_absolute_1900',
    create_companies_founded_line_chart, companies_df, 1900, 'companies_founded_absolute_1900')

generate_chart('companies_founded_absolute_all_years',
    create_companies_founded_line_chart, companies_df, 1900, 'companies_founded_absolute_all_years')

generate_chart('companies_founded_absolute_2000',
    create_companies_founded_line_chart, companies_df, 2000, 'companies_founded_absolute_2000')

generate_chart('companies_founded_bar_2000',
    create_companies_founded_bar_chart, companies_df, 2000, 'companies_founded_bar_2000')

generate_chart('companies_founded_outcomes_absolute',
    create_companies_founded_outcomes_chart, companies_df, 2000, 'companies_founded_outcomes_absolute')

generate_chart('companies_founded_by_sector_stacked_2000',
    create_companies_founded_by_sector_chart, companies_df, 2000, 'companies_founded_by_sector_stacked_2000')

generate_chart('companies_founded_by_funding_status_stacked_2000',
    create_companies_founded_by_funding_status_chart, companies_df, 2000, 'companies_founded_by_funding_status_stacked_2000')

plt.close('all')
gc.collect()
print(f"  Companies charts: {len(charts_generated)} generated")

# --- 8.3 Global Funding Charts (2 charts) ---
print("\nGenerating global funding charts...")
count_before = len(charts_generated)

generate_chart('total_global_funding_by_year_2000',
    create_total_funding_by_year_chart, analysis_df, 2000, 'total_global_funding_by_year_2000')

generate_chart('robotics_funding_by_year_2000_onwards',
    create_stacked_funding_chart, analysis_df, all_events_df,
    'All Robotics', None, 2000, 'robotics_funding_by_year_2000_onwards')

plt.close('all')
gc.collect()
print(f"  Global funding charts: {len(charts_generated) - count_before} generated")

# --- 8.4 Sector Charts ---
print("\nGenerating sector charts...")
count_before = len(charts_generated)
sectors = sorted(analysis_df['category_clean'].unique())

for sector in sectors:
    sanitized = sanitize_filename(sector)
    filename = f'funding_by_round_size_{sanitized}_2000_onwards'
    filter_fn = lambda df, s=sector: df[df['category_clean'] == s]
    generate_chart(filename,
        create_stacked_funding_chart, analysis_df, all_events_df,
        sector, filter_fn, 2000, filename)

plt.close('all')
gc.collect()
print(f"  Sector charts: {len(charts_generated) - count_before} generated")

# --- 8.5 Region Charts (12 charts) ---
print("\nGenerating region charts...")
count_before = len(charts_generated)

for region, countries in SECOND_TIER_REGIONS.items():
    sanitized = sanitize_filename(region)
    filename = f'funding_by_round_size_{sanitized}_2000_onwards'
    filter_fn = lambda df, c=countries: df[df['country_clean'].isin(c)]
    generate_chart(filename,
        create_stacked_funding_chart, analysis_df, all_events_df,
        region, filter_fn, 2000, filename)

plt.close('all')
gc.collect()
print(f"  Region charts: {len(charts_generated) - count_before} generated")

# --- 8.6 Subcategory Charts ---
print("\nGenerating subcategory charts...")
count_before = len(charts_generated)

# Determine Subcategory column name (may have suffix from merge)
subcat_col = 'Subcategory_Companies' if 'Subcategory_Companies' in analysis_df.columns else 'Subcategory'
subcategories = sorted(analysis_df[subcat_col].dropna().unique())

for subcategory in subcategories:
    sanitized = sanitize_filename(subcategory)
    filename = f'funding_by_round_size_{sanitized}_2000_onwards'
    filter_fn = lambda df, s=subcategory, col=subcat_col: df[df[col] == s]
    generate_chart(filename,
        create_stacked_funding_chart, analysis_df, all_events_df,
        subcategory, filter_fn, 2000, filename)

plt.close('all')
gc.collect()
print(f"  Subcategory charts: {len(charts_generated) - count_before} generated")

# --- 8.7 Summary Report ---
print(f"\n{'='*60}")
print(f"Chart Generation Complete")
print(f"{'='*60}")
print(f"Generated: {len(charts_generated)} charts")
print(f"Failed: {len(charts_failed)} charts")
if charts_failed:
    print(f"\nFailed charts:")
    for name, error in charts_failed:
        print(f"  - {name}: {error}")
png_count = len([f for f in os.listdir(CHART_OUTPUT_DIR) if f.endswith('.png')])
print(f"\nOutput directory: {CHART_OUTPUT_DIR}")
print(f"Total PNG files: {png_count}")
print(f"{'='*60}")

In [None]:
# Cell 8: Export & Download

zip_filename = 'robotics_charts.zip'
zip_count = 0

with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
    for png_file in sorted(os.listdir(CHART_OUTPUT_DIR)):
        if png_file.endswith('.png'):
            filepath = os.path.join(CHART_OUTPUT_DIR, png_file)
            zf.write(filepath, arcname=png_file)
            zip_count += 1

print(f"ZIP created: {zip_filename}")
print(f"Files in ZIP: {zip_count}")
print(f"ZIP size: {os.path.getsize(zip_filename) / 1024 / 1024:.1f} MB")
print(f"\nNote: Colab's filesystem is ephemeral. Download the ZIP before the session ends.")

files.download(zip_filename)