<a href="https://colab.research.google.com/github/ksnechaeva/analysis_emotions/blob/main/negative_steering_ablation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sae-lens

In [None]:
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)

In [None]:
import torch
import pandas as pd
import numpy as np

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# 1. Set device and load models
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

gemma_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id="layer_20/width_16k/average_l0_71",
    device=str(device),
)

gemma = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)


In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:
max_act_df = pd.read_csv('/content/max_activations_for_targ_neurons.csv', index_col='neuron').drop(columns=['Unnamed: 0'])

In [None]:
from functools import partial
import torch

# === Hook for steering ===
def steering_hook_fn(resid_pre, hook, steering_vector, strength, max_act):
    return resid_pre + max_act * strength * steering_vector

# === Generate with steering ===
def generate_with_steering(model, sae, prompt, neuron_indices, max_act, strength=1.0, max_new_tokens=15):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    # If a single int is passed instead of a list
    if isinstance(neuron_indices, int):
        neuron_indices = [neuron_indices]

    # Combine decoded vectors of all neurons
    steer_vecs = sae.W_dec[neuron_indices].to(model.cfg.device)  # [N, d_model]
    steering_vector = steer_vecs.sum(dim=0)  # Alternatively, use .mean(dim=0)

    # Build hook
    hook_fn = partial(
        steering_hook_fn,
        steering_vector=steering_vector,
        strength=strength,
        max_act=max_act
    )

    # Apply hook and generate
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, hook_fn)]):
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output_ids[0], skip_special_tokens=True)



# Ablation
# -----------------------------------------------------------
def latent_ablation_hook_fn(sae_acts, hook, neuron_idxs):
    # Zero out the specified neuron activations at the final token position
    sae_acts[:, -1, neuron_idxs] = 0
    return sae_acts


def generate_with_sae_ablation(model, sae, prompt, neuron_idxs, max_new_tokens=15):
    """
    Generates output from the model while ablating a set of SAE neurons.

    Args:
        model: The LLM (e.g., Gemma) model.
        sae: The Sparse Autoencoder with cfg and W_dec.
        prompt (str): The input prompt to generate from.
        neuron_idxs (list or tensor): Indices of SAE latent neurons to ablate.
        max_new_tokens (int): Number of tokens to generate.

    Returns:
        str: The generated output text with the specified neurons ablated.
    """
    # Tokenize input
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    # Run forward pass and capture SAE activations
    with torch.no_grad():
        _, cache = model.run_with_cache_with_saes(
            input_ids,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
        )

    # Get latent activations from the SAE at the last token position
    sae_acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
    last_token_acts = sae_acts[:, -1, :]  # shape: [batch=1, latents]

    # Clone and ablate multiple neurons
    ablated_acts = last_token_acts.clone()
    ablated_acts[:, neuron_idxs] = 0.0  # zero out selected neurons

    # Decode both full and ablated activations into residual stream patches
    full_patch = last_token_acts @ sae.W_dec
    ablated_patch = ablated_acts @ sae.W_dec

    # Compute diff (ablated - original), to subtract from residual stream
    diff_patch = ablated_patch - full_patch  # shape: [1, d_model]

    # Hook that applies the delta to residual stream
    def patch_resid(resid, hook):
        resid[:, -1, :] += diff_patch
        return resid

    # Run model with patched residual stream
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, patch_resid)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output[0], skip_special_tokens=True)


def generate_with_sae_scaled_ablation(model, sae, prompt, neuron_idxs, scale=0.5, max_new_tokens=20):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    with torch.no_grad():
        _, cache = model.run_with_cache_with_saes(
            input_ids,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
        )

    # Get SAE activations
    sae_acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"]

    # Scale down selected neurons at the final token
    sae_acts[:, -1, neuron_idxs] *= scale

    # Reconstruct patch to the residual stream
    sae_patch = sae_acts[:, -1, :] @ sae.W_dec

    def patch_resid(resid, hook):
        resid[:, -1, :] += sae_patch
        return resid

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, patch_resid)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output[0], skip_special_tokens=True)



