<a href="https://colab.research.google.com/github/jmerizia/city-circuits/blob/main/gpt2_neurons_faster.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# licenses

In [None]:
# For this notebook

# MIT License

# Copyright (c) 2021 Jacob Merizian

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

In [None]:
# source: https://github.com/graykode/gpt-2-Pytorch/blob/master/LICENSE

# MIT License

# Copyright (c) 2019 OpenAI, HugginFace Inc. team. and TaeHwan Jung

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

In [None]:
# source: https://github.com/openai/gpt-2/blob/master/LICENSE

# Modified MIT License

# Software Copyright (c) 2019 OpenAI

# We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
# We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.

# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
# The above copyright notice and this permission notice need not be included
# with content created by the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE.

# Ensure you are using GPU acceleration

In [None]:
!nvidia-smi

# define encoder and other utilities

In [None]:
'''
From https://github.com/openai/gpt-2/blob/master/src/encoder.py
'''

"""Byte pair encoding utilities"""

import os
import json
import regex as re
from functools import lru_cache

@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

def get_encoder():
    with open('./encoder.json', 'r') as f:
        encoder = json.load(f)
    with open('./vocab.bpe', 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )


In [None]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/utils.py
See above for original license.
'''

'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import logging

logger = logging.getLogger(__name__)

def load_weight(model, state_dict):
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if key.endswith(".g"):
            new_key = key[:-2] + ".weight"
        elif key.endswith(".b"):
            new_key = key[:-2] + ".bias"
        elif key.endswith(".w"):
            new_key = key[:-2] + ".weight"
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
        )
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    start_model = model
    if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
        start_model = model.transformer
    load(start_model, prefix="")

    # Make sure we are still sharing the output and input embeddings after loading weights
    model.set_tied()
    return model

In [None]:
!curl -o encoder.json https://raw.githubusercontent.com/jmerizia/city-circuits/main/encoder.json
!curl -o vocab.bpe https://raw.githubusercontent.com/jmerizia/city-circuits/main/vocab.bpe

# gpt model

In [None]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/config.py
See above for original license.
'''

'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
class GPT2Config(object):
    def __init__(
            self,
            vocab_size_or_config_json_file=50257,
            n_positions=1024,
            n_ctx=1024,
            n_embd=1600,
            n_layer=48,
            n_head=25,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
    ):
        self.vocab_size = vocab_size_or_config_json_file
        self.n_ctx = n_ctx
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range


In [None]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/model.py
See above for original license.
'''

'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import copy
import torch
import math
import torch.nn as nn
from torch.nn.parameter import Parameter

ACTIVATION_THRESHOLD = 2

def neuron_name_matches_specifier(name, specifier):
    for n, s in zip(name.split(':'), specifier.split(':')):
        if s != '*' and s != n:
            return False
    return True

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x

class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)

    def _attn(self, q, k, v):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns-nd:ns, :ns]
        w = w * b - 1e10 * (1 - b)
        w = nn.Softmax(dim=-1)(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def forward(self, x, layer_past=None):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        if layer_past is not None:
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        a = self._attn(query, key, value)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        return a, present

class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.config = config
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu
        # self.register_buffer('h', torch.zeros(1, 15, config.n_embd * 4))
        # self.register_buffer('h2', torch.zeros(1, 15, config.n_embd))

    def forward(self, x, name_prefix=''):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return h2

class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super(Block, self).__init__()
        nx = config.n_embd
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)

    def forward(self, x, layer_past=None, name_prefix=''):
        a, present = self.attn(self.ln_1(x), layer_past=layer_past)
        x = x + a
        m = self.mlp(self.ln_2(x), name_prefix=name_prefix + ':mlp')
        x = x + m
        return x, present

class GPT2Model(nn.Module):
    def __init__(self, config):
        super(GPT2Model, self).__init__()
        self.n_layer = config.n_layer
        self.n_embd = config.n_embd
        self.n_vocab = config.vocab_size

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        block = Block(config.n_ctx, config, scale=True)
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def set_embeddings_weights(self, model_embeddings_weights):
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.decoder.weight = model_embeddings_weights  # Tied weights

    def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
                                        device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
        presents = []
        for layer_idx, (block, layer_past) in enumerate(zip(self.h, past)):
            hidden_states, present = block(hidden_states, layer_past, name_prefix=f'block{layer_idx}')
            presents.append(present)
        hidden_states = self.ln_f(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)
        return hidden_states.view(*output_shape), presents

class GPT2LMHead(nn.Module):
    def __init__(self, model_embeddings_weights, config):
        super(GPT2LMHead, self).__init__()
        self.n_embd = config.n_embd
        self.set_embeddings_weights(model_embeddings_weights)

    def set_embeddings_weights(self, model_embeddings_weights):
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.decoder.weight = model_embeddings_weights  # Tied weights

    def forward(self, hidden_state):
        # Truncated Language modeling logits (we remove the last token)
        # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
        lm_logits = self.decoder(hidden_state)
        return lm_logits

class GPT2LMHeadModel(nn.Module):
    def __init__(self, config):
        super(GPT2LMHeadModel, self).__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)

    def set_tied(self):
        """ Make sure we are sharing the embeddings
        """
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight)

    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
        hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
        lm_logits = self.lm_head(hidden_states)
        if lm_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
            return loss
        return lm_logits, presents

In [None]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/sample.py
See above for original license.
'''

