## Cross entropy loss curves

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# ——— Data ———
parameters = [8e6, 35e6, 150e6, 350e6, 650e6]
parameters_labels = ['8M', '35M', '150M', '350M', '650M']
datasets = ["Quarter Data", "Half Data", "Full Data"]

# Loss values obtained from the infer_test.ipynb notebook.
evaluation_losses = {
    "Quarter Data": [0.2932, 0.2920, 0.2937, 0.2938, 0.2942],
    "Half Data":    [0.2925, 0.2884, 0.2902, 0.2905, 0.2895],
    "Full Data":    [0.2903, 0.2868, 0.2857, 0.2859, 0.2866],
}

# ——— Which points to fit on and to plot? ———
subset_indices = {
    "Quarter Data": slice(0, 3),   # keep only 8M, 35M, 150M
    "Half Data":    slice(0, 4),   # keep only 8M, 35M, 150M, 350M
    "Full Data":    slice(1, None), # keep only 35M,150M,350M,650M
}

# ——— Fit & find minimum-in-range ———
min_loss_points = {}
for ds in datasets:
    idx = subset_indices[ds]
    ps = np.array(parameters)[idx]
    ls = np.array(evaluation_losses[ds])[idx]
    logp = np.log10(ps)

    # 2nd-degree fit
    coeffs = np.polyfit(logp, ls, deg=2)
    a, b, c = coeffs
    poly = np.poly1d(coeffs)

    # Find minimum within the fitted interval
    if abs(a) < 1e-9: # Essentially linear
        ends = [poly(logp[0]), poly(logp[-1])]
        i_min = np.argmin(ends)
        # Use the parameter corresponding to the minimum loss end
        min_p, min_l = (ps[0], ends[0]) if i_min == 0 else (ps[-1], ends[1])
    elif a > 0: # Parabola opens upwards, minimum exists
        logp_min = -b / (2 * a)
        # Check if the theoretical minimum is within the fitted range
        if logp_min < logp[0]:
            min_p, min_l = ps[0], poly(logp[0])
        elif logp_min > logp[-1]:
             min_p, min_l = ps[-1], poly(logp[-1])
        else:
             min_p, min_l = 10**logp_min, poly(logp_min)
    else: # Parabola opens downwards, minimum is at one end
        loss_start, loss_end = poly(logp[0]), poly(logp[-1])
        if loss_start < loss_end:
            min_p, min_l = ps[0], loss_start
        else:
            min_p, min_l = ps[-1], loss_end

    min_loss_points[ds] = {
        'coeffs': coeffs,
        'logp_fit': logp,
        'min_param': min_p,
        'min_loss':  min_l
    }

# ——— Plot ———
sns.set_context("notebook")
palette = dict(zip(datasets[::-1],
                   sns.color_palette("colorblind", 3).as_hex()))

plt.figure(figsize=(10, 7))
for ds in datasets:
    idx = subset_indices[ds]
    ps = np.array(parameters)[idx]
    ls = np.array(evaluation_losses[ds])[idx]
    lp = min_loss_points[ds]['logp_fit']
    a, b, c = min_loss_points[ds]['coeffs']
    poly = np.poly1d((a, b, c))

    
    plt.scatter(ps, ls,
                color=palette[ds], s=50) # Removed label

    
    smooth = np.linspace(lp.min(), lp.max(), 200)
    plt.plot(10**smooth, poly(smooth),
             color=palette[ds], linewidth=2,
             label=f"{ds}") # MODIFIED label

    
    mp = min_loss_points[ds]['min_param']
    ml = min_loss_points[ds]['min_loss']
    plt.scatter([mp], [ml],
                marker='v', edgecolor='black',
                s=100, color=palette[ds], zorder=5)


plt.xscale("log")
plt.xticks(parameters, parameters_labels, fontsize=14)
plt.xlabel("Model Size (Parameters)", fontsize=16)
plt.ylabel("Test Loss", fontsize=18)
plt.tick_params(axis='y', labelsize=14)

plt.legend(fontsize=12, title="Legend", title_fontsize=13, loc='lower left')
plt.grid(which="both", linestyle="--", linewidth=0.5)
plt.tight_layout()
plt.savefig('lossـreg.png', dpi=300, bbox_inches='tight')
plt.show() 

