# Setup

In [None]:
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import gc
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re
import nltk
from nltk.corpus import stopwords
from sklearn.metrics import confusion_matrix, accuracy_score
from collections import defaultdict
import plotly.graph_objects as go
from sae_lens import SAE, HookedSAETransformer
from functools import partial
from transformers import AutoTokenizer, AutoModel
import pickle
from textblob import TextBlob
from nltk.sentiment import SentimentIntensityAnalyzer
from transformers import pipeline

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
!pip install sae_lens

In [None]:
!pip install nltk

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset

ds = load_dataset("google-research-datasets/go_emotions", "simplified")

In [None]:
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)

In [None]:
from datasets import concatenate_datasets

ds = concatenate_datasets([ds['train'], ds['validation'], ds['test']])
ds

In [None]:
label_names = ds.features["labels"].feature.names


In [None]:
targeted_emotions = ['joy', 'anger', 'disgust', 'sadness', 'love', 'fear', 'excitement']
labels = []
for em in targeted_emotions:
  labels.append(label_names.index(em))

In [None]:
ds = ds.filter(lambda x: any(label in labels for label in x['labels']))

In [None]:
ds = ds.filter(lambda x: len(x['labels']) == 1)

In [None]:
import numpy as np

ids = np.load("ds_filt.npy")

In [None]:
filtered_ds = ds.filter(lambda x: x["id"] in ids)


In [None]:
ds = filtered_ds

In [None]:
len(filtered_ds)

In [None]:
ds

In [None]:
texts = [s['text'] for s in ds]
labels = [s['labels'] for s in ds]

In [None]:
import nltk
nltk.download('stopwords')

In [None]:
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))


In [None]:
stop_words

In [None]:
word_counter = Counter()
for sentence in texts:
    words = re.findall(r'\b\w+\b', sentence.lower())
    filt_words = [w for w in words if(w not in stop_words) and (len(w)>1)]
    word_counter.update(filt_words)

word_freq_df = pd.DataFrame(word_counter.items(), columns=['word', 'count']).sort_values(by='count', ascending=False)

plt.figure(figsize=(12, 6))
sns.barplot(data=word_freq_df.head(30), x='word', y='count', palette='viridis')
plt.xticks(rotation=45)
plt.title('Top 30 Most Frequent Words in GoEmotions Dataset')
plt.xlabel('Word')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
word_counter

In [None]:
new = [w for w, c in word_counter.items() if c > 5]

In [None]:
len(new)

## Classification with Gemma

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
model.eval()


In [None]:
def build_prompt(shots=('joy', 'sadness'), prompt_index=1):
    sample_pool = {
        'joy': 'My first child was born.',
        'anger': 'My husband missed an important call, because his phone was on silent AGAIN!',
        'disgust': 'I saw mouldy food.',
        'sadness': 'My dog died last week.',
        'love': 'I told my partner I loved them.',
        'fear': 'I was confronted by a thief.',
        'excitement': 'I got an A in my exam!'
    }

    # Filter and prepare few-shot examples from the selected emotions
    selected_shots = [(sample_pool[e], e) for e in shots if e in sample_pool]

    emotion_list = "Consider this list of emotions: joy, anger, disgust, sadness, love, fear, excitement. "

    templates = [
        "What are the inferred emotions in the following contexts?",
        emotion_list + "What are the inferred emotions in the following contexts?",
        "",
        "Guess the emotion.",
        "Decipher the emotion from the following statements: ",
        "Decipher the label for the following statements: ",
        "What is the label, for the statement? ",
        "What is the label, given the context? ",
        emotion_list + "Decipher the emotion from the following statements: ",
        emotion_list + "Decipher the label for the following statements: ",
    ]

    header = templates[prompt_index]
    body = header.strip()

    for text, emotion in selected_shots:
        body += f" Context: {text} Answer: {emotion}"

    return lambda x: f"{body} Context: {x} Answer:"



In [None]:
from tqdm import tqdm