'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import torch
import torch.nn.functional as F

def top_k_logits(logits, k):
    if k == 0:
        return logits
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

def sample_sequence(model, length, context, batch_size=None, temperature=1, top_k=0, device='cuda', sample=True):
    context = torch.tensor(context, device=device, dtype=torch.long)
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in range(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)
            output = torch.cat((output, prev), dim=1)
    return output

# load the model

In [None]:
# Download the GPT-2 model weights
!curl -o gpt2-model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin

In [None]:
import gc

if 'model' in globals():
    print('Cleaning up old model')
    del model
gc.collect()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = get_encoder()
config = GPT2Config(
    n_positions=1024,
    n_ctx=1024,
    n_embd=1600,
    n_layer=48,
    n_head=25,
)
model = GPT2LMHeadModel(config).to(device)
if 'state_dict' not in globals():
    state_dict = torch.load('gpt2-model.bin', map_location=device)
model = load_weight(model, state_dict)
model.eval()
print('Loaded weights')

# load the dataset

(uploaded here from the other notebook)

In [None]:
import json
from tqdm import tqdm
import gc
import random
import os

assert os.path.exists('dataset.json'), 'Missing dataset.json! Did you upload it from the other notebook?'

with open('dataset.json', 'r') as f:
    dataset = json.loads(f.read())
print('Dataset has', len(dataset), 'examples.')

# register hooks

In [None]:
import pickle
import torch
import gc
from tqdm import tqdm
from functools import partial


# make sure we clear up any existing hooks if this cell is re-run
if 'handles' in globals():
    print('Removing existing hooks...')
    for handle in handles:
        handle.remove()

module_names_to_track_for_activations = \
    [f'transformer.h.{i}.mlp.c_proj' for i in range(config.n_layer)]
module_names_to_track_for_logits = \
    [f'transformer.h.{i}' for i in range(config.n_layer)]

hidden_states = dict()
def save_hidden_states(name, module, input, output):
    # we care about input for activations, and output for logit lens
    hidden_states[name] = { 'input': input[0], 'output': output }

handles = []
cnt = 0
for name, m in model.named_modules():
    if name in module_names_to_track_for_activations or name in module_names_to_track_for_logits:
        cnt += 1
        handle = m.register_forward_hook(partial(save_hidden_states, name))
        handles.append(handle)

print('Tracking', cnt, 'layers.')

In [None]:
!pip install jsonlines

In [None]:
# some settings

num_tokens = 100
top_k = 5
activation_threshold = 3

In [None]:
import jsonlines


def run_model_with_tracking(tokens):
    with torch.no_grad():
        context = torch.tensor([tokens], device=device, dtype=torch.long)
        # effectively:
        #  length = 1
        #  temp = 0
        #  batch_size = 1
        logits, past = model(context, past=None)
        return logits


def extract_neuron_values(threshold=5):
    """
    For each MLP, determine which neurons fire at any point during the entire sequence,
    unless it only fires on the first token (which we will just assume is noise).

    The output is a list of dicts resembling individual neurons with fields:
        l: the layer of the neuron
        f: the index of the neuron in the feature dimension
        a: a list of activations equal to the length of the sequence

    """

    neurons = []
    for name in module_names_to_track_for_activations:
        h = hidden_states[name]['input']
        neurons.append(h[0])
    neurons = torch.stack(neurons)
    # ignore first activations (too noisy!)
    high_activations = (neurons[:, 1:, :] > threshold).nonzero()
    values = []
    uniq = set()
    for layer_idx, _, feature_idx in high_activations:
        layer_idx = layer_idx.item()
        feature_idx = feature_idx.item()
        if (layer_idx, feature_idx) in uniq:
            # we already have it!
            continue
        uniq.add((layer_idx, feature_idx))
        values.append({
            'l': layer_idx,
            'f': feature_idx,
            'a': neurons[layer_idx, :, feature_idx].reshape([neurons.shape[1]]).tolist(),
        })
    return values


