<a href="https://colab.research.google.com/github/davidbau/rome/blob/gpt_stats/notebooks/statistics_of_GPT_states.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [1]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/davidbau/rome -b gpt_stats rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1

In [2]:
IS_COLAB = False
try:
    import google.colab, torch, sys
    if not torch.cuda.is_available():
        print("Change runtime type to include a GPU.")
    sys.path.append('/content/rome')
    IS_COLAB = True
except:
    pass

# Statistics of GPT states

Here we load GPT-2-XL and examine the statistics of its internal states.

We are particularly interested in: how close are the output values to Gaussian.

In [3]:
%load_ext autoreload
%autoreload 2

We build upon the ROME codebase.

There is an `experiments.causal_trace` module that contains some useful utility functions for loading large transformer models in colab.

In [4]:
import os, re, json, math
import torch, numpy
from collections import defaultdict
from experiments.causal_trace import ModelAndTokenizer, predict_token
from baukit import show, nethook

Now we load GPT-2-XL and a tokenizer, and show that it can complete a couple factual statements correctly.

In [5]:
torch.set_grad_enabled(False)
mt = ModelAndTokenizer('gpt2-xl', low_cpu_mem_usage=IS_COLAB)

Now for some neutral text samples, we load some random texts from wikipedia.  Just a few (one batch for now).

In [6]:
from datasets import load_dataset
from baukit import TokenizedDataset, length_collation, flatten_masked_batch

def get_ds(ds_name='wikipedia', tokenizer=None, batch_tokens=None):
    raw_ds = load_dataset(
        ds_name,
        dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name],
    )
    maxlen = mt.model.config.n_positions
    if batch_tokens is not None and batch_tokens < maxlen:
        maxlen = batch_tokens
    return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)
ds = get_ds(tokenizer=mt.tokenizer, ds_name='wikipedia')

Reusing dataset wikipedia (/home/davidbau/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
mt.tokenizer.decode(ds[0]['input_ids'])

'Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People\'s Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children\'s games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (Tiānjīn Shí Jiā Dà Yuàn, 天津石家大院) is situated in Yangliuqing Town of Xiqing District, which is the former residence of wealthy merchant Shi Yuanshi - the 4th son of Shi Wancheng, one of the eight great masters in Tianjin. First built in 1875, it covers over 6,000 square meters, inclu

In [23]:
if False:
    from baukit import show
    item = ds[0]
    cuda_batch = {k: v[None].cuda() for k, v in item.items()}
    logits = mt.model(**cuda_batch)['logits']
    probs = torch.nn.functional.softmax(logits, dim=2)
    input_ids = item['input_ids']
    preds = probs.argmax(dim=2)[0]
    baseline_probs = probs[0, torch.arange(len(input_ids) - 1), input_ids[1:]]
    show([[mt.tokenizer.decode(t), f'{p.item():.3g}', mt.tokenizer.decode(i), mt.tokenizer.decode(j)]
      for t, i, p, j in zip(input_ids[:-1], input_ids[1:], baseline_probs, preds)])

In [132]:
noise_const = torch.from_numpy(numpy.random.RandomState(1).randn(1024 * 1024)).float().cuda()
PLAIN = show.style()
RED = show.style(background='red')
BLUE = show.style(background='skyblue')
HEADING = show.style(fontWeight='bold', borderBottom='1px solid black')

def run_with_noise(datum, index, amplitude=0.1):
    maxlen = mt.model.config.n_positions
    input_ids = datum['input_ids'][:maxlen]
    cuda_batch = {k: v[None][:maxlen].cuda() for k, v in datum.items()}
    logits = mt.model(**cuda_batch)['logits']
    probs = torch.nn.functional.softmax(logits, dim=2)
    preds = probs.argmax(dim=2)[0]
    baseline_probs = probs[0, torch.arange(len(input_ids) - 1), input_ids[1:]]
    def add_noise(x):
        numel = x[:, index].numel()
        shape = x[:, index].shape
        x[:,index] += amplitude * noise_const[:numel].view(shape)
    with nethook.Trace(mt.model, 'transformer.wte', edit_output=add_noise):
        n_logits = mt.model(**cuda_batch)['logits']
    n_probs = torch.nn.functional.softmax(n_logits, dim=2)
    n_preds = probs.argmax(dim=2)[0]
    n_baseline_probs = n_probs[0, torch.arange(len(input_ids) - 1), input_ids[1:]]
    if (baseline_probs / n_baseline_probs)[index + 5:].max() > 4.0:
        last_index = ((baseline_probs / n_baseline_probs) > 4.0).nonzero()[-1, 0] + 5
        show.bare(show.style(display='flex'), [[HEADING, ['tok', 'gtprob', 'corrupted', 'gt', 'pred', 'corrupted']] +
             [[show.Tag('p'), (BLUE if ind == index else PLAIN), mt.tokenizer.decode(t),
               f'{p.item():.3g}',
               (RED if np.item() * 4 < p.item() else PLAIN), f'{np.item():.3g}',
               (RED if np.item() * 4 < p.item() else PLAIN), mt.tokenizer.decode(i),
               mt.tokenizer.decode(j),
               mt.tokenizer.decode(nj)]
          for ind, (t, i, p, j, np, nj) in enumerate(zip(input_ids[:-1], input_ids[1:last_index], baseline_probs, preds, n_baseline_probs, n_preds))]])
run_with_noise(ds[7], 3)

In [109]:
for doc in range(10):
    for tok in range(10):
        run_with_noise(ds[doc], tok)

In [10]:
ds[10]['input_ids'].shape

torch.Size([47])