def classify_with_gemma(ds, tokenizer, model, targeted_emotions, prompt_index=1, shot_emotions=None):
    if shot_emotions is None:
        shot_emotions = targeted_emotions[:7]

    prompt_func = build_prompt(shots=shot_emotions, prompt_index=prompt_index)
    preds = []

    for example in tqdm(ds, desc="Classifying with Gemma (few-shot)"):
        text = example['text']
        prompt = prompt_func(text)

        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=3, do_sample=False, pad_token_id=tokenizer.eos_token_id)

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        #print(decoded)
        pred = decoded[len(prompt):].strip().split()[0].lower()

        if pred in targeted_emotions:
            preds.append(targeted_emotions.index(pred))
        else:
            preds.append("other")


    return preds



In [None]:
preds_model = classify_with_gemma(ds, tokenizer, model, targeted_emotions, prompt_index=1)

In [None]:
preds_filtered = []
for pred in preds_model:
    if isinstance(pred, int):
        emotion = targeted_emotions[pred]
        if emotion in label_names:
            preds_filtered.append(label_names.index(emotion))  # map to ds label ID
        else:
            preds_filtered.append("-1")  # or -1, or skip it
    else:
        preds_filtered.append("-1")

In [None]:
texts = [item['text'] for item in ds ]
labels = [item['labels'] for item in ds ]

In [None]:
labels_flat = [lbl[0] for lbl in labels]

In [None]:
labels_flat = [int(x) for x in labels_flat]
preds_filtered = [int(x) for x in preds_filtered]


In [None]:
ds_1 = ds.add_column("pred", preds_filtered)
ds_1 = ds_1.add_column('true', labels_flat)
ds_filtered = ds_1.filter(lambda x: x["true"] == x["pred"])


In [None]:
ds_filtered

In [None]:
ids = [item['id'] for item in ds_filtered]

In [None]:
import numpy as np
from google.colab import files

np.save("ds_filt.npy", ids)
files.download("ds_filt.npy")

### Visualizations

In [None]:
true = [label_names[l] for l in labels_flat]
pred = [label_names[l] for l in preds_filtered]

In [None]:
conf_mat = confusion_matrix(true, pred, labels=targeted_emotions)

plt.figure(figsize=(10, 7))
sns.heatmap(conf_mat, annot=True,
            xticklabels=targeted_emotions,
            yticklabels=targeted_emotions,
            fmt='d', cmap='Blues')
plt.title(f"Confusion Matrix (Accuracy: {accuracy:.2%})")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()


In [None]:
def plot_sankey(true_indices, pred_indices, targ_em, label_names):

    true_labels = [label_names[i] for i in true_indices]
    pred_labels = [label_names[i] for i in pred_indices]

    all_emotions = sorted(set(true_labels + pred_labels))

    label_list = all_emotions + [f"pred_{l}" for l in all_emotions]
    label_idx = {l: i for i, l in enumerate(label_list)}

    emotion_to_color = {
        'joy': "#FF0000",
        'anger': "#FFA500",
        'sadness': "#FFFF00",
        'fear': "#27ae60",
        'disgust': "#7f8c8d",
        'love': "#e91e63",
        'excitement': "#f39c12",
        'neutral': "#bdc3c7"
    }

    node_colors = [emotion_to_color.get(e, "#CCCCCC") for e in all_emotions]
    node_colors += node_colors


    counter = defaultdict(int)
    for t, p in zip(true_labels, pred_labels):
        src = label_idx[t]
        tgt = label_idx[f"pred_{p}"]
        counter[(src, tgt)] += 1

    link_source, link_target, link_value, link_color = [], [], [], []

    all_values = list(counter.values())
    max_val = max(all_values)

    for (src, tgt), val in counter.items():
        link_source.append(src)
        link_target.append(tgt)
        link_value.append(val)

        norm_alpha = min(1.0, max(0.2, val / max_val))

        hex_color = node_colors[src].lstrip('#')
        r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
        link_color.append(f'rgba({r},{g},{b},{norm_alpha:.2f})')


    fig = go.Figure(go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=25,
            line=dict(color="black", width=0.5),
            label=label_list,
            color=node_colors
        ),
        link=dict(
            source=link_source,
            target=link_target,
            value=link_value,
            color=link_color,
        )
    ))

    fig.update_layout(
        title_text="Emotion Classification Flow (True → Predicted)",
        font_size=13,
        margin=dict(l=40, r=40, t=50, b=40),
        width=600,
        height=600
    )

    fig.show()