def extract_logit_lens(k=10):
    """
    Extract the output logits for each layer (including the final layer)

    Returns a nested list structure of shape [n_layers, n_seq, k]
    where each element is a dict containing:
        tok: the predicted token
        prob: the probability given to this token (from softmax of logits)

    Note: The sum of the final dimension probabilities will be very close to 1.
    """

    per_layer_tokens = []
    for name in module_names_to_track_for_logits:
        h2 = hidden_states[name]['output'][0]  # x, present
        with torch.no_grad():
            layer_logits = model.lm_head(h2).detach()[0]
        seq = layer_logits.shape[0]
        values, indices = torch.topk(layer_logits, k=k)
        norm_values = F.softmax(values, dim=-1)
        indices = indices.cpu()
        norm_values = norm_values.cpu()
        top_in_sequence = []
        for i in range(seq):
            top_tokens = []
            for tok, prob in zip(indices[i], norm_values[i]):
                tok = tok.item()
                prob = prob.item()
                top_tokens.append({
                    'tok': enc.decode([tok]),
                    'prob': prob,
                })
            top_in_sequence.append(top_tokens)
        per_layer_tokens.append(top_in_sequence)
    return per_layer_tokens


model.eval()

with torch.no_grad():
    with jsonlines.open('neurons.jsonl', 'w') as writer:
        for row in tqdm(dataset):
            text = row['text']
            tokens = enc.encode(text)
            tokens = tokens[:num_tokens]
            run_model_with_tracking(tokens)
            activations = extract_neuron_values(threshold=activation_threshold)
            logits = extract_logit_lens(k=top_k)  # [48, seq, 5]
            record = {
                'activations': activations,
                'logits': logits,
            }
            writer.write(record)

# upload the neuron file

In [None]:
!pip install awscli

In [None]:
!aws configure

In [None]:
!aws s3 cp neurons.jsonl s3://gpt2-neurons/wikipedia-first-lines/

# cluster the neurons

In [None]:
from collections import defaultdict
import os
import json
from tqdm import tqdm
import jsonlines
import torch


def iter_neuron_records():
    with jsonlines.open(f'neurons.jsonl') as reader:
        for obj in reader:
            yield obj

def iter_neurons():
    for l in range(48):
        for f in range(1600*4):
            yield l, f

def first_n(gen, n):
    buffer = []
    for v, idx in zip(gen, range(n)):
        buffer.append(v)
    return buffer

def normalize(v):
    v = [max(0, e) for e in v]
    mx = max(v)
    v = [e / mx for e in v]
    return v

# essentially, for each neuron, take the top N
# examples that this neuron responds to.
# Then, for each, multiply the weighted sum of all embedded tokens
# in the sequenced, weighted by the normalized activations,
# then concat all of these.

neuron_to_example_indices = defaultdict(list)
neuron_records = iter_neuron_records()
num_examples = len(dataset)
for idx in tqdm(range(num_examples)):
    example = dataset[idx]
    neuron_record = next(neuron_records)
    for record in neuron_record['activations']:
        mx = max(record['a'])
        if mx >= 2:
            neuron_to_example_indices[(record['l'], record['f'])].append({ 'activations': record['a'], 'exampleIdx': idx })
neurons = []
for l, f in iter_neurons():
    if (l, f) in neuron_to_example_indices:
        neurons.append((l, f))

In [None]:
def embed_tokens(tokens):
    with torch.no_grad():
        input_ids = torch.tensor([tokens], device=device, dtype=torch.long)

        position_ids = torch.arange(0, input_ids.size(-1), dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = model.transformer.wte(input_ids)
        position_embeds = model.transformer.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds
        return hidden_states


latent_vectors = []
for l, f in tqdm(neurons):
    top_examples = first_n(
        sorted(
            neuron_to_example_indices[(l, f)],
            key=lambda x: max(x['activations'])
        ),
        1
    )
    latent_vector_2d = []
    for top_example in top_examples:
        idx = top_example['exampleIdx']
        normalized_activations = torch.tensor(normalize(top_example['activations']))
        l = normalized_activations.shape[0]
        normalized_activations = normalized_activations.reshape((l, 1))
        example = dataset[idx]
        tokens = enc.encode(example['text'])
        embedded_tokens = embed_tokens(tokens[:l]).detach().cpu()[0]
        latent_vector_part = torch.multiply(embedded_tokens, normalized_activations).sum(dim=0)
        latent_vector_2d.append(latent_vector_part)
    latent_vector = latent_vector_2d[0]
    latent_vectors.append(latent_vector)
latent_vectors = torch.stack(latent_vectors)
latent_vectors.shape

In [None]:
# first compress down to 50 dimensions with PCA,
# then down to 2 with t-SME

from sklearn.decomposition import PCA

pca = PCA(n_components=100, whiten=True)
pca.fit(latent_vectors.T)
pca.components_.shape

In [None]:
!pip install opentsne

In [None]:
from openTSNE import TSNE
import time

tsne = TSNE(
    perplexity=30,
    metric="euclidean",
    n_jobs=2,
    random_state=42,
    verbose=True,
)

# use about 30% of the neurons to building an embedding in 2D:
st = time.time()
X = np.array(random.sample(list(pca.components_.T), 30000))
embedding = tsne.fit(X)
X = embedding.transform(pca.components_.T).T
en = time.time()
print(en-st)
X.shape

In [None]:
import plotly.express as px
import plotly.graph_objects as go

hovertext = [str(neuron) for neuron in neurons]
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=X[0],
    y=X[1],
    hovertext=hovertext,
    hoverinfo="text",
    showlegend=False,
    mode='markers',
    marker={
        'size': 2,
    },
))