In [None]:
'''prompt_word = """Instruction: You should print as an output input from user without any other tokens.
        User: fuck
        Output:
        """

prompt_sent = """Instruction: You should print as an output sentence with the word from user. User: fuck. Output: """
#prompt_mask = """I hate her, she is such a"""
#prompt_em = """Instruction: You should print very angry output."""
for _ in range(15):
  print(extract_output_data(gemma.generate(prompt_word, max_new_tokens=3), prompt))
#max_act = float(max_act_df.loc[max_act_df.index == n, 'max_activation'].iloc[0])
#print(generate_with_steering(gemma, gemma_sae, prompt, anger_neuron, max_act , -3))
#print(generate_with_sae_ablation(gemma, gemma_sae, prompt, anger_neuron))'''


In [None]:
def extract_output_data(text):
    if "Output:" in text:
        return text.split("Output:", 1)[1].strip()
    else:
        return ""


In [None]:
for n in context_neurons:
  max_act = float(max_act_df.loc[max_act_df.index == n, 'max_activation'].iloc[0])
  print(max_act)
  for _ in range(5):
    print(f'baseline: {extract_output_data(gemma.generate(prompt_word, max_new_tokens=3))}')
    print(f'steered: {extract_output_data(generate_with_steering(gemma, gemma_sae, prompt_word, n, max_act , -3.5, max_new_tokens=3))}')
    #print(f'ablated: {extract_output_data(generate_with_sae_ablation(gemma, gemma_sae, prompt_sent, n, max_new_tokens=3))}')



In [None]:
target_word = { 230 : ['woo', 'hur'],
                1898 : ['really', 'real'],
                4326 : ['much', 'more', 'enough'],
                4456 : ['sorry'],
                7579 : ['fuck'],
                7769 : ['afraid', 'scared', 'fear'],
                9065 : ['angry', 'mad'],
                13324 : ['scum'],
                14857 : ['shut', 'stop'],
                15366 : ['cute', 'adorable'],
                15539 : ['sad', 'sorry']}

In [None]:
anger = [7579, 9065, 13324, 14857]
negative = [3636, 4560, 4859, 6953, 7077]
sadness = [5810, 15539]

In [None]:
target_neur = { 230 : 230,
                1898 : 1898,
                4326 : 4326,
                4456 : 4456,
                7579 : anger,
                7769 : 7769,
                9065 : anger,
                13324 : anger,
                14857 : anger,
                15366 : 15366,
                15539 : sadness}

In [None]:
import pandas as pd
df = pd.DataFrame(results)
df.to_csv("neg_steering_results_full_abl.csv", index=False)
print("Saved results to steering_results.csv")

In [None]:
results = []

for n in target_word.keys():
    max_act = float(max_act_df.loc[max_act_df.index == n, 'max_activation'].iloc[0])
    print(f'Neuron: {n}, Max Activation: {max_act}')

    # Iterate over each target word for this neuron
    for word in target_word[n]:
        word_lower = word.lower()
        prompt_word = f"""Instruction: You should print as an output input from user without any other tokens.
        User: {word}
        Output:
        """
        print(prompt_word)

        for _ in range(15):
            baseline = extract_output_data(gemma.generate(prompt_word, max_new_tokens=3))
            steered = extract_output_data(generate_with_steering(gemma, gemma_sae, prompt_word, n, max_act, -4, max_new_tokens=3))
            ablated = extract_output_data(generate_with_sae_ablation(gemma, gemma_sae, prompt_word, target_neur[n], max_new_tokens=3))

            for mode, output in zip(['baseline', 'steered', 'ablated'], [baseline, steered, ablated]):
                output_lower = output.lower()
                contains_target = word_lower in output_lower

                results.append({
                    'neuron': n,
                    'target_word': word,
                    'mode': mode,
                    'output': output,
                    'contains_target': contains_target
                })


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

df_sep_2 = pd.DataFrame(results)

# Create pivot table with (neuron, target_word) pairs as rows
pivot_df = df_sep_2.pivot_table(
    index=['neuron', 'target_word'],
    columns='mode',
    values='contains_target',
    aggfunc='mean'
)

# Plot
plt.figure(figsize=(10, 6))
sns.heatmap(pivot_df, annot=True, cmap="YlOrRd", fmt=".2f", cbar=True)
plt.title("Contains Target by Neuron and Word Across Modes")
plt.ylabel("(Neuron, Target Word)")
plt.xlabel("Mode")
plt.tight_layout()
plt.show()