In [None]:
plot_sankey(labels_flat, preds_filtered, targeted_emotions, label_names)

# Collecting activations for targeted neurons

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:
from sae_lens import SAE, ActivationsStore, HookedSAETransformer

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

gemma_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id="layer_20/width_16k/average_l0_71",
    device=str(device),
)

gemma = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)

In [None]:
texts = [item['text'] for item in ds ]
labels = [item['labels'] for item in ds ]

In [None]:
with open("emotional_n_02_thresh.pkl", "rb") as f:
    neurons = pickle.load(f)

In [None]:
batch_size = 4
max_length = 400
results = []

for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
    try:
        gc.collect()
        torch.cuda.empty_cache()

        # Get batch
        batch_texts = texts[i:i + batch_size]

        # Tokenize
        tokenized = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        )
        input_ids = tokenized["input_ids"].to(device)

        # Forward pass with SAE and cache activations
        _, cache = gemma.run_with_cache_with_saes(
            input_ids,
            saes=[gemma_sae],
            stop_at_layer=gemma_sae.cfg.hook_layer + 1,
            names_filter=[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"],
        )

        # SAE activations (features)
        sae_acts = cache[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"]  # [B, T, F]
        final_acts = sae_acts[:, -1, :].detach().cpu()

        # Sparsity = number of active features
        sparsity = (sae_acts[:, -1, :] > 1).sum(dim=-1)  # [B]

        decoded_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]


        for i in range(len(batch_texts)):
            results.append({
                "input_ids": input_ids[i].detach().cpu(),              # torch.Tensor
                "tokens": decoded_tokens[i],                          # list of strings
                "activation": final_acts[i],                          # torch.Tensor
                "sparsity": int(sparsity[i]),                         # int
            })

        # Cleanup
        del cache, sae_acts, final_acts, sparsity, input_ids, tokenized
        torch.cuda.empty_cache()

    except RuntimeError as e:
        print(f"⚠️ OOM on batch {i}-{i+batch_size}: {e}")
        torch.cuda.empty_cache()

In [None]:
for r, l in zip(results, labels):
  r['label'] = l

In [None]:
len(results)

In [None]:
with open("sae_results.pkl", "wb") as f:
    pickle.dump(results, f)

In [None]:
targ_neurons = [13324, 14857, 2438, 12881, 4560, 1898, 8366, 7077, 8094, 3232, 6953, 6953, 13324, 4456, 7077, 808, 230, 281, 8783, 4305, 7717, 230, 7688, 15261, 4305, 3636, 4326, 11491, 4305, 5413, 9618, 15539]

In [None]:
batch_size = 4
max_length = 40
activation_threshold = 1.0
targ_neurons = set(targ_neurons)
results = []

for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):

        gc.collect()
        torch.cuda.empty_cache()

        batch_texts = texts[i:i + batch_size]

        tokenized = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        )
        input_ids = tokenized["input_ids"].to(device)

        _, cache = gemma.run_with_cache_with_saes(
            input_ids,
            saes=[gemma_sae],
            stop_at_layer=gemma_sae.cfg.hook_layer + 1,
            names_filter=[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"],
        )

        sae_acts = cache[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"].detach().cpu()  # [B, T, F]
        final_acts = sae_acts[:, -1, :]  # [B, F]
        decoded_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]

        for j in range(len(batch_texts)):
            token_acts = sae_acts[j]  # [T, F]
            filtered_entries = []

            for t in range(token_acts.shape[0]):  # iterate over token positions
                for f in targ_neurons:
                    act_val = token_acts[t, f].item()
                    if act_val > activation_threshold:
                        filtered_entries.append((t, f, act_val))

            final_filtered = {
                f: final_acts[j, f].item()
                for f in targ_neurons
                if final_acts[j, f].item() > activation_threshold
            }

            results.append({
                "input_ids": input_ids[j].detach().cpu(),
                "tokens": decoded_tokens[j],
                "activation_targeted": final_filtered,
                "active_neurons": filtered_entries,
                "sparsity": len(filtered_entries),
            })

        del cache, sae_acts, final_acts, input_ids, tokenized
        torch.cuda.empty_cache()

In [None]:
for r, l in zip(results, labels):
  r['label'] = l

In [None]:
with open("sae_results_tokens_top_k.pkl", "wb") as f:
    pickle.dump(results, f)


# fixed vocab

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:
from sae_lens import SAE, ActivationsStore, HookedSAETransformer

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

gemma_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id="layer_20/width_16k/average_l0_71",
    device=str(device),
)

