In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import torch
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
import time
from datetime import datetime
from SAEModel import SparseAutoencoder
from torch.nn import functional as F
import json
from collections import defaultdict
from tqdm import tqdm

from eval_utils import getDict, most_common_neurons, get_word_list

### Load SAE

In [2]:
input_dim = 4096  # Input and output dimensions
hidden_dim = input_dim * 4  # Hidden layer dimension
sae_model = SparseAutoencoder(input_dim, hidden_dim, k = 24, dead_steps_threshold=1000000) # initial lambda is 0

state_dict = torch.load("/workspace/LLM_interpretability/TopKSAE/experiments/mistral_pile_k24_experiment_20241206/models/model_epoch_30.pt")
K = 24

# Load the weights into the model
sae_model.load_state_dict(state_dict)
sae_model.to('cuda')

  state_dict = torch.load("/workspace/LLM_interpretability/TopKSAE/experiments/mistral_pile_k24_experiment_20241206/models/model_epoch_30.pt")


SparseAutoencoder()

## Modified Dataset to also return text with embedding for ease of visualisation

In [3]:

class EmbeddingDataset(Dataset):
    def __init__(self, path, file_pattern, use_files = 40):
        # Load all `.pt` files based on the pattern
        self.file_path = path
        self.files = sorted(glob.glob(path+"/"+file_pattern))
        print(f"Num of files found : {len(self.files)}")
        print(f"Num of files used : {len(self.files[:use_files])}")
        self.data = []

        # Read and store all embeddings from all files
        for file in self.files[:use_files]:
            batch_data = torch.load(file)
            # Extract embeddings and flatten them into a list
            self.data.extend([(item["embedding"], item["text"]) for item in batch_data])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [4]:
path = "/workspace/mistral_pile"
file_pattern = "*.pt"  # Adjust the path if needed

pile_data = EmbeddingDataset(path, file_pattern , use_files=20)
pile_data_loader = DataLoader(pile_data, batch_size=1024, shuffle=True)
# If certain words evoke feature activations
pile_data_Dict = getDict(pile_data_loader, sae_model, K=K)

Num of files found : 142
Num of files used : 20


  batch_data = torch.load(file)
Processing Batches:   5%|▌         | 1/20 [00:00<00:04,  4.59it/s]

Processed Batch 0...


Processing Batches: 100%|██████████| 20/20 [00:01<00:00, 17.67it/s]


In [5]:
# pile_data_Dict


In [6]:
print(get_word_list(3597, word_active_neurons=pile_data_Dict))

['er', 'the', 'board', 'in', 'making', 'light', 'on', 'material', 'pad', 'ing', 'adjacent', 'abs', 'ate', 'then', 'damage', 'closed', 'metal', 'expansion', 'described', 'from', 'i', 'direction', 'case', 'manufacturing', 'portion', 'or', 'ive', 'ising', 'find', 'remain', 'group', 'FE', ';', 'another', 'via', 'cover', 'between', 'used', 'per', 'ur', 'above', 'ic', 'invention', 'earing', 'extens', 'en', 'circuit', 'tissue', 'member', 'expos', 'method', 'iments', 'irable', 'article', 'formed', 'b', 'inated', 'Such', 'resid', 'es', 'However', 'rad', 'bond', 'manufactured', 'exposed', 'ively', 'release', 'electrical', 'degree', 'object', 'posed', 'low', 'agent', 'amount', 'deter', 'thus', 'expand', 'has', 'shorter', 'ating', 'subject', 'groups', 'least', 'ible', 'sor', 'generally', 'ined', 'pattern', 'covered', 'derivative', 'precip', 'affected', 'Therefore', 'corresponding', 'cause', 'structural', 'contained', 'portions', 'provide', 'properties', 'affect', 'cross', 'ording', 'pre', 'fine', 

### Load LLM model 

In [7]:
# Load Mistral

from transformers import AutoModelForCausalLM, AutoTokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.messages import UserMessage
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from datasets import load_dataset
import numpy as np

access_token="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"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)

## Residual Hook Function
def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

def get_word_vector_pairs(tokenizer_output, model_output):
    wordVectorPairs = []
    for i, word in enumerate(tokenizer_output):
        #print(i, word, tokenizer.decode(word), model_output[i, :].shape)
        wordVectorPairs.append({"text":tokenizer.decode(word), "embedding": model_output[i, :]})
    # print(wordVectorPairs)
    return wordVectorPairs


def prepare_batch_embeddings(text, tokenizer, model, gather_residual_activations, layer_index=16):

    inputs = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True).to("cuda")
    
    with torch.no_grad():
        residuals = gather_residual_activations(model, layer_index, inputs)
    
    out = get_word_vector_pairs(inputs[0].cpu().tolist(), residuals[0].cpu())

    batch_text = [item["text"] for item in out]
    batch_embeddings = torch.stack([item["embedding"] for item in out])

    return batch_embeddings, batch_text



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



### Visualize Text

In [67]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from matplotlib import cm
from IPython.display import HTML, display