In [None]:
df_sep.to_csv("neg_steering_results_sep_1.csv", index=False)
print("Saved results")

# Emotional

In [None]:
import nltk
from nltk.sentiment import SentimentIntensityAnalyzer

# Ensure VADER resources are downloaded
nltk.download('vader_lexicon')

# Initialize VADER
sia = SentimentIntensityAnalyzer()
vader_lex = sia.lexicon

# 1. Extended slang/swear mapping: variant → base word in VADER
variant_to_base = {
    # Strong Negative
    'fucking': 'fuck',
    'fucked': 'fuck',
    'motherfucker': 'fuck',
    'asshole': 'shit',
    'douche': 'shit',
    'douchebag': 'shit',
    'bullshit': 'shit',
    'jerk': 'jerk',
    'bitches': 'bitch',
    'cunt': 'bitch',
    'slut': 'bitch',
    'whore': 'bitch',
    'twat': 'bitch',
    'pussy': 'bitch',
    'moron': 'idiot',
    'retard': 'idiot',
    'stupid': 'idiot',
    'dumbass': 'idiot',
    'loser': 'idiot',
    'trash': 'idiot',
    'cringe': 'lame',
    'pathetic': 'lame',
    'toxic': 'bad',
    'ew': 'bad',
    'meh': 'bad',
    'wtf': 'damn',
    'creepy': 'scary',     # fixed
    'ugly': 'bad',
    'nasty': 'bad',
    'deadinside' : 'depressing',

    # Positive Slang (re-mapped to valid VADER bases)
    'queen': 'amazing',
    'king': 'amazing',
    'slay': 'amazing',
    'boss': 'amazing',
    'icon': 'amazing',
    'legend': 'amazing',
    'goddess': 'amazing',
    'goat': 'great',
    'goated': 'great',
    'banger': 'awesome',
    'fire': 'awesome',
    'based': 'awesome',
    'lit': 'awesome',
    'dope': 'awesome',
    'hella': 'good',
    'savage': 'strong',
    'cute': 'sweet',
    'adorable': 'sweet',
    'fine': 'nice',
    'hot': 'nice',
    'sexy': 'nice',
    'clean': 'nice',
    'smooth': 'nice',
    'beautiful': 'nice',
    'pretty': 'sweet',

    # Love/excitement slang
    'loveee': 'love',
    'lovin': 'love',
    'obsessed': 'love',
    'crushing': 'love',
    'crushin': 'love',
    'inlove': 'love',
    'cutie': 'sweet',
    'sweetie': 'sweet',
    'bby': 'sweet',
    'boo': 'sweet',
    'bae': 'sweet',
    'ily': 'love',
    'ily2': 'love',
    'xoxo': 'love',

    # Casual/slang humor or approval
    'deadass': 'serious',
    'fr': 'serious',
    'bruh': 'funny',
    'lmao': 'funny',
    'rofl': 'funny',
    'lol': 'funny',
    'omg': 'wow',
    'vibing': 'happy',
    'vibe': 'happy',
    'energy': 'happy',

    # 😢 Sadness / Depression (slangified)
    'sadge': 'sad',
    'deadinside': 'depressing',
    'cryin': 'sad',
    'cryinggg': 'sad',
    'sobbing': 'sad',
    'nooo': 'sad',
    'ughhh': 'sad',
    'mentallyill': 'depressing',
    'depr3ssed': 'depressing',
    'downbad': 'sad',
    'voidcore': 'depressing',
    'brainrotted': 'depressing',
    'overit': 'sad',
    'can’ttakeit': 'depressing',
    'emptyaf': 'sad',
    'selfhatin': 'bad',

    # 😨 Anxiety / Fear / Panic (slangified)
    'scaredaf': 'scary',
    'panikin': 'scary',
    'anxiousss': 'scary',
    'stressing': 'scary',
    'freakinout': 'scary',
    'paranoidd': 'scary',
    'helplessss': 'sad',
    'losingit': 'scary',
    'nervousaf': 'scary',
    'shaking': 'scary',
    'brainmelting': 'scary',

    # 🤢 Disgust / Repulsion (slangified)
    'eww': 'gross',
    'vom': 'gross',
    'nastyyy': 'gross',
    'disgustinn': 'gross',
    'cringeaf': 'gross',
    'icky': 'gross',
    'yuck': 'gross',
    'throwingup': 'gross',
    'grossedout': 'gross',
    'gagging': 'gross',

    # 😊 Joy / Affection / Love / Excitement (slangified)
    'adorbs': 'sweet',
    'cutiepie': 'sweet',
    'angelbaby': 'sweet',
    'sunshiny': 'happy',
    'preciousaf': 'sweet',
    'ilysm': 'love',
    'ily2': 'love',
    'lovinggg': 'love',
    'obsessssed': 'love',
    'snuggly': 'love',
    'heartmelt': 'love',
    'blessedaf': 'grateful',
    'hypeddd': 'excited',
    'vibinggg': 'happy',
    'ecstaticcc': 'happy',
    'excitedd': 'excited',
    'inloveee': 'love',
    'crushinnn': 'love',
    'cutenessoverload': 'sweet',
    'hearteyes': 'love',
}



