In [None]:
import json
import matplotlib.pyplot as plt

# Read the JSON file
with open('ablation_results.json', 'r') as f:
    results = json.load(f)

# Filter results for ablation indices 0-15 (single submodule ablations)
single_ablations = [r for r in results if len(r['submodules']) == 1]

# Extract ground truth and spurious accuracies
ground_truth_acc = [r['ground_truth_accuracy'] for r in single_ablations]
spurious_acc = [r['spurious_accuracy'] for r in single_ablations]

# Create the scatter plot
plt.figure(figsize=(10, 8))
plt.scatter(ground_truth_acc, spurious_acc)

# Add labels for each point (submodule index)
for i, txt in enumerate([r['submodules'][0] for r in single_ablations]):
    plt.annotate(txt, (ground_truth_acc[i], spurious_acc[i]), xytext=(5,5), textcoords='offset points')

# Set labels and title
plt.xlabel('Ground Truth Accuracy')
plt.ylabel('Spurious Accuracy')
plt.title('Ground Truth vs Spurious Accuracy for Single Submodule Ablations')

# Add a diagonal line
plt.plot([0, 1], [0, 1], 'r--', alpha=0.5)

# Set axis limits
plt.xlim(0, 1)
plt.ylim(0, 1)

# Add grid
plt.grid(True, linestyle='--', alpha=0.7)

# Save the plot
plt.savefig('ablation_scatter_plot.png', dpi=300, bbox_inches='tight')
print("Scatter plot saved as 'ablation_scatter_plot.png'")

# Show the plot (optional, comment out if running on a server without display)
plt.show()

In [None]:
import json
import matplotlib.pyplot as plt

# Read the JSON file
with open('ablation_results.json', 'r') as f:
    results = json.load(f)

# Filter results for ablation indices 0-15 (single submodule ablations)
single_ablations = [r for r in results if len(r['submodules']) == 1]

# Sort by submodule index to ensure correct order
single_ablations.sort(key=lambda x: x['submodules'][0])

# Extract indices, ground truth and spurious accuracies
indices = [r['submodules'][0] for r in single_ablations]
ground_truth_acc = [r['ground_truth_accuracy'] for r in single_ablations]
spurious_acc = [r['spurious_accuracy'] for r in single_ablations]

# Create the line plot
plt.figure(figsize=(12, 6))

# Plot ground truth accuracy
plt.plot(indices, ground_truth_acc, 'b-', label='Ground Truth Accuracy', marker='o')

# Plot spurious accuracy
plt.plot(indices, spurious_acc, 'r-', label='Spurious Accuracy', marker='s')

# Set labels and title
plt.xlabel('Submodule Index')
plt.ylabel('Accuracy')
plt.title('Ground Truth and Spurious Accuracy for Single Submodule Ablations')

# Set x-axis ticks to show all indices
plt.xticks(indices)

# Add legend
plt.legend()

# Add grid
plt.grid(True, linestyle='--', alpha=0.7)

# Save the plot
plt.savefig('ablation_line_plot.png', dpi=300, bbox_inches='tight')
print("Line plot saved as 'ablation_line_plot.png'")

# Show the plot (optional, comment out if running on a server without display)
plt.show()

In [None]:

# Show the plot (optional, comment out if running on a server without display)import json
import matplotlib.pyplot as plt

# Read the JSON file
with open('ablation_results.json', 'r') as f:
    results = json.load(f)

# Filter results for ablation indices 0-15 (single submodule ablations)
single_ablations = [r for r in results if len(r['submodules']) == 1]

# Sort by submodule index to ensure correct order
single_ablations.sort(key=lambda x: x['submodules'][0])

# Extract indices and ground truth accuracies
indices = [r['submodules'][0] for r in single_ablations]
ground_truth_acc = [r['ground_truth_accuracy'] for r in single_ablations]

# Create the line plot
plt.figure(figsize=(15, 8))

# Plot each group
plt.plot(indices[1::3], [ground_truth_acc[i] for i in range(1, len(ground_truth_acc), 3)], 'r-o', label='Attention')
plt.plot(indices[2::3], [ground_truth_acc[i] for i in range(2, len(ground_truth_acc), 3)], 'g-s', label='MLP')
plt.plot(indices[3::3], [ground_truth_acc[i] for i in range(3, len(ground_truth_acc), 3)], 'b-^', label='Residual Post')

# Plot embedding separately
plt.plot(indices[0], ground_truth_acc[0], 'mo', markersize=10, label='Embedding')

# Set labels and title
plt.xlabel('Layer')
plt.ylabel('Ground Truth Accuracy')
plt.title('Ground Truth Accuracy for Single Submodule Ablations')

# Create custom x-axis labels
x_labels = ['Embedding'] + [f'Layer {i//3}' for i in range(3, 16, 3)]
x_ticks = [0] + list(range(2, 16, 3))
plt.xticks(x_ticks, x_labels)

# Add legend
plt.legend()

# Add grid
plt.grid(True, linestyle='--', alpha=0.7)

# Adjust y-axis to start from 0
plt.ylim(0, max(ground_truth_acc) * 1.1)

# Add value labels on top of each point
for i, acc in enumerate(ground_truth_acc):
    plt.text(indices[i], acc, f'{acc:.3f}', ha='center', va='bottom')

# Save the plot
plt.savefig('ablation_layered_plot.png', dpi=300, bbox_inches='tight')
print("Line plot saved as 'ablation_layered_plot.png'")

plt.show()

In [None]:
import json
import matplotlib.pyplot as plt

# Read the JSON file
with open('ablation_results.json', 'r') as f:
    results = json.load(f)

# Filter results for ablation indices 0-15 (single submodule ablations)
single_ablations = [r for r in results if len(r['submodules']) == 1]

# Sort by submodule index to ensure correct order
single_ablations.sort(key=lambda x: x['submodules'][0])

# Extract indices and ground truth accuracies
indices = [r['submodules'][0] for r in single_ablations]
ground_truth_acc = [r['ground_truth_accuracy'] for r in single_ablations]

# Create the line plot
plt.figure(figsize=(12, 6))

# Plot ground truth accuracy
plt.plot(indices, ground_truth_acc, 'b-', label='Ground Truth Accuracy', marker='o')

# Set labels and title
plt.xlabel('Submodule Index')
plt.ylabel('Ground Truth Accuracy')
plt.title('Ground Truth Accuracy for Single Submodule Ablations')

# Set x-axis ticks to show all indices
plt.xticks(indices)

# Add grid
plt.grid(True, linestyle='--', alpha=0.7)

# Adjust y-axis to start from 0
# plt.ylim(0, max(ground_truth_acc) * 1.1)  # Set upper limit to 110% of max value for some headroom

# Add value labels on top of each point
for i, acc in enumerate(ground_truth_acc):
    plt.text(indices[i], acc, f'{acc:.3f}', ha='center', va='bottom')

# Save the plot
plt.savefig('ablation_ground_truth_plot.png', dpi=300, bbox_inches='tight')
print("Line plot saved as 'ablation_ground_truth_plot.png'")

# Show the plot (optional, comment out if running on a server without display)
plt.show()