def sentence_heatmap_visualization(batch_text, batch_embeddings, feature_index, sae_model):
    sae_model.eval()

    # Forward pass to get latent activations
    with torch.no_grad():
        _, _, _, latents = sae_model(batch_embeddings.cuda())

    # Extract activations for the given feature index
    feature_activations = latents[:, feature_index]

    # Normalize activations
    norm = Normalize(vmin=feature_activations.min().item(), vmax=feature_activations.max().item())
    token_activations = norm(feature_activations.cpu().numpy())

    # Create HTML with tokens highlighted
    sentence_html = ""
    cmap = cm.get_cmap("Reds")

    # Define transparency scaling factor (0.5 for 50% transparency)
    alpha_scaling = 0.7

    for token, act in zip(batch_text, token_activations):
        if act == 0:
            color_hex = "rgba(255, 255, 255, 0)"  # Fully transparent for zero activation
        else:
            color = cmap(act)
            color_hex = f"rgba({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)}, {color[3] * alpha_scaling})"
        token = token.replace("<", "&lt;").replace(">", "&gt;")  # Sanitize tokens
        sentence_html += f'<span style="background-color: {color_hex}; padding: 2px 4px; margin: 2px; text-decoration: none;">{token}</span> '

    # Render HTML visualizations as a single paragraph
    html_visualization = f"<p style='font-family: monospace; text-decoration: none;'>{sentence_html}</p>"
    display(HTML(html_visualization))

    # Optional: Add a color scale (uncomment if needed)
    # fig, ax = plt.subplots(figsize=(6, 1))
    # fig.subplots_adjust(bottom=0.5)
    # sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    # cbar = plt.colorbar(sm, orientation='horizontal', ax=ax)
    # cbar.set_label('Activation Strength (0 to 1)')
    # plt.show()


In [69]:
# Feature to visualize
feature_index = 3597
text = "The suspect was caught red-handed. Police confirmed the stolen goods were recovered. Crime rates in the city have been steadily increasing."
batch_embeddings, batch_text = prepare_batch_embeddings(text, tokenizer, model, gather_residual_activations, layer_index=16)
batch_embeddings = batch_embeddings.to(torch.float32)
# Visualize the heatmap
sentence_heatmap_visualization(batch_text, batch_embeddings, feature_index, sae_model)


  cmap = cm.get_cmap("Reds")


In [68]:
# Feature to visualize
feature_index = 3597

# Process a single batch from the DataLoader
batch_embeddings, batch_text = next(iter(pile_data_loader))
batch_embeddings = batch_embeddings.to(torch.float32)
print(len(batch_text))
# Visualize the heatmap
sentence_heatmap_visualization(batch_text, batch_embeddings, feature_index, sae_model)


1024


  cmap = cm.get_cmap("Reds")


### Text generation and grouping them

In [83]:
def analyze_feature_responses_grouped_and_visualize(theme_texts, tokenizer, model, sae_model, gather_residual_activations, layer_index=16, activation_threshold=0.5, num_activations=0.8):
    feature_to_text_map = defaultdict(list)  # Map each feature to sentences that activate it

    # Process each text and track feature activations
    for text_id, text in enumerate(theme_texts):
        # Prepare batch embeddings and text
        batch_embeddings, batch_text = prepare_batch_embeddings(text, tokenizer, model, gather_residual_activations, layer_index)
        batch_embeddings = batch_embeddings.to(torch.float32)

        # Forward pass through SAE model
        sae_model.eval()
        with torch.no_grad():
            _, _, _, latents = sae_model(batch_embeddings.cuda())

        # Identify features with activations above threshold
        activated_features = (latents > activation_threshold).nonzero(as_tuple=True)[1].tolist()
        
        # Avoid duplicates for the same feature
        unique_features = set(activated_features)  # Deduplicate feature indices

        # Track sentences for each feature
        for feature_index in unique_features:
            # Add the sentence only if it's not already added
            if (text, batch_embeddings, batch_text) not in feature_to_text_map[feature_index]:
                feature_to_text_map[feature_index].append((text, batch_embeddings, batch_text))

    # Debugging: Print a specific feature's map
    # print(f"Feature 307 Text Map: {feature_to_text_map[307]}")

    # Filter features activated by more than n texts

    
    active_features = [feature for feature, texts in feature_to_text_map.items() if len(texts)/len(theme_texts) > num_activations]
    print(f"Features activated by more than {num_activations} texts: {active_features}")
    
    run = 0
    # Visualize for each feature
    for feature_index in active_features:
        if run == 3:
            break
        print(f"\nVisualizing Feature {feature_index}")
        for text_id, (_, batch_embeddings, batch_text) in enumerate(feature_to_text_map[feature_index]):
            sentence_heatmap_visualization(batch_text, batch_embeddings, feature_index, sae_model)
        run += 1
        
    return active_features