## Projection plot 

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import linregress
from matplotlib.ticker import FuncFormatter, LogLocator, NullFormatter # Import LogLocator and NullFormatter

# --- 1. Data Definition ---
model_size = np.array([30500000, 66200000, 163000000]) # Parameters
data_size = np.array([412182, 824363, 1648726])       # Sequences

# Create a DataFrame and sort by Model Size
df = pd.DataFrame({'Model Size': model_size, 'Data Size': data_size}).sort_values('Model Size').reset_index(drop=True)
point_labels = ['Q', 'H', 'F'] # Labels for the sorted points

print("--- Input Data (Sorted by Model Size) ---")
print(df)
print(f"Point Labels (correspond to sorted points): {point_labels}")
print("-" * 30)

# --- 2. Linear Regression (on original linear data) ---
slope, intercept, r_value, p_value, std_err = linregress(df['Model Size'], df['Data Size'])

print("\n--- Linear Regression Results (on original scale) ---")
print(f"  Slope (m): {slope:.4e}")
print(f"  Intercept (c): {intercept:.4e}")
print(f"  R-squared (R^2): {r_value**2:.4f}")
print(f"  P-value: {p_value:.4f}")
print(f"  Standard Error of Slope: {std_err:.4e}")
print(f"\nEquation: Data Size ≈ ({slope:.4e} * Model Size) + {intercept:.4e}")
print("-" * 30)

# --- 3. Prediction Calculation ---
new_model_size = 6.5e8 # 650 million parameters
predicted_data_size = (slope * new_model_size) + intercept

if predicted_data_size <= 0:
    print("\n--- Warning ---")
    print("Predicted data size is non-positive based on linear fit. Cannot plot prediction point on log scale.")
    print("-" * 30)
    plot_prediction = False
else:
    plot_prediction = True
    print("\n--- Prediction ---")
    print(f"Model Size for Prediction: {new_model_size:.2e} parameters (650 Million)")
    print(f"Predicted Required Data Size: Approximately {predicted_data_size:,.0f} sequences (based on linear fit)")
    print("-" * 30)


# --- 4. Plotting ---
plt.style.use('seaborn-v0_8-whitegrid')
plt.figure(figsize=(14, 9))
ax = plt.gca()

palette = {'Q': 'green', 'H': 'orange', 'F': 'blue'}
point_colors = [palette[label] for label in point_labels]

sns.scatterplot(x='Model Size', y='Data Size', data=df, s=350,
                color=point_colors,
                zorder=5, ax=ax, legend=False)

annotation_offsets = {'Q': (15, -25), 'H': (15, 15), 'F': (15, 15)}
for i in range(df.shape[0]):
    x_coord = df['Model Size'][i]
    y_coord = df['Data Size'][i]
    label = point_labels[i]
    color = palette[label]
    offset = annotation_offsets[label]
    annotation_text = (
        f"{label}:\n"
        f"{x_coord/1e6:.1f}M params\n"
        f"{y_coord/1e3:.1f}k seqs"
    )
    ax.annotate(annotation_text,
                xy=(x_coord, y_coord), xytext=offset, textcoords='offset points',
                fontsize=12, color=color, ha='left', va='center',
                bbox=dict(boxstyle="round,pad=0.3", fc="white", ec=color, alpha=0.8),
                arrowprops=dict(arrowstyle="->", color=color, connectionstyle="arc3,rad=0.2"),
                zorder=6)

# Generate points for the regression line
x_min_plot = min(df['Model Size'].min() * 0.8, 8e6) # Start around 8M
x_max_plot = max(df['Model Size'].max() * 1.1, new_model_size * 1.1)
x_line = np.linspace(x_min_plot, x_max_plot, 200)
y_line = slope * x_line + intercept
positive_y_indices = y_line > 0
x_line_pos = x_line[positive_y_indices]
y_line_pos = y_line[positive_y_indices]

if len(x_line_pos) > 1:
    
    ax.plot(x_line_pos, y_line_pos, color='red', linestyle='--', label=f'Linear Fit (R²={r_value**2:.2f})', zorder=3)
else:
    print("Warning: Not enough positive points to plot the regression line on the log scale.")

