In [1]:
import torch
import json
import numpy as np
import tiktoken

from gpt_model import GPTModel
from sparse_auto_encoder import SparseAutoencoder

In [2]:
device = "cpu"
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 256,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.2,
    "qkv_bias": True,
    "device": "cpu",
}

In [3]:
model = GPTModel(GPT_CONFIG_124M)
checkpoint = torch.load("model_768_12_12_old_tok.pth", weights_only=True, map_location=torch.device('cpu'))

model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval();

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

In [5]:
sae_6 = SparseAutoencoder(input_dim=768, hidden_dim=3072).to(device)
sae_6.load_state_dict(torch.load("sae_model_6_3072.pth", map_location=torch.device('cpu')))
sae_6.eval();

sae_12 = SparseAutoencoder(input_dim=768, hidden_dim=3072).to(device)
sae_12.load_state_dict(torch.load("sae_model_12_3072.pth", map_location=torch.device('cpu')))
sae_12.eval();

In [6]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

In [7]:
def get_token_embeddings(text, model, tokenizer, layers=[6, 12]):
    """
    Extracts token embeddings from specified transformer layers.

    Args:
    - text (str): Input text.
    - model: Custom GPT model.
    - tokenizer: tiktoken encoding object.
    - layers (list): Transformer layers to extract embeddings from.

    Returns:
    - dict: Layer-wise token embeddings {layer_number: embeddings}
    """

    input_ids = text_to_token_ids(text, tokenizer).to(device)

    with torch.no_grad():
        _, hidden_states = model(input_ids, output_hidden_states=True)

    embeddings = {} 
    for layer in layers:
        if layer - 1 < len(hidden_states):
            embeddings[layer] = hidden_states[layer - 1].squeeze(0).cpu().numpy()
        else:
            print(f"⚠️ Warning: Layer {layer} is out of range (max index {len(hidden_states) - 1})")

    return embeddings

In [8]:
from collections import defaultdict

def find_top_activating_neurons(concept_to_texts, model, tokenizer, sae, get_token_embeddings, layer=6, top_k=5, device='cpu'):
    concept_top_neurons = {}

    for concept, sentences in concept_to_texts.items():
        print(f"Processing concept: {concept}")
        neuron_activation_counts = defaultdict(int)
        total_tokens = 0

        for sentence in sentences:
            embeddings_np = get_token_embeddings(sentence, model, tokenizer, layers=[layer])[layer]
            embeddings = torch.tensor(embeddings_np, dtype=torch.float32).to(device)

            decoded, encoded = sae(embeddings)  # encoded shape: (seq_len, n_features)
            top_neuron_indices = torch.argmax(encoded, dim=1).cpu().numpy()

            for idx in top_neuron_indices:
                neuron_activation_counts[idx] += 1

            total_tokens += encoded.shape[0]

        neuron_avg_activation = {k: v / total_tokens for k, v in neuron_activation_counts.items()}
        top_neurons = sorted(neuron_avg_activation.items(), key=lambda x: x[1], reverse=True)[:top_k]
        concept_top_neurons[concept] = [neuron for neuron, _ in top_neurons]

    return concept_top_neurons

In [9]:
with open("concepts_to_text.json", "r", encoding="utf-8") as f:
    concept_to_texts = json.load(f)

top_neurons = find_top_activating_neurons(
    concept_to_texts=concept_to_texts,
    model=model,
    tokenizer=tokenizer,
    sae=sae_6,
    get_token_embeddings=get_token_embeddings,
    layer=6,
    top_k=5,
    device='cpu'
)

print(top_neurons)

Processing concept: marriage_as_duty
Processing concept: romantic_love
Processing concept: social_class
Processing concept: moral_superiority
Processing concept: stigma_of_spinsterhood
Processing concept: wealth_and_inheritance
Processing concept: female_professions
Processing concept: male_professions
Processing concept: reputation_and_gossip
Processing concept: truth_and_honesty
Processing concept: vanity_and_appearance
Processing concept: matchmaking_positive
Processing concept: matchmaking_negative
Processing concept: social_hierarchy
{'marriage_as_duty': [720, 639, 443, 2899, 1996], 'romantic_love': [720, 1091, 639, 1313, 1228], 'social_class': [720, 1996, 1080, 639, 779], 'moral_superiority': [1080, 1091, 2899, 639, 443], 'stigma_of_spinsterhood': [1080, 639, 1313, 317, 811], 'wealth_and_inheritance': [720, 1313, 604, 1091, 2604], 'female_professions': [779, 1441, 720, 2604, 5], 'male_professions': [779, 1441, 720, 2755, 2604], 'reputation_and_gossip': [2755, 720, 1441, 1313, 443

In [10]:
idx = text_to_token_ids("Marriage is", tokenizer).to(device)
    
for _ in range(10):
    idx_cond = idx[:, -GPT_CONFIG_124M['context_length']:]

    # 1. Run forward to get hidden state at layer 6
    _, hiddens = model(idx_cond, output_hidden_states=True)
    layer6_hidden = hiddens[5].detach().clone()

    # 2. Inject neuron 720 with a high value
    layer6_hidden[:, :, 720] += 5.0  # boost neuron 720 activation

    # 3. Run from layer 6 onward using intervene_layer=6
    logits = model(
        idx_cond,
        intervene_layer=6,
        edited_hidden=layer6_hidden,
        output_hidden_states=False
    )
    idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)
    idx_next = idx_next[:, -1, :]

    # Same as before: append sampled index to the running sequence
    idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    ouput_text = token_ids_to_text(idx, tokenizer)

In [11]:
ouput_text

'Marriage is a very good deal of the world. I am'

In [35]:
idx = text_to_token_ids("Marriage is", tokenizer).to(device)
boost_value = 5.0
layer = 6
    
for _ in range(10):
    with torch.no_grad():
        _, hidden_states = model(idx, output_hidden_states=True)
        layer_hidden = hidden_states[layer - 1].squeeze(0)

        intervened_hidden = sae_12.intervene_and_decode(layer_hidden, 1090, boost=5)

#         encoded[:, 2899] += boost_value

#         intervened_hidden = sae_6.decoder(encoded)
#         intervened_hidden = intervened_hidden.unsqueeze(0)  # shape: (batch, seq_len, emb_dim)

        logits = model(idx, intervene_layer=(layer - 1), edited_hidden=intervened_hidden.unsqueeze(0))
        logits = logits[:, -1, :]

#         # Filter logits with top_k sampling
#         top_logits, _ = torch.topk(logits, 50)
#         min_val = top_logits[:, -1]
#         logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

#         # temperature scaling
#         logits = logits / 0.3

#         # Apply softmax to get probabilities
#         probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

#         # Sample from the distribution
#         idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)


        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

        ouput_text = token_ids_to_text(idx, tokenizer)
        
ouput_text

'Marriage is parallels\nanthafield Parkynch Hallrietriet'