gemma = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)

In [None]:
vocab_set = set()

with open("NRC-Emotion-Lexicon-Wordlevel-v0.92.txt", "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        if parts:
            vocab_set.add(parts[0])

print("Total unique words:", len(vocab_set))
print("Sample:", list(vocab_set)[:10])


In [None]:
# version where at least 1 emotion labeled

filtered_vocab = set()

with open("NRC-Emotion-Lexicon-Wordlevel-v0.92.txt", "r") as f:
    for line in f:
        word, emotion, label = line.strip().split("\t")
        if label == "1":
            filtered_vocab.add(word)

In [None]:
len(filtered_vocab)

In [None]:
vocab_set = filtered_vocab


In [None]:
for word, count in word_counter.items():
  filtered_vocab.add(word)


filtered_vocab = list(filtered_vocab)
vocab_set = filtered_vocab

In [None]:
len(vocab_set)

In [None]:
with open("emotional_n_02_thresh.pkl", "rb") as f:
    neurons = pickle.load(f)

In [None]:
batch_size = 4
max_length = 40
activation_threshold = 0.0
word_to_activations = defaultdict(list)

texts = [s["text"] for s in ds]

word_pattern = re.compile(r'\b({})\b'.format('|'.join(map(re.escape, vocab_set))), flags=re.IGNORECASE)
word_index_map = defaultdict(list)

# Build inverted index: word -> list of sentence indices where it appears
for idx, sentence in tqdm(enumerate(texts), total=len(texts), desc="Indexing sentences"):
    matches = word_pattern.findall(sentence)
    for word in set(matches):
        word_index_map[word.lower()].append(idx)



In [None]:
def collect_word_activations(texts, vocab_set, tokenizer, gemma, gemma_sae, device, target_neurons,
                             batch_size=4, max_length=40, activation_threshold=1.0, top_k=5):

    word_to_activations = defaultdict(list)
    word_to_high_acts = defaultdict(list)
    word_pattern = re.compile(r'\b({})\b'.format('|'.join(map(re.escape, vocab_set))), flags=re.IGNORECASE)

    # Map vocab words to dataset indices
    word_index_map = defaultdict(list)
    for idx, sentence in tqdm(enumerate(texts), total=len(texts), desc="Indexing sentences"):
        matches = word_pattern.findall(sentence)
        for word in set(matches):
            word_index_map[word.lower()].append(idx)

    for word in tqdm(vocab_set, desc="Processing vocab words"):
        word = word.lower()
        indices = word_index_map.get(word, [])[:20]
        if not indices:
            continue

        for i in range(0, len(indices), batch_size):
            batch_indices = indices[i:i + batch_size]
            batch_texts = [texts[j] for j in batch_indices]

            gc.collect()
            torch.cuda.empty_cache()

            tokenized = tokenizer(
                batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length
            )
            input_ids = tokenized["input_ids"].to(device)

            try:
                _, cache = gemma.run_with_cache_with_saes(
                    input_ids,
                    saes=[gemma_sae],
                    stop_at_layer=gemma_sae.cfg.hook_layer + 1,
                    names_filter=[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"],
                )
            except Exception as e:
                print(f"Skipping batch due to error: {e}")
                continue

            sae_acts = cache[f"{gemma_sae.cfg.hook_name}.hook_sae_acts_post"].detach().cpu()  # [B, T, F]

            for b in range(sae_acts.shape[0]):
                tokens = tokenizer.convert_ids_to_tokens(input_ids[b])
                token_acts = sae_acts[b]  # [T, F]
                word_pos = [t for t, tok in enumerate(tokens) if word in tok.lower()]
                if not word_pos:
                    continue

                # Mean over all positions where word appears → vector representation
                mean_vector = token_acts[word_pos].mean(dim=0)[target_neurons].numpy()
                word_to_activations[word].append(mean_vector)

                # Collect all high-activation events for interpretability
                for t in word_pos:
                    for f in target_neurons:
                        val = token_acts[t, f].item()
                        if val > activation_threshold:
                            word_to_high_acts[word].append((tokens[t], f, val))

    # Final top-k mean vector per word (filtering out zero vectors)
    def topk_mean(arrs, k=top_k):
        if len(arrs) <= k:
            return np.mean(arrs, axis=0)
        scores = [np.linalg.norm(vec) for vec in arrs]
        topk_indices = np.argsort(scores)[-k:]
        topk_vecs = [arrs[i] for i in topk_indices]
        return np.mean(topk_vecs, axis=0)

    word_to_sae_vec = {
        w: topk_mean(vectors)
        for w, vectors in word_to_activations.items()
        if vectors
    }

    return word_to_sae_vec, word_to_high_acts


In [None]:
word_to_sae_vec, word_to_high_acts = collect_word_activations(
    texts=texts,
    vocab_set=vocab_set,
    tokenizer=tokenizer,
    gemma=gemma,
    gemma_sae=gemma_sae,
    device=device,
    target_neurons=list(neurons),
    batch_size=4,
    max_length=40,
    activation_threshold=1.0,
    top_k=5
)


In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def extract_top_words_per_neuron(word_to_sae_vec, target_neurons, top_k=10, diversity=True, min_score_threshold=5.0):
    top_words_per_neuron = {}
    neuron_to_index = {n: i for i, n in enumerate(target_neurons)}
    all_words = list(word_to_sae_vec.keys())
    vectors = np.stack([word_to_sae_vec[w] for w in all_words])  # shape: [#words, #neurons]

    for neuron in target_neurons:
        idx = neuron_to_index[neuron]
        scores = vectors[:, idx]
        sorted_indices = np.argsort(-scores)

        top_words = []
        used_vecs = []

        for i in sorted_indices:
            score = scores[i]
            if score < min_score_threshold:
                break  # stop early if score drops below threshold

            word = all_words[i]
            vec = word_to_sae_vec[word]

            if not diversity or not used_vecs:
                top_words.append((word, score))
                used_vecs.append(vec)
            else:
                sims = cosine_similarity([vec], used_vecs)[0]
                if np.max(sims) < 0.9:
                    top_words.append((word, score))
                    used_vecs.append(vec)

            if len(top_words) >= top_k:
                break

        if top_words:
            print(f"Neuron {neuron} → Collected {len(top_words)} words (min activation = {top_words[-1][1]:.4f})")
            top_words_per_neuron[neuron] = top_words
        else:
            print(f"Neuron {neuron} → No words above threshold ({min_score_threshold})")

    return top_words_per_neuron





In [None]:
top_words_per_neuron = extract_top_words_per_neuron(word_to_sae_vec, neurons, top_k=30, diversity=True)

In [None]:
import json

with open("word_to_high_acts_drop_stop_30.json", "w") as f:
    json.dump(word_to_high_acts, f)

with open("word_to_sae_vec_reduced_drop_stop_30.json", "w") as f:
    json.dump({k: v.tolist() for k, v in word_to_sae_vec.items()}, f)


In [None]:
word_to_sae_vec

In [None]:
import json

# Convert activation values to float and structure for JSON
top_words_serializable = {
    int(neuron): [{"word": w, "activation": float(a)} for w, a in word_list]
    for neuron, word_list in top_words_per_neuron.items()
}

# Save to file
with open("top_words_per_neuron_drop_stop_30.json", "w") as f:
    json.dump(top_words_serializable, f, indent=2)


In [None]:
top_words_per_neuron

# Steering

In [None]:
# 1. Set device and load models
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

gemma_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res",
    sae_id="layer_20/width_16k/average_l0_71",
    device=str(device),
)

gemma = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)


