In [1]:
def plot_combined_params_trend(open_license_only=False, use_active_params=True, show_model_names=False, benchmark_name='MMLU-Pro (Reasoning & Knowledge)'):
    """
    Plot record-small parameter trends for three MMLU ranges (30-50, 50-70, 70-90) 
    on a single graph.
    
    Parameters:
      - open_license_only: if True, only include models where 'License' contains 'open'
      - use_active_params: if True, use 'Active Parameters'; else, use 'Parameters'
      - show_model_names: if True, display model names next to record-small points
      - benchmark_name: name of the benchmark column to use
    """
    # Define MMLU ranges and plotting styles
    mmlu_ranges = [(30, 50), (50, 70), (70, 90)]
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  # Matplotlib default colors
    markers = ['o', 's', '^']
    
    # Setup plot
    plt.figure(figsize=(12, 8))
    
    # Column names
    mmlu_col = benchmark_name
    param_col = 'Active Parameters' if use_active_params else 'Parameters'
    license_col = 'License'

    # Loop over MMLU ranges
    for i, (min_mmlu, max_mmlu) in enumerate(mmlu_ranges):
        # --- Data Preparation ---
        df_work = df.copy()

        df_work[mmlu_col] = (
            df_work[mmlu_col].astype(str)
                             .str.replace('%', '', regex=False)
                             .astype(float)
        )
        
        df_work[param_col] = pd.to_numeric(df_work[param_col], errors='coerce')
        df_work['Params_B'] = df_work[param_col]

        if open_license_only:
            df_work = df_work[
                df_work[license_col].notna() &
                df_work[license_col].str.contains('open', case=False, na=False)
            ]

        df_sub = df_work[df_work[mmlu_col].between(min_mmlu, max_mmlu)].copy()
        
        df_sub = df_sub.dropna(subset=['Release Date', 'Params_B'])
        df_sub = df_sub[df_sub['Params_B'] > 0]
        
        if len(df_sub) < 2:
            print(f"Not enough data for MMLU range {min_mmlu}-{max_mmlu} to plot a trend.")
            continue
            
        # --- Record-Small Calculation & Regression ---
        df_sub = df_sub.sort_values('Release Date')
        df_sub['Date_Ordinal'] = df_sub['Release Date'].map(datetime.toordinal)
        
        df_sub['Is_Record_Small'] = df_sub['Params_B'].cummin() == df_sub['Params_B']
        record_small = df_sub[df_sub['Is_Record_Small']].copy()
        
        if len(record_small) < 2:
            print(f"Not enough record-small points for MMLU range {min_mmlu}-{max_mmlu} to plot a trend.")
            continue

        X_rec = record_small['Date_Ordinal'].values.reshape(-1, 1)
        y_rec_log = np.log10(record_small['Params_B'].values)
        rec_ols = LinearRegression().fit(X_rec, y_rec_log)
        
        lo, hi = record_small['Date_Ordinal'].min(), record_small['Date_Ordinal'].max()
        x_range = np.arange(lo, hi + 1)
        x_dates = [datetime.fromordinal(int(d)) for d in x_range]
        y_rec_log_pred = rec_ols.predict(x_range.reshape(-1, 1))

        annual_pct_rec = ((10**rec_ols.coef_[0])**365 - 1) * 100
        factor_rec_str = ""
        if annual_pct_rec < 0:
            factor_rec = 1 / (1 + annual_pct_rec / 100)
            factor_rec_str = f' ({factor_rec:.1f}× decrease/yr)'
        
        # --- Plotting ---
        plt.scatter(
            record_small['Release Date'], record_small['Params_B'],
            color=colors[i], s=80, marker=markers[i], alpha=0.8,
            label=f'MMLU {min_mmlu}-{max_mmlu}% record-small'
        )
        
        plt.plot(
            x_dates, 10**y_rec_log_pred,
            color=colors[i], linestyle='--', linewidth=2,
            label=f'MMLU {min_mmlu}-{max_mmlu}% trend: {annual_pct_rec:.1f}%/yr{factor_rec_str}'
        )

        if show_model_names and 'Model' in record_small.columns:
            for _, row in record_small.iterrows():
                plt.annotate(
                    row['Model'], (row['Release Date'], row['Params_B']),
                    xytext=(5, 5), textcoords='offset points',
                    fontsize=8, color=colors[i]
                )
    
    # --- Final Plot Formatting ---
    plt.yscale('log')
    plt.xlabel('Release Date', fontsize=12)
    param_type_label = 'Active Parameters' if use_active_params else 'Parameters'
    plt.ylabel(f'{param_type_label} (Billions)', fontsize=12)
    lic_label = 'open-license only' if open_license_only else 'all licenses'
    benchmark_short = benchmark_name.split(' ')[0]
    plt.title(f'Record-Small Model Size Trend by {benchmark_short} ({lic_label})', fontsize=14)
    plt.grid(True, which='both', ls='--', alpha=0.4)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Example usage for open-license models, showing model names
plot_combined_params_trend(open_license_only=True, use_active_params=True, show_model_names=True)

# Example usage for all models
# plot_combined_params_trend(open_license_only=False, use_active_params=True, show_model_names=True)

NameError: name 'plt' is not defined