# Init

In [2]:
import numpy as np
from copy import deepcopy
import gc
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch import nn
import torch 
import torch.nn.functional as F
from torch.utils.data import DataLoader
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from speaking_probes.generate import (extract_gpt_parameters, extract_gpt_j_parameters, speaking_probe, encode)
from speaking_probes.extra import top_tokens, corpus_search
%reload_ext autoreload
%autoreload 2

In [3]:
model_name = 'gpt2-medium'#'EleutherAI/gpt-j-6B' #  
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

## Extract Parameters

In [4]:
model_params = extract_gpt_parameters(model, full=True)

emb = model_params.emb
d_int = model_params.d_int
ff_keys = model_params.ff_keys
ff_values = model_params.ff_values
W_Q_heads = model_params.W_Q_heads
W_K_heads = model_params.W_K_heads
W_V_heads = model_params.W_V_heads
W_O_heads = model_params.W_O_heads
mu = emb.mean(1)

# Speaking Probes

In [5]:
model.parallelize()

In [12]:
prompt = "The term \" <neuron>\" means"

In [14]:
print(prompt)
decoded = speaking_probe(model, model_params, tokenizer, prompt, 
                         (emb[:, tokenizer.encode(' children')[0]].to(model.device) + 
                          emb[:, tokenizer.encode(' dog')[0]].to(model.device)),
                         repetition_penalty=2., num_generations=20,
                         min_length=1, do_sample=True, # length_penalty=-0.2,
                         max_new_tokens=50, temperature=.5)

print('\n')
for i in range(len(decoded)):
    print("generate:", decoded[i])
    print('')
    print("""~~~~~~~~~~~~~~~~~~""" * 5)   
    print('')

The term " <neuron>" means


generate: The term " <neuron>" means a person under the age of 18 years.
 The terms are used in this chapter as follows: (1) A child shall be considered to have attained legal maturity when he or she has reached an understanding with his parent, guardian and other persons that

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

generate: The term " <neuron>" means any animal that is capable of walking on two legs, but not a dog.
            The phrase 'the children' refers to those who were born before the adoption law was changed in 1996 and are now considered as age 18 or 19 years old

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

generate: The term " <neuron>" means any member of the genus Canis familiarus, but does not include chimpanzees or gorillas.
 (2) (a)(i), in paragraph 1, is used to mean a dog that has been domesticated by human beings and whose natural

~~~~~

## Comparison

### Embedding Space

In [229]:
neuron = ff_keys[21, 86].to(emb.device)

In [230]:
top_tokens(neuron @ emb, tokenizer=tokenizer, k=40)

['IDs',
 'identifiers',
 'surname',
 'surn',
 'identifier',
 'initials',
 '#Registered',
 'NAME',
 '#names',
 'pseudonym',
 '#codes',
 'nomine',
 'names',
 'username',
 '#IDs',
 'ID',
 'registration',
 '#76561',
 '#soDeliveryDate',
 '#ADRA',
 'CLSID',
 'numbering',
 '#ername',
 '#address',
 'addresses',
 'codes',
 '#Names',
 'regist',
 'name',
 'Names',
 '#Address',
 '#name',
 'Registration',
 '#Profile',
 'Codes',
 'aliases',
 '#registered',
 'nickname',
 '#wcsstore',
 '#isSpecialOrderable']

In [153]:
top_tokens(neuron @ F.normalize(emb, dim=0), tokenizer=tokenizer)

['purchasing',
 'installation',
 '#buy',
 'font',
 'execution',
 'clock',
 'buying',
 'buys',
 'money',
 'sale',
 'purchase',
 'buy',
 'button',
 'clicking',
 'connecting',
 'committee',
 'connect',
 'transmission',
 'interaction',
 '#budget']

### Corpus Search

In [8]:
corpus = load_dataset("wikipedia", "20220301.en", split=f'train')
corpus = corpus.shuffle().select(range(12_000))['text']

Found cached dataset wikipedia (/home/guydar/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


In [12]:
i1, i2 = 2, 1
results = corpus_search(model2, tokenizer, corpus, layer_path=f'transformer.h.{i1}.mlp.c_fc', # fc_in
                        dim_idx=i2, k=100, batch_size=64, normalize=True)

  0%|          | 0/188 [00:00<?, ?it/s]

removing hook..


In [13]:
for sent, idx in results:
    # idx = idx.item()
    tokens = tokenizer.tokenize(sent)
    min_idx = max(0, idx-20)
    max_idx = min(len(tokens), idx+10)
    
    tokens = [*tokens[min_idx:idx], '@@', tokens[idx], '@@',  *tokens[idx+1:max_idx]]
    print(tokenizer.convert_tokens_to_string(tokens))
    print("""~~~~~~~~~~~~~~~~~~~~""" * 5)

Token indices sequence length is longer than the specified maximum sequence length for this model (3467 > 1024). Running this sequence through the model will result in indexing errors


 of a semi-automatic firearm to fire ammunition cartridges in rapid succession.

The legality of bump@@ stocks@@ in the United States came under question following the
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 (born March 8, 1973) is an American producer, mixer, engineer, re-recording@@ mixer@@, sound designer and musician.   He
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
, the term refers to non-criminal law. The law relating to civil wrongs and quasi-@@contract@@s is part of the civil law. The
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Torch is a hardcore/metal band from Tr@@ond@@heim, Norway.

Musical style
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 rules (informal rules-in-use and norms) and also theoretical rules (formal@@ rules@@ and law). This field