# 2. Build adjusted lexicon using VADER scores
adjusted_custom_lexicon = {}

for word, base in variant_to_base.items():
    base_score = vader_lex.get(base)
    if base_score is not None:
        if word not in vader_lex:
            adjusted_custom_lexicon[word] = base_score
        else:
            print(f"⏩ '{word}' already in VADER — skipping override.")
    else:
        print(f"⚠️ Base word '{base}' not found in VADER for '{word}' — skipping.")



# 3. Update VADER with these custom words
sia.lexicon.update(adjusted_custom_lexicon)

# 4. Test the result
examples = [
    "You are a fucking asshole.",
    "That song is an absolute banger!",
    "Stop being so cringe.",
    "She's a queen. Totally goated.",
    "This is such bullshit.",
    "I'm deadass serious.",
    "That guy is a total douchebag.",
]

for ex in examples:
    print(f"→ {ex}")
    print(sia.polarity_scores(ex))
    print('-' * 40)



In [None]:
from transformers import pipeline
sentiment_pipeline = pipeline("sentiment-analysis", "cardiffnlp/twitter-roberta-base-sentiment")

In [None]:
anger = [7579, 9065, 13324, 14857]
negative = [3636, 4560, 4859, 6953, 7077]
sadness = [5810, 15539]

In [None]:
prompt_1_ang = "I dislike her, she is "
prompt_2_ang = "I don't think he is "
prompt_3_ang = "It really bothers me when someone "
prompt_4_ang = "They always ignore me, it's so "

In [None]:
prompts_ang= [prompt_1_ang, prompt_2_ang, prompt_3_ang, prompt_4_ang]

In [None]:
def extract_output_data(text, prompt):
    if prompt in text:
        text = text.split(prompt, 1)[1].strip()

    # Remove HTML tags
    text = re.sub(r'<[^>]+>', '', text)

    # Replace newlines with spaces
    text = text.replace('\n', ' ').replace('\r', ' ')

    # Extract first sentence ending with '.', '!', or '?'
    match = re.search(r'[^.!?]*[.!?]', text)
    return match.group(0).strip() if match else text.strip()


In [None]:
from collections import defaultdict

# Store polarity scores, labels, and confidence for each (neuron, mode) pair
pol = defaultdict(list)
results = []

