# Intuition plots

This notebook is just used to generate plots for the intuition figure of the paper.

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

def create_gaussian(x, mu, sigma):
    return np.exp(-(x - mu) ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))


x_range = np.linspace(-10, 30, 1000)
correct_model_vals = 2 * create_gaussian(x_range, 0, 1) + create_gaussian(x_range, 15, 5)
incorrect_model_vals = 2 * create_gaussian(x_range, 0, 2) + create_gaussian(x_range, 15, 7)

Path("plots").mkdir(parents=True, exist_ok=True)


plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(5, 2))
    
    # Function to find x-threshold for rightmost 40% of area under curve
    def find_area_threshold(x_vals, y_vals, target_fraction=0.4):
        import builtins
        total_area = np.trapz(y_vals, x_vals)
        target_area = total_area * target_fraction
        
        # Start from the right and integrate backwards until we reach target area
        for i in builtins.range(len(x_vals) - 1, 0, -1):
            area_from_right = np.trapz(y_vals[i:], x_vals[i:])
            if area_from_right >= target_area:
                return x_vals[i]
        return x_vals[0]
    
    # Calculate thresholds for each curve
    correct_threshold = find_area_threshold(x_range, correct_model_vals)
    incorrect_threshold = find_area_threshold(x_range, incorrect_model_vals)
    
    # Plot lines
    line1 = plt.plot(x_range, correct_model_vals, label="Disentangled SAE", linewidth=0.5)
    line2 = plt.plot(x_range, incorrect_model_vals, label="SAE mixing correlated features", linewidth=0.5)
    
    # Get colors from the lines for consistent coloring
    color1 = line1[0].get_color()
    color2 = line2[0].get_color()
    
    # Fill entire area with light transparency
    plt.fill_between(x_range, correct_model_vals, alpha=0.3, color=color1)
    plt.fill_between(x_range, incorrect_model_vals, alpha=0.3, color=color2)
    
    # Fill rightmost 40% by area with darker color
    plt.fill_between(x_range, correct_model_vals, 
                     where=(x_range >= correct_threshold), 
                     alpha=0.6, color=color1, interpolate=True)
    plt.fill_between(x_range, incorrect_model_vals,
                     where=(x_range >= incorrect_threshold),
                     alpha=0.6, color=color2, interpolate=True)
    # Add arrows pointing to threshold points
    # Find y-values at threshold points
    correct_idx = np.argmin(np.abs(x_range - correct_threshold))
    incorrect_idx = np.argmin(np.abs(x_range - incorrect_threshold))
    correct_y = correct_model_vals[correct_idx]
    incorrect_y = incorrect_model_vals[incorrect_idx]
    
    # Add arrows with labels
    plt.annotate('$s_n^{dec}$', 
                xy=(correct_threshold, correct_y), 
                xytext=(correct_threshold + 4, correct_y + 0.15),
                arrowprops=dict(arrowstyle='->', color=color1, lw=0.5),
                fontsize=10, ha='center', color=color1)
    
    plt.annotate('$s_n^{dec}$', 
                xy=(incorrect_threshold, incorrect_y), 
                xytext=(incorrect_threshold + 4, incorrect_y + 0.15),
                arrowprops=dict(arrowstyle='->', color=color2, lw=0.5),
                fontsize=10, ha='center', color=color2)
    
    # plt.xlabel("Decoder projection on input activations")
    plt.title("Idealized decoder projection histogram")
    legend = plt.legend()
    for line in legend.get_lines():
        line.set_linewidth(1)
    # Customize legend border
    legend.get_frame().set_edgecolor('lightgray')
    legend.get_frame().set_linewidth(0.5)

    # Apply axis customization outside the seaborn context
    ax = plt.gca()
    ax.grid(False)
    ax.set_xticks([0])
    ax.set_xticklabels(['0'])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_facecolor('white')
    ax.tick_params(axis='x', which='major', length=4, width=0.5, color='black')
    ax.tick_params(axis='y', which='both', left=False, labelleft=False)
    ax.set_ylim(bottom=0)

    plt.savefig("plots/intuition_plot.pdf")
    
    plt.show()



In [None]:
SEABORN_RC_CONTEXT