if plot_prediction:
    ax.scatter(new_model_size, predicted_data_size,
               marker='*', s=400, color='purple', zorder=7, label='Prediction (650M params)')

    # Use annotate with an offset so the text sits right by the star
    ax.annotate(
        f"Prediction:\n~{predicted_data_size/1e6:.1f}M seqs",
        xy=(new_model_size, predicted_data_size),
        xytext=(-50, 10),              
        textcoords='offset points',
        fontsize=12,
        verticalalignment='bottom',
        horizontalalignment='right',
        color='purple',
        fontweight='bold',
        zorder=12,
        arrowprops=dict(
            arrowstyle="->",
            color='purple',
            connectionstyle="arc3,rad=0.2"
        )
    )


# --- SET LOG SCALE ---
ax.set_xscale('log')
ax.set_yscale('log')


# Custom readable formatter function (handles k, M, G)
def readable_log_formatter_func(val, pos):
    if val >= 1e9:
        if np.isclose(val / 1e9, round(val / 1e9)): return f'{round(val / 1e9):.0f}G'
        else: return f'{val / 1e9:.1f}G'
    elif val >= 1e6:
        if np.isclose(val / 1e6, round(val / 1e6)): return f'{round(val / 1e6):.0f}M'
        else: return f'{val / 1e6:.1f}M'
    elif val >= 1e3:
        if np.isclose(val / 1e3, round(val / 1e3)): return f'{round(val / 1e3):.0f}k'
        else: return f'{val / 1e3:.1f}k'
    else:
        if np.isclose(val, round(val)): return f'{round(val):.0f}'
        else: return f'{val:.1f}'

# Apply the custom formatter to MAJOR ticks
ax.xaxis.set_major_formatter(FuncFormatter(readable_log_formatter_func))
ax.yaxis.set_major_formatter(FuncFormatter(readable_log_formatter_func))

# Explicitly set MAJOR tick locations (powers of 10)
ax.xaxis.set_major_locator(LogLocator(base=10.0, numticks=15)) # numticks is approximate
ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=15))

# Explicitly set MINOR tick locations (e.g., 2, 3, ..., 9 within each decade)
ax.xaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10), numticks=15))
ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10), numticks=15))

# Format MINOR ticks using the same readable formatter (can be crowded)
ax.xaxis.set_minor_formatter(FuncFormatter(readable_log_formatter_func))
ax.yaxis.set_minor_formatter(FuncFormatter(readable_log_formatter_func))
# To hide minor tick labels (less clutter):
# ax.xaxis.set_minor_formatter(NullFormatter())
# ax.yaxis.set_minor_formatter(NullFormatter())


# Adjust tick label sizes
ax.tick_params(axis='both', which='major', labelsize=16,labelrotation=45)
ax.tick_params(axis='x', which='minor', labelsize=16, labelrotation=45) # Rotate minor x-labels
ax.tick_params(axis='y', which='minor', labelsize=16, labelrotation=45)


# Add labels and title
ax.set_xlabel("Model Size (Number of Parameters)", fontsize=18)
ax.set_ylabel("Required Paired Ab Sequences ", fontsize=18)
#ax.set_title("Model Size vs. Required Data Size (Log-Log Scale with Linear Fit)", fontsize=14, fontweight='bold')

# --- Legend handling ---


ax.grid(True, which="major", ls="-", c='0.75') # Solid lines for major
ax.grid(True, which="minor", ls=":", c='0.85') # Dotted lines for minor

plt.tight_layout()

print("\n--- Plotting (Log Scale with Detailed Ticks and Annotations) ---")
print("Generating log-log scale plot (may appear in a separate window)...")
plt.savefig("model_vs_data_loglog_detailed_ticks.png", dpi=300, bbox_inches='tight')
plt.show()
print("-" * 30)

# --- 5. Prediction Reminder ---
print("\nReminder: This prediction is based on a LINEAR model fitted to only three data points.")
print("The log-log plot helps visualize the data distribution, but the prediction still uses the linear fit.")
print("If the data appears linear on the log-log plot, a power law model (linear regression on log-transformed data) might be more appropriate.")
print("The current prediction involves extrapolation. Use with caution.")
print("-" * 30)