In [6]:
"""
AI Exposure and College Enrollment Analysis - 4-DIGIT CIP VERSION
==================================================================
Updated for:
- 4-digit CIP codes (436 programs vs 49 at 2-digit level)
- 2019-2025 enrollment data (combined from both files)
- More granular analysis (e.g., Computer Science vs Information Systems)
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION: MANUAL FOD → CIP4 MAPPINGS
# =============================================================================
# These are mappings not present in the crosswalk file or that need correction.
# Each entry: dict with FOD, CIP4, CIP4_title, notes
# To add more: just append to this list!

MANUAL_MAPPINGS = [
    {
        'FOD': 6107,
        'CIP4': '5138',
        'CIP4_title': 'Registered Nursing/Nursing Administration/Nursing Research and Clinical Nursing',
        'notes': 'Added Nov 6 2025 - FOD 6107 missing mapping to CIP 5138 (490K students)'
    },
    {
        'FOD': 3611,
        'CIP4': '2615',
        'CIP4_title': 'Neurobiology and Neurosciences',
        'notes': 'Added Nov 6 2025 - FOD 3611 not in original crosswalk'
    },
    {
        'FOD': 5202,
        'CIP4': '4228',
        'CIP4_title': 'Clinical, Counseling and Applied Psychology',
        'notes': 'Added Nov 6 2025 - FOD 5202 (Clinical Psychology) maps to CIP 4228'
    },
    {
        'FOD': 5203,
        'CIP4': '4228',
        'CIP4_title': 'Clinical, Counseling and Applied Psychology',
        'notes': 'Added Nov 6 2025 - FOD 5203 (Counseling Psychology) maps to CIP 4228'
    },

      {
        'FOD': 5203,
        'CIP4': '4228',
        'CIP4_title': 'Clinical, Counseling and Applied Psychology',
        'notes': 'Added Nov 6 2025 - FOD 5203 (Counseling Psychology) maps to CIP 4228'
    },
     {
        'FOD': 5098,
        'CIP4': '4099',
        'CIP4_title': 'Physical Sciences, other',
        'notes': 'Added Nov 6 2025 - FOD 5098 (Multi-disciplinary or General Science) maps to CIP 4099'
    },
    
]

print(f"Loaded {len(MANUAL_MAPPINGS)} manual FOD→CIP4 mappings")

# =============================================================================
# STEP 1: LOAD FELTEN AIOE DATA
# =============================================================================

def load_felten_data(filepath: str) -> pd.DataFrame:
    """Load Felten et al. (2021) AIOE scores."""
    print("="*70)
    print("LOADING FELTEN AIOE DATA")
    print("="*70)
    
    felten = pd.read_excel(filepath, sheet_name='LM AIOE')
    felten['soc_clean'] = felten['SOC Code'].str.replace('-', '').str.replace('.', '')
    felten['AIOE'] = felten['Language Modeling AIOE'] 
    print(f"\nLoaded {len(felten)} occupations")
    print(f"AIOE range: {felten['AIOE'].min():.2f} to {felten['AIOE'].max():.2f}")
    
    return felten


# =============================================================================
# STEP 2: LOAD FOD TO 4-DIGIT CIP CROSSWALK
# =============================================================================

def load_fod_cip4_crosswalk(filepath: str, manual_mappings: list = None) -> pd.DataFrame:
    """
    Load FOD to 4-digit CIP mapping from crosswalk file.
    
    The crosswalk has detailed 6-digit CIP codes (like 11.0701).
    We extract 4-digit CIP:
    - Family (2 digits): 11 = Computer Science
    - Group (next 2 digits): 07 = Computer Science 
    - Combined: 1107 = Computer Science (4-digit)
    
    Examples:
    - 11.0000 → 1100 (Computer Science, General)
    - 11.0701 → 1107 (Computer Science)
    - 52.0201 → 5202 (Business Administration)
    
    Parameters:
    -----------
    filepath : str
        Path to crosswalk Excel file
    manual_mappings : list of dict
        Additional manual mappings to append. Each dict should have keys:
        'FOD', 'CIP4', 'CIP4_title', 'notes'
    
    Returns:
    --------
    DataFrame with columns ['FOD', 'CIP4', 'CIP4_title']
    """
    print("\n" + "="*70)
    print("LOADING FOD TO 4-DIGIT CIP CROSSWALK")
    print("="*70)
    
    # Read from "CIP code by HHES code" sheet
    df = pd.read_excel(filepath, sheet_name='CIP code by HHES code', skiprows=1)
    
    # Extract FOD and CIP columns
    crosswalk = df[['HHES Code', 'CIP \nCode', 'CIP Title']].copy()
    crosswalk.columns = ['FOD', 'CIP', 'CIP_title']
    crosswalk = crosswalk.dropna(subset=['FOD', 'CIP'])
    
    # Convert FOD to integer
    crosswalk['FOD'] = crosswalk['FOD'].astype(int)
    
    # Extract 4-digit CIP from 6-digit CIP code
    # CIP format: XX.XXXX where first 2 are family, next 2 are group
    # E.g., 11.0701 → 1107
    def extract_cip4(cip_6digit):
        try:
            cip_float = float(cip_6digit)
            # Get integer part (family, 2 digits)
            family = int(cip_float)  # e.g., 11
            # Get first 2 decimal digits (group)
            decimal_part = cip_float - family  # e.g., 0.0701
            # Extract first 2 decimal digits
            group = int(round(decimal_part * 10000)) // 100  # e.g., 07
            # Combine to 4-digit code
            cip4 = f"{family:02d}{group:02d}"  # e.g., "1107"
            return cip4
        except:
            return None
    
    crosswalk['CIP4'] = crosswalk['CIP'].apply(extract_cip4)
    crosswalk = crosswalk.dropna(subset=['CIP4'])
    
    # Create many-to-many mapping (each FOD can map to multiple CIP4s)
    # Keep CIP4_title for the most general title per CIP4
    fod_to_cip4_df = crosswalk.groupby(['FOD', 'CIP4']).agg({
        'CIP_title': 'first'  # Take first title (they're usually the same for same CIP4)
    }).reset_index()
    fod_to_cip4_df.columns = ['FOD', 'CIP4', 'CIP4_title']
    
    print(f"\nLoaded {len(fod_to_cip4_df)} FOD→CIP4 mappings from crosswalk file")
    print(f"  {fod_to_cip4_df['FOD'].nunique()} unique FODs")
    print(f"  {fod_to_cip4_df['CIP4'].nunique()} unique CIP4 codes")
    print(f"  Average {len(fod_to_cip4_df) / fod_to_cip4_df['FOD'].nunique():.1f} CIP4 codes per FOD")
    
    # Append manual mappings if provided
    if manual_mappings:
        manual_df = pd.DataFrame(manual_mappings)[['FOD', 'CIP4', 'CIP4_title']]
        fod_to_cip4_df = pd.concat([fod_to_cip4_df, manual_df], ignore_index=True)
        print(f"\n✓ Added {len(manual_mappings)} manual mappings")
        for mapping in manual_mappings:
            print(f"  FOD {mapping['FOD']} → CIP4 {mapping['CIP4']} ({mapping['CIP4_title']})")
    
    # Show sample mappings
    print("\nSample mappings:")
    sample_fods = sorted(fod_to_cip4_df['FOD'].unique())[:10]
    for fod in sample_fods:
        cips = fod_to_cip4_df[fod_to_cip4_df['FOD'] == fod]['CIP4'].tolist()
        print(f"  FOD {fod} → CIP4 {cips}")
    
    return fod_to_cip4_df


# =============================================================================


# =============================================================================
# STEP 2B: ADD EMPIRICAL ENROLLMENT WEIGHTS TO FOD→CIP4 MAPPING
# =============================================================================

def add_empirical_weights_to_crosswalk(
    fod_to_cip4: pd.DataFrame,
    enrollment: pd.DataFrame,
    base_year: int = 2019
) -> pd.DataFrame:
    """
    Add empirical enrollment weights to FOD→CIP4 mapping.
    
    For each FOD that maps to multiple CIP4s, calculate weights based on
    actual 2019 enrollment: weight_i = enrollment_i / sum(enrollment for all CIP4s that FOD maps to)
    
    This creates a Bayesian update: P(CIP4 | FOD) ∝ enrollment(CIP4)
    
    Parameters:
    -----------
    fod_to_cip4 : DataFrame with columns ['FOD', 'CIP4', 'CIP4_title']
    enrollment : DataFrame with columns ['CIP4', 'year', 'enrollment']
    base_year : Year to use for calculating weights (default 2019)
    
    Returns:
    --------
    DataFrame with columns ['FOD', 'CIP4', 'CIP4_title', 'empirical_weight']
    """
    print("\n" + "="*70)
    print("ADDING EMPIRICAL ENROLLMENT WEIGHTS TO FOD→CIP4 MAPPING")
    print("="*70)
    
    # Get base year enrollment
    enroll_base = enrollment[enrollment['year'] == base_year][['CIP4', 'enrollment', 'CIP4_title']].copy()
    print(f"\nUsing {base_year} enrollment as basis for weights")
    print(f"  {len(enroll_base)} CIP4 codes have enrollment data")
    
    # Merge enrollment into crosswalk
    crosswalk_with_enroll = fod_to_cip4.merge(
        enroll_base[['CIP4', 'enrollment']],
        on='CIP4',
        how='left'
    )
    
    # For CIP4s with no enrollment data, use a small value (1.0) as placeholder
    crosswalk_with_enroll['enrollment'] = crosswalk_with_enroll['enrollment'].fillna(1.0)
    
    # For each FOD, calculate weights as proportion of total enrollment
    # weight_i = enrollment_i / sum_j(enrollment_j) for all j that FOD maps to
    fod_totals = crosswalk_with_enroll.groupby('FOD')['enrollment'].transform('sum')
    crosswalk_with_enroll['empirical_weight'] = crosswalk_with_enroll['enrollment'] / fod_totals
    
    # Clean up
    crosswalk_weighted = crosswalk_with_enroll[['FOD', 'CIP4', 'CIP4_title', 'empirical_weight']].copy()
    
    # Report
    print(f"\nCalculated empirical weights for {len(crosswalk_weighted)} FOD→CIP4 mappings")
    
    # Show examples
    print("\nSample weighted mappings:")
    sample_fods = crosswalk_weighted['FOD'].unique()[:3]
    for fod in sample_fods:
        fod_mappings = crosswalk_weighted[crosswalk_weighted['FOD'] == fod]
        print(f"\n  FOD {fod} maps to {len(fod_mappings)} CIP4 codes:")
        for _, row in fod_mappings.iterrows():
            print(f"    CIP4 {row['CIP4']}: weight = {row['empirical_weight']:.3f}")
    
    # Sanity check: weights should sum to 1.0 for each FOD
    weight_sums = crosswalk_weighted.groupby('FOD')['empirical_weight'].sum()
    if not np.allclose(weight_sums, 1.0):
        print(f"\n⚠ WARNING: Some FOD weights don't sum to 1.0!")
        print(f"  Min: {weight_sums.min():.6f}, Max: {weight_sums.max():.6f}")
    else:
        print(f"\n✓ All FOD weights sum to 1.0")
    
    return crosswalk_weighted


# STEP 3: LOAD AND PROCESS ACS PUMS DATA
# =============================================================================

def load_and_filter_acs(
    filepath: str,
    age_min: int = 22,
    age_max: int = 35
) -> pd.DataFrame:
    """
    Load and filter IPUMS ACS PUMS data.
    
    Your ACS columns: DEGFIELDD, OCCSOC, PERWT, AGE, EDUC, YEAR
    """
    print("\n" + "="*70)
    print("LOADING ACS PUMS DATA")
    print("="*70)
    
    acs = pd.read_csv(filepath)
    
    print(f"\nInitial sample: {len(acs):,} observations")
    
    # Filter out missing, invalid, and zero FODs
    acs_filtered = acs[
        (acs['AGE'] >= age_min) & 
        (acs['AGE'] <= age_max) &
        (acs['OCCSOC'].notna()) &
        (acs['DEGFIELDD'].notna()) &
        (acs['DEGFIELDD'] != 0)  # Exclude FOD = 0 (invalid/no field of degree)
    ].copy()
    
    print(f"Filtered sample: {len(acs_filtered):,} observations")
    print(f"  - Age {age_min}-{age_max}")
    print(f"  - Valid occupation (OCCSOC) and field of degree (DEGFIELDD)")
    
    # Clean SOC codes
    acs_filtered['soc_clean'] = acs_filtered['OCCSOC'].astype(str).str.replace('-', '').str.replace('.', '')
    
    print(f"\nUnique DEGFIELDD codes: {acs_filtered['DEGFIELDD'].nunique()}")
    print(f"Unique OCCSOC codes: {acs_filtered['OCCSOC'].nunique()}")
    
    return acs_filtered


# =============================================================================
# STEP 4: MAP FOD TO 4-DIGIT CIP AND MERGE WITH EXPOSURE
# =============================================================================

def process_acs_with_exposure(
    acs: pd.DataFrame,
    felten: pd.DataFrame,
    fod_to_cip4: pd.DataFrame
) -> pd.DataFrame:
    """
    Map ACS FOD codes to 4-digit CIP using many-to-many relationship.
    """
    print("\n" + "="*70)
    print("MAPPING FOD TO 4-DIGIT CIP AND MERGING AI EXPOSURE")
    print("="*70)
    
    # Map FOD to CIP4
    # Map FOD to CIP4 using many-to-many relationship
    # Each ACS observation can contribute to multiple CIP4 codes
    acs_with_cip = acs.merge(
        fod_to_cip4,
        left_on='DEGFIELDD',
        right_on='FOD',
        how='inner'
    )
    
    # Report mapping success
    n_original = len(acs)
    n_after_mapping = len(acs_with_cip)
    n_unique_people = acs_with_cip['DEGFIELDD'].nunique() if 'DEGFIELDD' in acs_with_cip.columns else len(acs_with_cip)
    
    # Use empirical weights from crosswalk (already calculated based on 2019 enrollment)
    # Each ACS person contributes weight_split = PERWT * empirical_weight to each CIP4
    acs_with_cip['weight_split'] = acs_with_cip['PERWT'] * acs_with_cip['empirical_weight']
    
    avg_cips = len(acs_with_cip) / len(acs)
    print(f"  Each ACS person contributes to avg {avg_cips:.1f} CIP4 codes (weighted by 2019 enrollment)")
    
    
    print(f"\nMapped {n_unique_people:,} ACS observations to {len(acs_with_cip):,} CIP4 mappings")
    print(f"  (Average {n_after_mapping/n_unique_people if n_unique_people > 0 else 0:.1f} CIP4 codes per person)")
    
    # Check unmapped FODs
    if len(acs_with_cip) < len(acs):
        unmapped = acs[~acs['DEGFIELDD'].isin(fod_to_cip4['FOD'])]
        unmapped_fods = unmapped['DEGFIELDD'].value_counts().head(10)
        print("\nTop 10 unmapped FOD codes:")
        print(unmapped_fods)
    
    
    # Filter to successfully mapped
    
    # Merge with Felten AIOE scores on SOC code
    acs_with_exposure = acs_with_cip.merge(
        felten[['soc_clean', 'AIOE']],
        on='soc_clean',
        how='left'
    )
    
    # Report merge success
    n_matched = acs_with_exposure['AIOE'].notna().sum()
    pct_matched = 100 * n_matched / len(acs_with_exposure)
    print(f"\nMatched {n_matched:,}/{len(acs_with_exposure):,} observations to AI exposure ({pct_matched:.1f}%)")
    
    
    # Drop observations with missing AIOE (do NOT impute with mean)
    n_missing = acs_with_exposure['AIOE'].isna().sum()
    if n_missing > 0:
        print(f"\n⚠ Dropping {n_missing:,} observations with missing AIOE scores")
        acs_with_exposure = acs_with_exposure[acs_with_exposure['AIOE'].notna()].copy()
    
    print(f"\nFinal sample: {len(acs_with_exposure):,} observations")
    print(f"Unique 4-digit CIP codes: {acs_with_exposure['CIP4'].nunique()}")
    
    return acs_with_exposure


# =============================================================================
# STEP 5: CALCULATE 4-DIGIT CIP-LEVEL AI EXPOSURE
# =============================================================================

def calculate_cip4_exposure(
    acs: pd.DataFrame,
    weight_var: str = 'weight_split'
) -> pd.DataFrame:
    """
    Calculate weighted average AI exposure by 4-digit CIP code.
    
    For each CIP4: AI_exposure = Σ [P(occupation|CIP4) × AIOE(occupation)]
    where P(occupation|CIP4) is weighted by split weights (PERWT × empirical_weight).
    """
    print("\n" + "="*70)
    print("CALCULATING 4-DIGIT CIP-LEVEL AI EXPOSURE SCORES")
    print("="*70)
    
    # Calculate weighted average by CIP4
    cip_exposure = acs.groupby('CIP4').apply(
        lambda x: pd.Series({
            'ai_exposure_score': np.average(x['AIOE'], weights=x[weight_var]),
            'n_obs': len(x),
            'n_weighted': x[weight_var].sum(),
            'min_exposure': x['AIOE'].min(),
            'max_exposure': x['AIOE'].max(),
            'std_exposure': np.sqrt(np.average((x['AIOE'] - np.average(x['AIOE'], weights=x[weight_var]))**2, 
                                                weights=x[weight_var])),
            'CIP4_title': x['CIP4_title'].iloc[0] if 'CIP4_title' in x.columns else ''
        })
    ).reset_index()
    
    print(f"\nCalculated AI exposure for {len(cip_exposure)} 4-digit CIP codes")
    print("\nAI Exposure Score Distribution:")
    print(cip_exposure['ai_exposure_score'].describe())
    
    # Show top and bottom CIPs
    print("\n\nTop 20 most AI-exposed majors (4-digit CIP):")
    top20 = cip_exposure.nlargest(20, 'ai_exposure_score')[['CIP4', 'CIP4_title', 'ai_exposure_score', 'n_obs']]
    print(top20.to_string(index=False))
    
    print("\n\nBottom 20 least AI-exposed majors (4-digit CIP):")
    bottom20 = cip_exposure.nsmallest(20, 'ai_exposure_score')[['CIP4', 'CIP4_title', 'ai_exposure_score', 'n_obs']]
    print(bottom20.to_string(index=False))
    
    return cip_exposure


# =============================================================================
# STEP 6: LOAD AND COMBINE ENROLLMENT DATA (2019-2025)
# =============================================================================

def load_and_combine_enrollment_data(
    filepath_2024: str, 
    filepath_2025: str
) -> pd.DataFrame:
    """
    Load and combine enrollment data from two sources with proper header handling.
    
    2024 file: Major Field (4-year, Undergrad) sheet, years 2019-2024
    2025 file: CIP Group Enrollment sheet, years 2020-2025 (filter to Undergraduate 4-year)
    
    Returns combined dataset with 4-digit CIP codes (2019-2025), including CIP4_title.
    """
    print("\n" + "="*70)
    print("LOADING AND COMBINING ENROLLMENT DATA (2019-2025)")
    print("="*70)
    
    # ===== LOAD 2024 FILE =====
    print("\nLoading 2019-2024 data from CTEESpring2024-Appendix.xlsx...")
    df_2024 = pd.read_excel(
        filepath_2024, 
        sheet_name='Major Field (4-year, Undergrad)',
        header=2  # Row 2 has the actual column headers
    )
    print(f"  Loaded {len(df_2024)} rows")
    
    # Rename columns for clarity
    df_2024 = df_2024.rename(columns={
        'Major Field Family (2-digit CIP)': 'CIP2',
        'Major Field Family (2-digit) Title': 'CIP2_title',
        'Major Field Group (4-digit CIP)': 'CIP4',
        'Major Field Group (4-digit) Title': 'CIP4_title'
    })
    
    # Get enrollment columns (years 2019-2024)
    years_2024 = [2019, 2020, 2021, 2022, 2023, 2024]
    enrollment_cols = [col for col in df_2024.columns if 'Enrollment' in str(col) and '% Change' not in str(col)]
    print(f"  Found {len(enrollment_cols)} enrollment columns for years 2019-2024")
    
    # Reshape to long format
    data_2024 = []
    for idx, row in df_2024.iterrows():
        cip4 = row['CIP4']
        cip4_title = row['CIP4_title']
        if pd.isna(cip4) or cip4 == 'Total':
            continue
        for year, col in zip(years_2024, enrollment_cols):
            enrollment = row[col]
            if pd.notna(enrollment) and enrollment != '*':
                data_2024.append({
                    'CIP4': str(cip4)[:4] if pd.notna(cip4) and str(cip4) != 'Total' else None,
                    'CIP4_title': cip4_title,
                    'year': year,
                    'enrollment': float(enrollment)
                })
    
    df_2024_long = pd.DataFrame(data_2024)
    print(f"  Reshaped to {len(df_2024_long)} observations")
    
    # ===== LOAD 2025 FILE =====
    print("\nLoading 2020-2025 data from CTEESpring2025-DataAppendix.xlsx...")
    df_2025 = pd.read_excel(
        filepath_2025,
        sheet_name='CIP Group Enrollment',
        header=2  # Row 2 has the actual column headers
    )
    print(f"  Loaded {len(df_2025)} rows")
    
    # Filter to Undergraduate 4-year only
    df_2025 = df_2025[df_2025['Award Level and Institution Type'] == 'Undergraduate 4-year'].copy()
    print(f"  Filtered to {len(df_2025)} Undergraduate 4-year rows")
    
    # Rename columns
    df_2025 = df_2025.rename(columns={
        'Major Field Family \n(2-digit CIP)': 'CIP2',
        'Major Field Family \n(2-digit CIP) Title': 'CIP2_title',
        'Major Field Group \n(4-digit CIP)': 'CIP4',
        'Major Field Group \n(4-digit CIP) Title': 'CIP4_title'
    })
    
    # Get enrollment columns (years 2020-2025)
    # The enrollment columns alternate: Enrollment, % Change, Enrollment, % Change...
    # Columns 5, 6, 8, 10, 12, 14 correspond to years 2020-2025
    years_2025 = [2020, 2021, 2022, 2023, 2024, 2025]
    enrollment_col_indices = [5, 6, 8, 10, 12, 14]
    
    # Reshape to long format
    data_2025 = []
    for idx, row in df_2025.iterrows():
        cip4 = row['CIP4']
        cip4_title = row['CIP4_title']
        if pd.isna(cip4) or cip4 == 'Total':
            continue
        for year, col_idx in zip(years_2025, enrollment_col_indices):
            enrollment = row.iloc[col_idx]
            if pd.notna(enrollment) and enrollment != '*':
                data_2025.append({
                    'CIP4': str(cip4)[:4] if pd.notna(cip4) and str(cip4) != 'Total' else None,
                    'CIP4_title': cip4_title,
                    'year': year,
                    'enrollment': float(enrollment)
                })
    
    df_2025_long = pd.DataFrame(data_2025)
    print(f"  Reshaped to {len(df_2025_long)} observations")
    
    # ===== COMBINE DATASETS =====
    print("\nCombining datasets...")
    
    # For overlapping years (2020-2024), use 2025 file data (more recent)
    df_2024_unique = df_2024_long[df_2024_long['year'] == 2019].copy()
    
    enrollment = pd.concat([df_2024_unique, df_2025_long], axis=0, ignore_index=True)
    enrollment = enrollment.sort_values(['CIP4', 'year']).reset_index(drop=True)
    
    print(f"\n✓ Combined dataset: {len(enrollment)} observations")
    print(f"  Years: {sorted(enrollment['year'].unique())}")
    print(f"  Unique 4-digit CIP codes: {enrollment['CIP4'].nunique()}")
    
    # Summary stats
    print("\nTotal enrollment by year:")
    yearly_enrollment = enrollment.groupby('year')['enrollment'].sum()
    for year, total in yearly_enrollment.items():
        print(f"  {year}: {total:,.0f}")
    
    return enrollment


# =============================================================================
# STEP 7: MERGE AND FINALIZE
# =============================================================================

def merge_enrollment_exposure(
    enrollment: pd.DataFrame,
    cip_exposure: pd.DataFrame
) -> pd.DataFrame:
    """
    Merge enrollment data with AI exposure scores.
    Preserves CIP4_title from enrollment data (more complete).
    """
    print("\n" + "="*70)
    print("MERGING ENROLLMENT WITH AI EXPOSURE")
    print("="*70)
    # Normalize CIP4 codes to match (both as zero-padded 4-char strings)
    enrollment['CIP4'] = enrollment['CIP4'].astype(str).str.zfill(4)
    cip_exposure['CIP4'] = cip_exposure['CIP4'].astype(str).str.zfill(4)
    
    
    # Merge on 4-digit CIP code
    # Keep CIP4_title from enrollment (left) as it's more complete
    df_final = enrollment.merge(
        cip_exposure[['CIP4', 'ai_exposure_score', 'n_obs']],
        on='CIP4',
        how='left'
    )
    
    # Report merge success
    n_matched = df_final['ai_exposure_score'].notna().sum()
    pct_matched = 100 * n_matched / len(df_final)
    print(f"\nMatched {n_matched}/{len(df_final)} enrollment records ({pct_matched:.1f}%)")
    
    # Create treatment variables
    median_exposure = df_final['ai_exposure_score'].median()
    df_final['high_ai_exposure'] = (
        df_final['ai_exposure_score'] > median_exposure
    ).astype(int)
    
    # Standardized exposure
    df_final['ai_exposure_std'] = (
        (df_final['ai_exposure_score'] - df_final['ai_exposure_score'].mean()) /
        df_final['ai_exposure_score'].std()
    )
    
    # Terciles (with error handling for insufficient unique values)
    try:
        df_final['ai_exposure_tercile'] = pd.qcut(
            df_final['ai_exposure_score'],
            q=3,
            labels=['Low', 'Medium', 'High'],
            duplicates='drop'
        )
    except ValueError as e:
        # If qcut fails (e.g., too many NaNs or duplicates), use simple cut
        print(f"⚠ Could not create terciles: {e}")
        print("  Using quartile-based cut instead")
        df_final['ai_exposure_tercile'] = pd.cut(
            df_final['ai_exposure_score'],
            bins=3,
            labels=['Low', 'Medium', 'High']
        )
    
    # Create log enrollment
    df_final['log_enrollment'] = np.log(df_final['enrollment'] + 1)
    
    print("\n\nFinal dataset:")
    print(df_final.head(20))
    print(f"\nShape: {df_final.shape}")
    print(f"Columns: {list(df_final.columns)}")
    
    return df_final


# =============================================================================
# STEP 8: VISUALIZATION
# =============================================================================

def create_descriptive_plots(df: pd.DataFrame, output_path: str = 'enrollment_trends_4digit.png'):
    """
    Create descriptive visualizations for 4-digit CIP analysis.
    """
    print("\n" + "="*70)
    print("CREATING VISUALIZATIONS")
    print("="*70)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Enrollment trends by AI exposure group
    ax1 = axes[0, 0]
    trend_data = df.groupby(['year', 'high_ai_exposure'])['enrollment'].sum().reset_index()
    # Normalize to 2019 (show as % of 2019 enrollment)
    trend_2019 = trend_data[trend_data['year'] == 2019].set_index('high_ai_exposure')['enrollment']
    trend_data['enrollment_pct_2019'] = trend_data.apply(
        lambda row: (row['enrollment'] / trend_2019[row['high_ai_exposure']]) * 100
            if row['high_ai_exposure'] in trend_2019.index else 100,
        axis=1
    )
    
    for group in [0, 1]:
        data = trend_data[trend_data['high_ai_exposure'] == group]
        label = 'High AI Exposure' if group else 'Low AI Exposure'
        ax1.plot(data['year'], data['enrollment_pct_2019'], marker='o', label=label, linewidth=2)
    ax1.set_xlabel('Year', fontsize=12)
    ax1.set_ylabel('Enrollment (% of 2019)', fontsize=12)
    ax1.set_title('Enrollment Trends by AI Exposure (4-digit CIP)', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)
    ax1.axvline(2022.5, color='red', linestyle='--', alpha=0.5, label='ChatGPT Launch')
    
    # 2. Distribution of AI exposure
    ax2 = axes[0, 1]
    cip_scores = df.groupby('CIP4')['ai_exposure_score'].first()
    ax2.hist(cip_scores, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
    ax2.axvline(cip_scores.median(), color='red', linestyle='--', linewidth=2, label='Median')
    ax2.set_xlabel('AI Exposure Score', fontsize=12)
    ax2.set_ylabel('Number of 4-digit CIP Codes', fontsize=12)
    ax2.set_title('Distribution of AI Exposure Across Majors', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    # 3. Scatter: enrollment growth vs exposure
    ax3 = axes[1, 0]
    first_year = df['year'].min()
    last_year = df['year'].max()
    
    growth_data = []
    for cip in df['CIP4'].unique():
        cip_data = df[df['CIP4'] == cip]
        enroll_first = cip_data[cip_data['year'] == first_year]['enrollment'].values
        enroll_last = cip_data[cip_data['year'] == last_year]['enrollment'].values
        if len(enroll_first) > 0 and len(enroll_last) > 0 and enroll_first[0] > 0:
            growth = (enroll_last[0] - enroll_first[0]) / enroll_first[0] * 100
            exposure = cip_data['ai_exposure_score'].iloc[0] if len(cip_data) > 0 else None
            if exposure is not None and pd.notna(exposure):
                growth_data.append({'CIP4': cip, 'growth_rate': growth, 'ai_exposure': exposure})
    
    growth_df = pd.DataFrame(growth_data)
    if len(growth_df) > 0:
        ax3.scatter(growth_df['ai_exposure'], growth_df['growth_rate'], alpha=0.6, s=30)
        ax3.set_xlabel('AI Exposure Score', fontsize=12)
        ax3.set_ylabel(f'Enrollment Growth Rate ({first_year}-{last_year}, %)', fontsize=12)
        ax3.set_title('Growth Rate vs AI Exposure', fontsize=14, fontweight='bold')
        ax3.axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5)
        ax3.grid(alpha=0.3)
        
        # Add correlation
        corr = growth_df[['ai_exposure', 'growth_rate']].corr().iloc[0, 1]
        ax3.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
                transform=ax3.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 4. Enrollment by tercile over time
    ax4 = axes[1, 1]
    tercile_data = df.groupby(['year', 'ai_exposure_tercile'])['enrollment'].sum().reset_index()
    # Normalize to 2019
    tercile_2019 = tercile_data[tercile_data['year'] == 2019].set_index('ai_exposure_tercile')['enrollment']
    tercile_data['enrollment_pct_2019'] = tercile_data.apply(
        lambda row: (row['enrollment'] / tercile_2019[row['ai_exposure_tercile']]) * 100
            if row['ai_exposure_tercile'] in tercile_2019.index else 100,
        axis=1
    )
    
    for tercile in ['Low', 'Medium', 'High']:
        data = tercile_data[tercile_data['ai_exposure_tercile'] == tercile]
        if len(data) > 0:
            ax4.plot(data['year'], data['enrollment_pct_2019'], marker='o', label=f'{tercile} Exposure', linewidth=2)
    ax4.set_xlabel('Year', fontsize=12)
    ax4.set_ylabel('Enrollment (% of 2019)', fontsize=12)
    ax4.set_title('Enrollment by AI Exposure Tercile', fontsize=14, fontweight='bold')
    ax4.legend()
    ax4.grid(alpha=0.3)
    ax4.axvline(2022.5, color='red', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Saved plots to {output_path}")
    plt.close()




# =============================================================================
# STEP 8B: TERCILE DEEP-DIVE VISUALIZATION
# =============================================================================

def create_tercile_deepdive_plots(df: pd.DataFrame, output_path: str = 'enrollment_tercile_deepdive.png'):
    """
    Create detailed enrollment trend plots for top 5 majors within each AI exposure tercile.
    
    For each tercile (Low/Medium/High), shows:
    - Top 5 CIP4 codes by 2019 enrollment
    - Enrollment trends 2019-2025 (normalized to 2019 = 100%)
    - CIP4 labels with titles
    - 2025: Actual enrollment number labeled
    - 2023: AI exposure score labeled
    - % coverage: what fraction of tercile enrollment these top 5 represent
    """
    print("\n" + "="*70)
    print("CREATING TERCILE DEEP-DIVE PLOTS (TOP 5)")
    print("="*70)
    
    # Filter to rows with valid tercile assignment
    df_valid = df[df['ai_exposure_tercile'].notna()].copy()
    
    # Create figure with 3 subplots (1 row x 3 cols)
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    
    # Color palette for 5 lines
    colors = plt.cm.tab10(np.linspace(0, 1, 5))
    
    for idx, tercile in enumerate(['Low', 'Medium', 'High']):
        ax = axes[idx]
        
        # Get data for this tercile
        tercile_data = df_valid[df_valid['ai_exposure_tercile'] == tercile].copy()
        
        if len(tercile_data) == 0:
            print(f"⚠ No data for {tercile} tercile, skipping...")
            continue
        
        # Get 2019 baseline
        tercile_2019 = tercile_data[tercile_data['year'] == 2019].copy()
        
        # Get top 5 CIP4 by 2019 enrollment
        top5_cip4s = tercile_2019.nlargest(5, 'enrollment')['CIP4'].values
        
        # Calculate coverage
        top5_enrollment = tercile_2019[tercile_2019['CIP4'].isin(top5_cip4s)]['enrollment'].sum()
        total_enrollment = tercile_2019['enrollment'].sum()
        coverage_pct = (top5_enrollment / total_enrollment * 100) if total_enrollment > 0 else 0
        
        print(f"\n{tercile} Tercile:")
        print(f"  Top 5 CIP4s: {list(top5_cip4s)}")
        print(f"  Coverage: {coverage_pct:.1f}% of {tercile} tercile enrollment")
        print(f"  2019 enrollment in top 5: {top5_enrollment:,.0f} / {total_enrollment:,.0f}")
        
        # For each of the top 5 CIP4s, plot enrollment trend
        for i, cip4 in enumerate(top5_cip4s):
            cip_data = tercile_data[tercile_data['CIP4'] == cip4].copy()
            
            if len(cip_data) == 0:
                continue
            
            # Get 2019 baseline for this CIP4
            baseline_2019 = cip_data[cip_data['year'] == 2019]['enrollment'].values
            if len(baseline_2019) == 0 or baseline_2019[0] == 0:
                continue
            baseline_2019 = baseline_2019[0]
            
            # Normalize to 2019 = 100%
            cip_data['enrollment_pct'] = (cip_data['enrollment'] / baseline_2019) * 100
            
            # Get CIP4 title (truncate if too long)
            cip4_title = cip_data['CIP4_title'].iloc[0] if len(cip_data) > 0 else ''
            if len(cip4_title) > 30:
                cip4_title = cip4_title[:27] + '...'
            
            # Get AI exposure score
            ai_exposure = cip_data['ai_exposure_score'].iloc[0] if len(cip_data) > 0 else None
            
            # Plot
            label = f"{cip4}: {cip4_title}"
            ax.plot(cip_data['year'], cip_data['enrollment_pct'], 
                   marker='o', label=label, linewidth=2.5, color=colors[i], alpha=0.8, markersize=6)
            
            # Add label for 2025 (actual enrollment)
            data_2025 = cip_data[cip_data['year'] == 2025]
            if len(data_2025) > 0:
                enrollment_2025 = data_2025['enrollment'].values[0]
                enrollment_pct_2025 = data_2025['enrollment_pct'].values[0]
                
                # Smart vertical offset to avoid overlap
                offset = (i - 2) * 8  # Spread labels vertically (-16, -8, 0, 8, 16)
                
                ax.annotate(f'{enrollment_2025:,.0f}',
                           xy=(2025, enrollment_pct_2025),
                           xytext=(8, offset),
                           textcoords='offset points',
                           fontsize=8,
                           color=colors[i],
                           fontweight='bold',
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=colors[i], alpha=0.8))
            
            # Add label for 2023 (AI exposure score)
            data_2023 = cip_data[cip_data['year'] == 2023]
            if len(data_2023) > 0 and ai_exposure is not None:
                enrollment_pct_2023 = data_2023['enrollment_pct'].values[0]
                
                # Smart vertical offset to avoid overlap
                offset_y = (i - 2) * 6  # Spread labels vertically
                
                ax.annotate(f'AI: {ai_exposure:.3f}',
                           xy=(2023, enrollment_pct_2023),
                           xytext=(-35, offset_y),
                           textcoords='offset points',
                           fontsize=7,
                           color=colors[i],
                           style='italic',
                           bbox=dict(boxstyle='round,pad=0.2', facecolor='lightyellow', edgecolor=colors[i], alpha=0.7))
        
        # Styling
        ax.set_xlabel('Year', fontsize=12)
        ax.set_ylabel('Enrollment (% of 2019)', fontsize=12)
        ax.set_title(f'{tercile} AI Exposure - Top 5 Majors', fontsize=14, fontweight='bold')
        ax.grid(alpha=0.3)
        ax.axvline(2022.5, color='red', linestyle='--', alpha=0.5, linewidth=1.5, label='ChatGPT Launch')
        ax.axhline(100, color='gray', linestyle=':', alpha=0.5, linewidth=1)
        
        # Add coverage annotation
        annotation_text = f"Top 5: {coverage_pct:.1f}% of tercile\nN = {top5_enrollment:,.0f} (2019)"
        ax.text(0.02, 0.98, annotation_text, 
               transform=ax.transAxes, 
               verticalalignment='top',
               fontsize=9,
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Legend - smaller font, outside plot
        ax.legend(loc='upper left', bbox_to_anchor=(0, -0.12), 
                 ncol=1, fontsize=9, framealpha=0.9)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Saved tercile deep-dive plots to {output_path}")
    plt.close()


# =============================================================================
# # STEP 9: DIAGNOSTIC REPORTING
# =============================================================================

def generate_diagnostic_report(
    acs: pd.DataFrame,
    fod_to_cip4: pd.DataFrame,
    enrollment: pd.DataFrame
):
    """
    Generate diagnostic report showing coverage gaps:
    (i) ACS FOD codes not in crosswalk
    (ii) CIP4 codes with enrollment but no FOD mapping
    (iii) Top unmapped ACS FODs by weighted person-count
    """
    print("\n" + "#"*70)
    print("# DIAGNOSTIC REPORT: COVERAGE ANALYSIS")
    print("#"*70)
    
    # (i) ACS FOD codes not in crosswalk
    acs_fods = set(acs['DEGFIELDD'].dropna().unique())
    crosswalk_fods = set(fod_to_cip4['FOD'].unique())
    missing_fods = acs_fods - crosswalk_fods
    
    print(f"\n(i) ACS FOD codes NOT in crosswalk mapping:")
    print(f"    Total: {len(missing_fods)} FOD codes")
    if len(missing_fods) > 0:
        print(f"    FODs: {sorted(list(missing_fods))[:20]}")
        # How many ACS observations do these represent?
        missing_fod_count = acs[acs['DEGFIELDD'].isin(missing_fods)]['PERWT'].sum()
        total_count = acs['PERWT'].sum()
        print(f"    Represents {missing_fod_count:,.0f} / {total_count:,.0f} ACS observations ({missing_fod_count/total_count*100:.1f}%)")
    
    # (ii) CIP codes with enrollment but no FOD mapping
    enrollment_cips = set(enrollment['CIP4'].unique())
    crosswalk_cips = set(fod_to_cip4['CIP4'].unique())
    unmapped_cips = enrollment_cips - crosswalk_cips
    
    print(f"\n(ii) CIP4 codes with enrollment but NOT mapped from any FOD:")
    print(f"     Total: {len(unmapped_cips)} CIP4 codes")
    if len(unmapped_cips) > 0:
        # Get enrollment counts for these
        unmapped_enroll = enrollment[enrollment['CIP4'].isin(unmapped_cips)]
        unmapped_2019 = unmapped_enroll[unmapped_enroll['year'] == 2019]['enrollment'].sum()
        total_2019 = enrollment[enrollment['year'] == 2019]['enrollment'].sum()
        print(f"     CIP4s: {sorted(list(unmapped_cips))[:30]}")
        print(f"     2019 enrollment: {unmapped_2019:,.0f} / {total_2019:,.0f} ({unmapped_2019/total_2019*100:.1f}%)")
        print(f"\n     Top 10 unmapped CIP4s by 2019 enrollment:")
        top_unmapped = unmapped_enroll[unmapped_enroll['year'] == 2019].nlargest(10, 'enrollment')[['CIP4', 'CIP4_title', 'enrollment']]
        for _, row in top_unmapped.iterrows():
            print(f"       CIP4 {row['CIP4']} ({row['CIP4_title']}): {row['enrollment']:,.0f} students")
    
    # (iii) NEW: Top unmapped ACS FODs by weighted person-count
    print(f"\n(iii) Top 20 unmapped ACS FOD codes by weighted person-count:")
    if len(missing_fods) > 0:
        unmapped_acs = acs[acs['DEGFIELDD'].isin(missing_fods)]
        top_unmapped_fods = unmapped_acs.groupby('DEGFIELDD')['PERWT'].sum().sort_values(ascending=False).head(20)
        print(f"\n     {'FOD':<8} {'Weighted Count':>15} {'% of Total':>10}")
        print(f"     {'-'*8} {'-'*15} {'-'*10}")
        for fod, count in top_unmapped_fods.items():
            pct = count / total_count * 100
            print(f"     {int(fod):<8} {count:>15,.0f} {pct:>9.2f}%")
    else:
        print("     All ACS FODs are mapped!")


# =============================================================================
# MAIN EXECUTION
# =============================================================================

def main():
    """
    Main execution function.
    """
    
    print("\n" + "#"*70)
    print("# AI EXPOSURE AND ENROLLMENT ANALYSIS - 4-DIGIT CIP")
    print("#"*70 + "\n")
    
    # FILE PATHS
    FELTEN_PATH = '/Users/jeffreyohl/Dropbox/CollegeMajorData/FeltenEtAl/2023_Language Modeling AIOE and AIIE.xlsx'
    CROSSWALK_PATH = '/Users/jeffreyohl/Dropbox/CollegeMajorData/Crosswalks/crosswalk_handout.xlsx'
    ACS_PATH = '/Users/jeffreyohl/Dropbox/CollegeMajorData/IPUMS/usa_00008.csv'
    ENROLLMENT_PATH_2025 = '/Users/jeffreyohl/Dropbox/CollegeMajorData/National Student Clearinghouse Data/CTEESpring2025-DataAppendix.xlsx'
    ENROLLMENT_PATH_2024 = '/Users/jeffreyohl/Dropbox/CollegeMajorData/National Student Clearinghouse Data/CTEESpring2024-Appendix.xlsx'
    OUTPUT_DIR = '/Users/jeffreyohl/Dropbox/CollegeMajorData/output'
    
    try:
        # Step 1: Load Felten data
        felten = load_felten_data(FELTEN_PATH)
        
        # Step 2: Load FOD to 4-digit CIP crosswalk (with manual mappings)
        fod_to_cip4 = load_fod_cip4_crosswalk(CROSSWALK_PATH, manual_mappings=MANUAL_MAPPINGS)
        
        # Step 3: Load and combine enrollment data (2019-2025) - MOVED UP!
        enrollment = load_and_combine_enrollment_data(ENROLLMENT_PATH_2024, ENROLLMENT_PATH_2025)
        
        # Step 4: Add empirical enrollment weights to crosswalk
        fod_to_cip4_weighted = add_empirical_weights_to_crosswalk(
            fod_to_cip4, enrollment, base_year=2019
        )
        
        # Step 5-6: Process ACS and calculate exposure scores
        print("\n" + "="*70)
        print("PROCESSING ACS DATA")
        print("="*70)
        acs = load_and_filter_acs(ACS_PATH)
        acs_with_exposure = process_acs_with_exposure(acs, felten, fod_to_cip4_weighted)
        cip_exposure = calculate_cip4_exposure(acs_with_exposure)
        
        # Save exposure scores
        cip_exposure.to_csv(f'{OUTPUT_DIR}/cip4_ai_exposure_scores.csv', index=False)
        print(f"\n✓ Saved exposure scores to {OUTPUT_DIR}/cip4_ai_exposure_scores.csv")
        
        # Step 7: Merge enrollment with exposure
        df_final = merge_enrollment_exposure(enrollment, cip_exposure)
        
        # Save final dataset
        df_final.to_csv(f'{OUTPUT_DIR}/enrollment_with_ai_exposure_4digit.csv', index=False)
        print(f"\n✓ Saved final dataset to {OUTPUT_DIR}/enrollment_with_ai_exposure_4digit.csv")
        
        # Step 8: Create visualizations
        create_descriptive_plots(df_final, f'{OUTPUT_DIR}/enrollment_trends_4digit.png')
        
        # Step 8B: Create tercile deep-dive plots
        create_tercile_deepdive_plots(df_final, f'{OUTPUT_DIR}/enrollment_tercile_deepdive.png')
        
        # Step 9: Diagnostic reporting - what's missing?
        generate_diagnostic_report(acs, fod_to_cip4_weighted, enrollment)
        
        print("\n" + "#"*70)
        print("# DATA PREPARATION COMPLETE (4-DIGIT CIP)")
        print("#"*70)
        print("\nNext steps:")
        print("1. Review cip4_ai_exposure_scores.csv to validate exposure scores")
        print("2. Check enrollment_with_ai_exposure_4digit.csv for data quality")
        print("3. Review diagnostic report and add more manual mappings if needed")
        print("4. Run econometric_analysis.py for DiD and event study")
        print("\n4-digit CIP analysis provides:")
        print("  - Computer Science (1107) vs Information Systems (1104)")
        print("  - Business Administration (5202) vs Finance (5208) vs Accounting (5203)")  
        print("  - More granular treatment effects and heterogeneity analysis")
        
    except FileNotFoundError as e:
        print(f"\n❌ Error: File not found - {e}")
        print("\nPlease check that all data files exist at the specified paths:")
        print(f"  - Felten: {FELTEN_PATH}")
        print(f"  - Crosswalk: {CROSSWALK_PATH}")
        print(f"  - ACS: {ACS_PATH}")
        print(f"  - Enrollment 2024: {ENROLLMENT_PATH_2024}")
        print(f"  - Enrollment 2025: {ENROLLMENT_PATH_2025}")
    except Exception as e:
        print(f"\n❌ Error occurred: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

Loaded 6 manual FOD→CIP4 mappings

######################################################################
# AI EXPOSURE AND ENROLLMENT ANALYSIS - 4-DIGIT CIP
######################################################################

LOADING FELTEN AIOE DATA

Loaded 774 occupations
AIOE range: -1.85 to 1.93

LOADING FOD TO 4-DIGIT CIP CROSSWALK

Loaded 614 FOD→CIP4 mappings from crosswalk file
  191 unique FODs
  398 unique CIP4 codes
  Average 3.2 CIP4 codes per FOD

✓ Added 6 manual mappings
  FOD 6107 → CIP4 5138 (Registered Nursing/Nursing Administration/Nursing Research and Clinical Nursing)
  FOD 3611 → CIP4 2615 (Neurobiology and Neurosciences)
  FOD 5202 → CIP4 4228 (Clinical, Counseling and Applied Psychology)
  FOD 5203 → CIP4 4228 (Clinical, Counseling and Applied Psychology)
  FOD 5203 → CIP4 4228 (Clinical, Counseling and Applied Psychology)
  FOD 5098 → CIP4 4099 (Physical Sciences, other)

Sample mappings:
  FOD 1100 → CIP4 ['0100']
  FOD 1101 → CIP4 ['0100', '0101', '0102',