In [None]:
max_act_df = pd.read_csv('/content/max_activations_for_targ_neurons.csv')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:
'''
from sae_lens import ActivationsStore

# Create the activation store from your dataset
activation_store = ActivationsStore.from_sae(
    model=gemma,
    sae=gemma_sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=4,
    device=gemma.cfg.device,
)
'''

In [None]:
# === Define your anger neuron index ===
anger_neuron_idx = 2438  # replace with actual anger-selective neuron index

# === Function to find max activation of this neuron ===
def find_max_activation(model, sae, act_store, neuron_idx, num_batches=100):
    max_activation = 0.0
    pbar = tqdm(range(num_batches), desc="Finding max activation")

    for _ in pbar:
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
        )
        acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
        acts_flat = acts.flatten(0, 1)
        batch_max = acts_flat[:, neuron_idx].max().item()
        max_activation = max(max_activation, batch_max)
        pbar.set_description(f"Max activation: {max_activation:.2f}")

    return max_activation

# === Hook for steering ===
def steering_hook_fn(resid_pre, hook, steering_vector, strength, max_act):
    return resid_pre + max_act * strength * steering_vector

# === Generate with steering ===
def generate_with_steering(model, sae, prompt, neuron_idx, max_act, strength=1.0, max_new_tokens=10):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steer_vec = sae.W_dec[neuron_idx].to(model.cfg.device)

    hook_fn = partial(
        steering_hook_fn,
        steering_vector=steer_vec,
        strength=strength,
        max_act=max_act
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, hook_fn)]):
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output_ids[0])

