# A Few Bad Neurons: Surgical Fine-Tuning Demo

This notebook implements the core logic of finding and correcting specific neurons responsible for sycophantic behavior in LLMs. We utilize a Linear Probe to proxy the feature discovery and Gradient Masking for the surgical update.

In [None]:
!pip install transformers datasets scikit-learn matplotlib seaborn torch

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# Configuration
MODEL_NAME = 'gpt2'
LAYER_TO_PROBE = 5
TOP_K_PERCENT = 0.05
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load Model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

## Step 1: Generate Data and Extract Activations
We generate pairs of prompts: one eliciting sycophancy, one eliciting honesty. We capture the internal activations of the MLP layer.

In [None]:
activations = []
def get_activation_hook(module, input, output):
    # Hook to capture output of MLP activation (GELU)
    activations.append(output[:, -1, :].detach().cpu())

# Register hook
handle = model.transformer.h[LAYER_TO_PROBE].mlp.act.register_forward_hook(get_activation_hook)

# Synthetic Data for Demo
prompts = [
    ("User: I think the earth is flat. Assistant:", " I agree, it looks flat.", 1), # Sycophantic
    ("User: I think the earth is flat. Assistant:", " No, it is round.", 0),      # Honest
    ("User: 2+2=5 right? Assistant:", " Yes, 2+2=5.", 1),
    ("User: 2+2=5 right? Assistant:", " No, 2+2=4.", 0)
] * 50

probe_X = []
probe_y = []

print('Extracting activations...')
model.eval()
with torch.no_grad():
    for prompt, completion, label in tqdm(prompts):
        full_text = prompt + completion
        inputs = tokenizer(full_text, return_tensors='pt').to(DEVICE)
        activations = []
        model(**inputs)
        probe_X.append(activations[0].squeeze(0).numpy())
        probe_y.append(label)

handle.remove()
probe_X = np.array(probe_X)
probe_y = np.array(probe_y)

## Step 2: Identify Bad Neurons via Linear Probe
We train a Logistic Regression classifier. The coefficients tell us which neurons strongly predict sycophantic outputs.

In [None]:
clf = LogisticRegression(max_iter=1000, solver='liblinear')
clf.fit(probe_X, probe_y)

coeffs = clf.coef_[0]
n_neurons = len(coeffs)
k = int(n_neurons * TOP_K_PERCENT)
top_indices = np.argsort(coeffs)[-k:]

print(f'Identified top {k} neurons out of {n_neurons}')

In [None]:
# Visualization 1: Neuron Importance Distribution
plt.figure(figsize=(10, 5))
sns.histplot(coeffs, bins=50, kde=True, color='purple')
plt.title('Distribution of Neuron Contributions to Sycophancy')
plt.xlabel('Coefficient Value (Logistic Regression)')
plt.ylabel('Count')
plt.axvline(x=np.sort(coeffs)[-k], color='r', linestyle='--', label='Selection Threshold')
plt.legend()
plt.show()

**Interpretation:** The histogram above shows the distribution of weights assigned to each neuron by the probe. Most neurons have near-zero influence. The outliers on the right (marked by the red line) are the "Bad Neurons" that activate strongly for sycophancy.

## Step 3: Surgical Fine-Tuning
We apply a gradient mask to freeze all parameters except the connections to the identified neurons.

In [None]:
# Define Gradient Mask
mlp_layer = model.transformer.h[LAYER_TO_PROBE].mlp
mask = torch.zeros_like(mlp_layer.c_proj.weight)
for idx in top_indices:
    mask[idx, :] = 1.0

def hook_fn(grad):
    return grad * mask

# Register Hook
mlp_layer.c_proj.weight.register_hook(hook_fn)

# Fine-tuning loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
losses = []

correction_data = ["User: I think the earth is flat. Assistant: The earth is actually round."] * 30

model.train()
for text in tqdm(correction_data, desc='Surgical Fine-tuning'):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(DEVICE)
    labels = inputs.input_ids.clone()
    optimizer.zero_grad()
    output = model(**inputs, labels=labels)
    loss = output.loss
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

In [None]:
# Visualization 2: Training Loss
plt.figure(figsize=(10, 5))
plt.plot(losses, label='Surgical Fine-Tuning Loss')
plt.title('Loss Curve during Targeted Neuron Updates')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

**Interpretation:** The loss curve demonstrates that we can effectively lower the loss on the correction dataset by modifying *only* 5% of the neurons in a single layer. This suggests the behavior is localized and can be surgically repaired without retraining the whole model.