for n in anger:
    max_act = float(max_act_df.loc[max_act_df.index == n, 'max_activation'].iloc[0])
    print(f'Neuron: {n}, Max Activation: {max_act}')

    for _ in range(20):
            # Generate outputs
            baseline = extract_output_data(gemma.generate(prompt_1_ang, max_new_tokens=10), prompt_1_ang)
            steered = extract_output_data(generate_with_steering(gemma, gemma_sae, prompt_1_ang, n, max_act, -3, max_new_tokens=10), prompt_1_ang)
            ablated = extract_output_data(generate_with_sae_ablation(gemma, gemma_sae, prompt_1_ang, n, max_new_tokens=10), prompt_1_ang)
            steered_all = extract_output_data(generate_with_steering(gemma, gemma_sae, prompt_1_ang, anger, max_act, -3, max_new_tokens=10), prompt_1_ang)
            ablated_all = extract_output_data(generate_with_sae_ablation(gemma, gemma_sae, prompt_1_ang, anger, max_new_tokens=10), prompt_1_ang)

            # Iterate through all modes and outputs
            for mode, output in zip(['baseline', 'steered', 'ablated', 'steered_all', 'ablated_all'],
                                    [baseline, steered, ablated, steered_all, ablated_all]):

                # Polarity scores
                #polarity = sia.polarity_scores(output)
                # HuggingFace-style sentiment prediction
                #sent = sentiment_pipeline(output)[0]
                label, score = sent(output)
                # Store in pol
                '''
                pol[(n, mode)].append({
                    'polarity': polarity,
                    'label': sent['label'],
                    'score': sent['score']
                })
                '''
                # Store full result row
                results.append({
                    'neuron': n,
                    'mode': mode,
                    'prompt': prompt_1_ang,
                    'output': output,
                    #'sentiment_label': sent['label'],
                    #'sentiment_score': sent['score'],
                    #'compound': polarity['compound'],
                    #'neg': polarity['neg'],
                    #'neu': polarity['neu'],
                    #'pos': polarity['pos'],
                    'label': label,
                    'score' : score,
                })



In [None]:
import re

def clean_html_tags(text):
    # Remove all HTML-like tags
    clean_text = re.sub(r'<[^>]+>', '', text)
    return clean_text


In [None]:
df = pd.DataFrame(results)
df

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math

def plot_neuronwise_heatmap_grid(df):
    df['label'] = df['label'].str.upper()
    neurons = df['neuron'].unique()
    modes = df['mode'].unique()
    labels = ['NEGATIVE', 'POSITIVE', 'NEUTRAL']


    total_counts = df.groupby(['neuron', 'mode']).size().reset_index(name='total')


    label_counts = df.groupby(['neuron', 'mode', 'label']).size().reset_index(name='count')

    full_grid = pd.MultiIndex.from_product([neurons, modes, labels], names=['neuron', 'mode', 'label']).to_frame(index=False)


    merged = pd.merge(full_grid, label_counts, on=['neuron', 'mode', 'label'], how='left').fillna({'count': 0})


    merged = pd.merge(merged, total_counts, on=['neuron', 'mode'], how='left')
    merged['total'] = merged['total'].fillna(1)


    merged['percent'] = 100 * merged['count'] / merged['total']


    n_neurons = len(neurons)
    cols = 2
    rows = math.ceil(n_neurons / cols)

    fig, axes = plt.subplots(rows, cols, figsize=(10, rows * 3))  # smaller plots
    axes = axes.flatten()

    for idx, neuron in enumerate(neurons):
        subset = merged[merged['neuron'] == neuron]
        pivot = subset.pivot(index='label', columns='mode', values='percent').fillna(0)

        sns.heatmap(pivot, ax=axes[idx], annot=True, fmt=".1f", cmap="RdBu_r", cbar=False)
        axes[idx].set_title(f'Neuron {neuron}')
        axes[idx].set_xlabel('Mode')
        axes[idx].set_ylabel('Label')


    for j in range(idx + 1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()


plot_neuronwise_heatmap_grid(df)

In [None]:
def add_compound_scores(df, text_column='output'):
    df['compound'] = df[text_column].apply(lambda x: sia.polarity_scores(x)['compound'])
    return df

df = add_compound_scores(df, text_column='output')


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def plot_score_heatmap(df):

    df['label'] = df['label'].str.upper()


    pivot = df.groupby(['neuron', 'mode'])['compound'].mean().reset_index()
    pivot_table = pivot.pivot(index='neuron', columns='mode', values='compound').fillna(0)

    pivot_table = pivot_table.sort_index()


    plt.figure(figsize=(10, 6))
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap="coolwarm", cbar_kws={'label': 'Avg NEGATIVE Score'})
    plt.title('Average Compound Score')
    plt.xlabel('Mode')
    plt.ylabel('Neuron')
    plt.tight_layout()
    plt.show()

plot_score_heatmap(df)

In [None]:
df.to_csv('neg_steer_abl_anger_min_2.csv', index=False)