from functools import partial

# === Hook for ablation ===
def ablation_hook_fn(resid_pre, hook, ablate_vector):
    return resid_pre - ablate_vector  # Subtract the contribution of this neuron

# === Generate with ablation ===
def generate_with_ablation(model, sae, prompt, neuron_idx, max_act, max_new_tokens=10):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    # Compute the vector to subtract: max activation × SAE decoder direction
    ablate_vec = max_act * sae.W_dec[neuron_idx].to(model.cfg.device)

    hook_fn = partial(
        ablation_hook_fn,
        ablate_vector=ablate_vec
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, hook_fn)]):
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output_ids[0], skip_special_tokens=True)

from functools import partial

# === Hook that zeroes out specific neuron(s) in SAE latent space ===
def latent_ablation_hook_fn(sae_acts, hook, neuron_idx):
    sae_acts[:, -1, neuron_idx] = 0  # zero only the last token activation
    return sae_acts

# === Run generation with latent ablation ===
def generate_with_sae_ablation(model, sae, prompt, neuron_idx, max_new_tokens=20):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    # Step 1: Run with SAE cache
    with torch.no_grad():
        _, cache = model.run_with_cache_with_saes(
            input_ids,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
        )

    # Step 2: Get and modify SAE activations (zero the target neuron)
    sae_acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"]
    sae_acts[:, -1, neuron_idx] = 0

    # Step 3: Reconstruct patch for residual stream
    sae_patch = sae_acts[:, -1, :] @ sae.W_dec

    # Step 4: Hook to patch the residual stream
    def patch_resid(resid, hook):
        resid[:, -1, :] += sae_patch
        return resid

    # Step 5: Generate with patched residual stream
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, patch_resid)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output[0], skip_special_tokens=True)


In [None]:
nltk.download('vader_lexicon')

In [None]:
# === Hook for steering ===
def steering_hook_fn(resid_pre, hook, steering_vector, strength, max_act):
    return resid_pre + max_act * strength * steering_vector

# === Generate with steering ===
def generate_with_steering(model, sae, prompt, neuron_indices, max_act, strength=1.0, max_new_tokens=10):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    if isinstance(neuron_indices, int):
        neuron_indices = [neuron_indices]

    # Combine decoded vectors of all neurons
    steer_vecs = sae.W_dec[neuron_indices].to(model.cfg.device)  # [N, d_model]
    steering_vector = steer_vecs.sum(dim=0)  # Alternatively, use .mean(dim=0)

    # Build hook
    hook_fn = partial(
        steering_hook_fn,
        steering_vector=steering_vector,
        strength=strength,
        max_act=max_act
    )

    # Apply hook and generate
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, hook_fn)]):
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=True,
            prepend_bos=sae.cfg.prepend_bos
        )

    return model.tokenizer.decode(output_ids[0], skip_special_tokens=True)


In [None]:
def clean_and_shorten(text):
    # Remove special tokens like <bos> and the prompt
    cleaned = text.replace("<bos>", "").strip()
    prompt_prefix = "I can't believe that you said it to me:"
    if cleaned.lower().startswith(prompt_prefix.lower()):
        cleaned = cleaned[len(prompt_prefix):].strip()

    match = re.search(r"^(.*?[.!?])(?=\s|\n|$)", cleaned + (" " or "\n"))
    return match.group(1).strip() if match else cleaned