fig.show()
# px.scatter(x=X[0], y=X[1], hovertext=[str(neuron) for neuron in neurons])

# create the index for fast viewing on the browser

In [None]:
from collections import defaultdict
import os
import json
from tqdm import tqdm
import jsonlines


def iter_neuron_records():
    with jsonlines.open(f'neurons.jsonl') as reader:
        for obj in reader:
            yield obj

def iter_neurons():
    for l in range(48):
        for f in range(1600*4):
            yield l, f


base_path = 'index'

if not os.path.exists(base_path):
    os.makedirs(base_path, exist_ok=True)

num_examples = len(dataset)

print(f'Generating example-level indices into the {base_path} folder...')
neuron_records = iter_neuron_records()
for idx in tqdm(range(num_examples)):
    example = dataset[idx]
    neuron_record = next(neuron_records)
    with open(os.path.join(base_path, f'example-{idx:05}.json'), 'w') as f:
        f.write(json.dumps({
            'example': example,
            'activations': neuron_record['activations'],
            'logits': neuron_record['logits'],
            'tokens': [enc.decode([t]) for t in enc.encode(example['text'])]
        }))

print(f'Generating neuron-level indices into the {base_path} folder...')
neuron_to_example_indices = defaultdict(list)
neuron_records = iter_neuron_records()
for idx in tqdm(range(num_examples)):
    example = dataset[idx]
    neuron_record = next(neuron_records)
    for record in neuron_record['activations']:
        mx = max(record['a'])
        if mx >= 2:
            neuron_to_example_indices[(record['l'], record['f'])].append({ 'activations': record['a'], 'exampleIdx': idx })

for k in neuron_to_example_indices.keys():
    neuron_to_example_indices[k] = list(sorted(neuron_to_example_indices[k], key=lambda e: max(e['activations']), reverse=True))

for (l, f) in iter_neurons():
    if (l, f) in neuron_to_example_indices:
        with open(os.path.join(base_path, f'neuron-{l}-{f}.json'), 'w') as file:
            file.write(json.dumps(neuron_to_example_indices[(l, f)]))

print(f'Generating neuron-cluster indices into the {base_path} folder...')
with open(os.path.join(base_path, 'cluster.json'), 'w') as fp:
    obj = {
        'x': [round(e, 3) for e in X[0]],
        'y': [round(e, 3) for e in X[1]],
        'neurons': neurons,
    }
    json.dump(obj, fp)

In [None]:
!du -h {base_path}/cluster.json

In [None]:
# zip the files
!zip -rq index.zip ./index

In [None]:
!aws s3 cp index.zip s3://gpt2-neurons/wikipedia-first-lines/index2.zip

# (you can stop here, the rest is just computing statistics on the data)

In [None]:
from collections import defaultdict

unique_neurons = set()
for record in records:
    for neuron in record['activations']:
        unique_neurons.add((neuron['l'], neuron['f']))
print(len(unique_neurons))

per_layer_totals = defaultdict(lambda: [])
for record in records:
    per_layer_counts = defaultdict(lambda: 0)
    for neuron in record['activations']:
        l = neuron['l']
        per_layer_counts[l] += 1
    for k, v in per_layer_counts.items():
        per_layer_totals[k].append(v)
for k, v in sorted(per_layer_totals.items()):
    print(k, ':', sum(v) / len(v))

In [None]:
with open('neurons.json', 'w') as f:
    f.write(json.dumps(records))

In [None]:
!du -h neurons.json