In [84]:
# Crime related
theme_texts = [
    "The suspect was caught red-handed. Police confirmed the stolen goods were recovered.",
    "The crime rate in urban areas has steadily increased over the past decade.",
    "Law enforcement officers are working hard to combat rising theft cases.",
    "Recent statistics show an alarming rise in burglary incidents in the neighborhood.",
    "The investigation into the high-profile murder case is still ongoing.",
    "Cybercrime is becoming a significant threat to online businesses and individuals.",
    "Robbery attempts in broad daylight have left citizens worried about their safety.",
    "A series of fraud cases involving credit cards has been reported recently.",
    "The police successfully thwarted a planned heist at the national bank.",
    "Gang violence in the city is being addressed with increased police patrols."
]

# Example usage
active_features = analyze_feature_responses_grouped_and_visualize(
    theme_texts=theme_texts,
    tokenizer=tokenizer,
    model=model,
    sae_model=sae_model,
    gather_residual_activations=gather_residual_activations,
    layer_index=16,
    activation_threshold=0.95,
    num_activations=0.99
)

Features activated by more than 0.99 texts: [13313, 7179, 12304, 4122, 546, 3625, 11306, 9266, 13875, 4154, 8763, 10820, 580, 1621, 7262, 6757, 6251, 10861, 1661, 6282, 9386, 14002, 4791, 15033, 2749, 12996, 13509, 8397, 3792, 12506, 9450, 12526, 12529, 3842, 4361, 6413, 6933, 7960, 4890, 1823, 14622, 16161, 12068, 2349, 307, 11571, 15161, 15193, 12121, 2908, 8547, 7012, 11110, 881, 3955, 10613, 3965, 15747, 13193, 397, 398, 1433, 13210, 5032, 10154, 13740, 8649, 11230, 15851, 7660, 12786, 14323, 5110]

Visualizing Feature 13313


  cmap = cm.get_cmap("Reds")



Visualizing Feature 7179



Visualizing Feature 12304


In [85]:
# Gender bias related
theme_texts = [
    "She is a nurse who takes care of her patients with great dedication.",
    "He is a firefighter who risked his life to save the family from the burning house.",
    "The CEO of the company announced her vision for the next quarter.",
    "The construction worker showed his team how to operate the heavy machinery.",
    "She is a teacher who always inspires her students to achieve their best.",
    "He is a stay-at-home dad who loves spending time with his children.",
    "The female scientist received a prestigious award for her groundbreaking research.",
    "The male nurse was praised for his compassion and professionalism in the emergency room.",
    "She is a pilot who successfully landed the aircraft during a storm.",
    "The father cooked dinner while the mother worked late at the office.",
    "He is a hairstylist with a reputation for creativity and attention to detail.",
    "She is a mechanic who repaired the car quickly and efficiently.",
    "The mother dropped her child off at school before heading to the board meeting.",
    "The male kindergarten teacher sang songs with his students during storytime.",
    "She is a soldier who served her country with bravery and honor.",
    "He is a chef who specializes in creating intricate desserts and pastries.",
    "The female athlete broke several world records at the international competition.",
    "He is a dancer who performs both classical and contemporary styles.",
    "She is an engineer who designed an innovative bridge to reduce traffic congestion.",
    "The male librarian recommended a fascinating novel to the students."
]


# Example usage
active_features = analyze_feature_responses_grouped_and_visualize(
    theme_texts=theme_texts,
    tokenizer=tokenizer,
    model=model,
    sae_model=sae_model,
    gather_residual_activations=gather_residual_activations,
    layer_index=16,
    activation_threshold=0.9,
    num_activations=0.8
)

Features activated by more than 0.8 texts: [13313, 12304, 4122, 546, 3625, 11306, 9266, 13875, 4154, 4163, 10820, 580, 1621, 3165, 7262, 12383, 6757, 6251, 10861, 3699, 1661, 16002, 6282, 151, 12961, 9386, 11441, 14002, 4791, 15033, 2749, 13509, 16073, 3792, 12506, 12526, 16112, 12529, 3842, 4361, 6413, 7442, 6933, 7960, 4890, 14622, 1823, 7968, 16161, 12068, 2349, 307, 15161, 15193, 12121, 2908, 8547, 7012, 8035, 11110, 881, 3955, 3965, 15747, 4489, 13193, 397, 398, 916, 1433, 13210, 3489, 5032, 10154, 13740, 4528, 10691, 8649, 11230, 6624, 15851, 7660, 12783, 12786, 14323, 5110, 4089, 12796, 9726, 8397, 11571]

Visualizing Feature 13313


  cmap = cm.get_cmap("Reds")



Visualizing Feature 12304



Visualizing Feature 4122


In [13]:
import csv
csv_file = './experiments/K_128_20241202/word_active_neurons.csv'

# Open the CSV file for writing
with open(csv_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    
    # Write the header (optional, for clarity)
    header = ['word'] + [f'num{i+1}' for i in range(128)]  # Generate column names like num1, num2, ...
    writer.writerow(header)
    
    # Write each word and its corresponding list of 128 numbers
    for word, numbers in word_active_neurons.items():
        writer.writerow([word] + numbers)

print(f"Data successfully saved to {csv_file}")

FileNotFoundError: [Errno 2] No such file or directory: './experiments/K_128_20241202/word_active_neurons.csv'