In [None]:
def run_experiment_on_logit_attribitions(dataset, n, activation_to_ablate): 
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    indices = indices[:n]
    
    results = []
    
    for idx in indices: 
        prompt, subject, target = sample_dataset(dataset, idx=idx)
        true_fact, corrupted_facts, target = resample_ablation(prompt, subject, target, n_noise_samples=10)
        result = compensatory_effect(
            clean_prompt=true_fact, 
            corrupted_prompts=corrupted_facts, 
            target=target, 
            corrupted_ablation=True, 
            activation_to_ablate = activation_to_ablate,
            mode="all",
            mlp_input=False,
            early_exit_1 = True
        )
        results.append(result)
    return results
results = run_experiment_on_logit_attribitions(dataset, n=100, activation_to_ablate="attn_out") 


def plot_logit_attribution_by_layer(results, attribution_type="attn", use_target=True):
    # labels, target_residual_clean_logit, mle_residual_corrupted_logit, target_residual_clean_logit, mle_residual_clean_logit
    layer_names = results[0][0]
    idx_of_interest = torch.tensor([attribution_type in l in l for l in layer_names])
    n_ablations = idx_of_interest.sum()
    
    all_residual_clean_logit = torch.zeros((len(results), n_ablations))
    all_residual_corrupted_logit = torch.zeros((len(results), n_ablations))
    delta_residual_logit = torch.zeros((len(results), n_ablations))
    
    for i,result in enumerate(results): 
        
        (labels,target_residual_corrupted_logit,mle_residual_corrupted_logit,target_residual_clean_logit, mle_residual_clean_logit) = result
        if use_target: 
            residual_clean_logit = target_residual_clean_logit.squeeze().mean(dim=-1)
            residual_corrupted_logit = target_residual_corrupted_logit.squeeze().mean(dim=-1)
        else: 
            residual_clean_logit = mle_residual_clean_logit.squeeze().mean(dim=-1)
            residual_corrupted_logit = mle_residual_corrupted_logit.squeeze().mean(dim=-1)

        all_residual_clean_logit[i] = residual_clean_logit[idx_of_interest]
        all_residual_corrupted_logit[i] = residual_corrupted_logit[idx_of_interest]
        delta_residual_logit[i] = all_residual_clean_logit[i] - all_residual_corrupted_logit[i]
    
            
            
    fig, axes = plt.subplots(1, 3, figsize=(15, 6))
    for row in all_residual_clean_logit:
        axes[0].plot(row, color='grey', linewidth=0.5)
    for row in all_residual_corrupted_logit:
        axes[1].plot(row, color='grey', linewidth=0.5)
    for row in delta_residual_logit:
        axes[2].plot(row, color='grey', linewidth=0.5)   

    axes[0].plot(torch.mean(all_residual_clean_logit, dim=0), color='blue', linewidth=2, label='Mean')
    axes[1].plot(torch.mean(all_residual_corrupted_logit, dim=0), color='blue', linewidth=2, label='Mean')
    axes[2].plot(torch.mean(delta_residual_logit, dim=0), color='blue', linewidth=2, label='Mean')

    # Set titles and labels
    axes[0].set_title('Clean prompt')
    axes[0].set_xlabel('Layer Index (Attention)')
    axes[0].set_ylabel('Logit Contribution')
    axes[1].set_title('Corrupted prompt (resample ablation)')
    axes[1].set_xlabel('Layer Index (Attention)')
    axes[1].set_ylabel('Logit Contribution')
    axes[2].set_title('Delta')
    axes[2].set_xlabel('Layer Index (Attention)')
    axes[2].set_ylabel('Logit Contribution')

    # Add legends
    axes[0].legend()
    axes[1].legend()
    axes[2].legend()

    # Set y-limits manually
    for ax in axes:
        ax.set_ylim(-0.5, 1.5)  # Replace ymin and ymax with your desired y-limits

    # Add a master title
    fig.suptitle('Mean Decomposition of Contributions to Logits of Target Token (Attention)', fontsize=16)

    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()

plot_logit_attribution_by_layer(results, attribution_type="attn", use_target=True)
        
    