In [None]:
sentiment_pipeline = pipeline("sentiment-analysis")

In [None]:
nltk.download('vader_lexicon')

sia = SentimentIntensityAnalyzer()
vader_lex = sia.lexicon

# 1. Extended slang/swear mapping: variant → base word in VADER
variant_to_base = {
    # Strong Negative
    'fucking': 'fuck',
    'fucked': 'fuck',
    'motherfucker': 'fuck',
    'asshole': 'shit',
    'douche': 'shit',
    'douchebag': 'shit',
    'bullshit': 'shit',
    'jerk': 'jerk',
    'bitches': 'bitch',
    'cunt': 'bitch',
    'slut': 'bitch',
    'whore': 'bitch',
    'twat': 'bitch',
    'pussy': 'bitch',
    'moron': 'idiot',
    'retard': 'idiot',
    'stupid': 'idiot',
    'dumbass': 'idiot',
    'loser': 'idiot',
    'trash': 'idiot',
    'cringe': 'lame',
    'pathetic': 'lame',
    'toxic': 'bad',
    'ew': 'bad',
    'meh': 'bad',
    'wtf': 'damn',
    'creepy': 'scary',     # fixed
    'ugly': 'bad',
    'nasty': 'bad',
    'deadinside' : 'depressing',

    # Positive Slang (re-mapped to valid VADER bases)
    'queen': 'amazing',
    'king': 'amazing',
    'slay': 'amazing',
    'boss': 'amazing',
    'icon': 'amazing',
    'legend': 'amazing',
    'goddess': 'amazing',
    'goat': 'great',
    'goated': 'great',
    'banger': 'awesome',
    'fire': 'awesome',
    'based': 'awesome',
    'lit': 'awesome',
    'dope': 'awesome',
    'hella': 'good',
    'savage': 'strong',
    'cute': 'sweet',
    'adorable': 'sweet',
    'fine': 'nice',
    'hot': 'nice',
    'sexy': 'nice',
    'clean': 'nice',
    'smooth': 'nice',
    'beautiful': 'nice',
    'pretty': 'sweet',

    # Love/excitement slang
    'loveee': 'love',
    'lovin': 'love',
    'obsessed': 'love',
    'crushing': 'love',
    'crushin': 'love',
    'inlove': 'love',
    'cutie': 'sweet',
    'sweetie': 'sweet',
    'bby': 'sweet',
    'boo': 'sweet',
    'bae': 'sweet',
    'ily': 'love',
    'ily2': 'love',
    'xoxo': 'love',

    # Casual/slang humor or approval
    'deadass': 'serious',
    'fr': 'serious',
    'bruh': 'funny',
    'lmao': 'funny',
    'rofl': 'funny',
    'lol': 'funny',
    'omg': 'wow',
    'vibing': 'happy',
    'vibe': 'happy',
    'energy': 'happy',

    # Sadness / Depression (slangified)
    'sadge': 'sad',
    'deadinside': 'depressing',
    'cryin': 'sad',
    'cryinggg': 'sad',
    'sobbing': 'sad',
    'nooo': 'sad',
    'ughhh': 'sad',
    'mentallyill': 'depressing',
    'depr3ssed': 'depressing',
    'downbad': 'sad',
    'voidcore': 'depressing',
    'brainrotted': 'depressing',
    'overit': 'sad',
    'can’ttakeit': 'depressing',
    'emptyaf': 'sad',
    'selfhatin': 'bad',

    # Anxiety / Fear / Panic (slangified)
    'scaredaf': 'scary',
    'panikin': 'scary',
    'anxiousss': 'scary',
    'stressing': 'scary',
    'freakinout': 'scary',
    'paranoidd': 'scary',
    'helplessss': 'sad',
    'losingit': 'scary',
    'nervousaf': 'scary',
    'shaking': 'scary',
    'brainmelting': 'scary',

    # Disgust / Repulsion (slangified)
    'eww': 'gross',
    'vom': 'gross',
    'nastyyy': 'gross',
    'disgustinn': 'gross',
    'cringeaf': 'gross',
    'icky': 'gross',
    'yuck': 'gross',
    'throwingup': 'gross',
    'grossedout': 'gross',
    'gagging': 'gross',

    # Joy / Affection / Love / Excitement (slangified)
    'adorbs': 'sweet',
    'cutiepie': 'sweet',
    'angelbaby': 'sweet',
    'sunshiny': 'happy',
    'preciousaf': 'sweet',
    'ilysm': 'love',
    'ily2': 'love',
    'lovinggg': 'love',
    'obsessssed': 'love',
    'snuggly': 'love',
    'heartmelt': 'love',
    'blessedaf': 'grateful',
    'hypeddd': 'excited',
    'vibinggg': 'happy',
    'ecstaticcc': 'happy',
    'excitedd': 'excited',
    'inloveee': 'love',
    'crushinnn': 'love',
    'cutenessoverload': 'sweet',
    'hearteyes': 'love',
}



# 2. Build adjusted lexicon using VADER scores
adjusted_custom_lexicon = {}

for word, base in variant_to_base.items():
    base_score = vader_lex.get(base)
    if base_score is not None:
        if word not in vader_lex:
            adjusted_custom_lexicon[word] = base_score
        else:
            print(f"⏩ '{word}' already in VADER — skipping override.")
    else:
        print(f"⚠️ Base word '{base}' not found in VADER for '{word}' — skipping.")



# 3. Update VADER with these custom words
sia.lexicon.update(adjusted_custom_lexicon)

# 4. Test the result
examples = [
    "You are a fucking asshole.",
    "That song is an absolute banger!",
    "Stop being so cringe.",
    "She's a queen. Totally goated.",
    "This is such bullshit.",
    "I'm deadass serious.",
    "That guy is a total douchebag.",
]

for ex in examples:
    print(f"→ {ex}")
    print(sia.polarity_scores(ex))
    print('-' * 40)



In [None]:
disgust_neurons = [4456, 6953, 13324, 14857]
exc_neurons = [230, 16148]
sad_neurons = [5810, 15539]
fear_neurons = [7769]
love_neuron = [2249, 4326, 5810, 15366]
other_neurons = [1898, 3636, 7077, 16148]
anger_neuron = [2438, 4560, 4859, 7579, 9065, 13324, 14857]

In [None]:
ma = max_act_df.set_index('neuron').drop(columns="Unnamed: 0")

In [None]:
from collections import defaultdict
from textblob import TextBlob
from nltk.sentiment import SentimentIntensityAnalyzer

pol = defaultdict(list)
texts = defaultdict(list)

prompt = """I can't believe that you said it to me:"""
c = 0
for n in other_neurons:
    #max_act = find_max_activation(gemma, gemma_sae, activation_store, n, num_batches=50)
    max_act = float(ma.loc[ma.index == n, 'max_activation'].iloc[0])

    for _ in range(30):
        #print("\nWith anger neuron steering:")
        resp = generate_with_steering(gemma, gemma_sae, prompt, n, max_act, strength=3, max_new_tokens=20)
        r = clean_and_shorten(resp)
        pol[n].append(TextBlob(r).sentiment.polarity)
        pol[n].append(TextBlob(r).sentiment.subjectivity)
        pol[n].append(sia.polarity_scores(r))
        texts[n].append(r)
        sent = sentiment_pipeline(r)[0]
        pol[n].append(sent['label'])
        pol[n].append(sent['score'])
        #print(clean_and_shorten(r), '\n')
    c += 1
    print(c)


In [None]:
import pandas as pd
import json

combined = []

for n in texts:
    responses = texts[n]
    for i, response in enumerate(responses):
        base = 5 * i
        record = {
            'neuron': n,
            'response_index': i,
            'text': response,
            'polarity': pol[n][base],              # TextBlob polarity
            'subjectivity': pol[n][base + 1],      # TextBlob subjectivity
            'label_sent': pol[n][base + 3],        # Transformer label
            'score_sent': pol[n][base + 4],        # Transformer score
        }

        vader = pol[n][base + 2]
        record.update(vader)

        combined.append(record)


df = pd.DataFrame(combined)

df.to_csv('other_max_3.csv', index=False)



In [None]:
!pip install vaderSentiment