In [1]:
import torch, baukit
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "gpt2-xl"  # gpt2-xl or EleutherAI/gpt-j-6B
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to("cuda"),
    AutoTokenizer.from_pretrained(MODEL_NAME)
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f641c467250>

In [3]:
input = {k: torch.tensor(v)[None].cuda() for k, v in tok('Hello - nice to meet you.  My full name is David Bau.  I work as a software engineer at a company located in the').items()}

In [4]:
input

{'input_ids': tensor([[15496,   532,  3621,   284,  1826,   345,    13,   220,  2011,  1336,
           1438,   318,  3271,   347,   559,    13,   220,   314,   670,   355,
            257,  3788, 11949,   379,   257,  1664,  5140,   287,   262]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1]], device='cuda:0')}

In [5]:
baukit.set_requires_grad(False, model)

In [6]:
logits = model(**input)['logits']

In [7]:
output_scores = logits[0, -1, :]

In [8]:
input

{'input_ids': tensor([[15496,   532,  3621,   284,  1826,   345,    13,   220,  2011,  1336,
           1438,   318,  3271,   347,   559,    13,   220,   314,   670,   355,
            257,  3788, 11949,   379,   257,  1664,  5140,   287,   262]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1]], device='cuda:0')}

In [9]:
def generate(model, tok, prefix, n=10):
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    initial_length = len(inp['input_ids'])
    pkv = None
    for _ in range(n):
        full_out = model(**inp)
        out = full_out['logits']
        pred = out[0, -1].argmax()
        inp['input_ids'] = torch.cat((inp['input_ids'], torch.tensor([pred])[None].cuda()), dim=1)
        inp['attention_mask'] = torch.cat((inp['attention_mask'], torch.ones(1, 1).cuda()), dim=1)
    return tok.decode(inp['input_ids'][0, initial_length:])
generate(model, tok, 'In his NBA career, KC Jones played', n=100)
    

' his NBA career, KC Jones played for the New York Knicks, Golden State Warriors, Los Angeles Clippers, and the Houston Rockets. He was a member of the NBA All-Rookie First Team in 1996 and was named to the NBA All-Rookie Second Team in 1997. He was also named to the NBA All-Rookie Third Team in 1997.\n\nJones was drafted by the Houston Rockets in the first round (11th overall) of the 1996 NBA Draft. He played for the Rockets for two seasons, averaging 7'

In [10]:
from baukit import TraceDict

def get_hidden_states(model, tok, prefix, layers=[]):
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    layer_names = [f'transformer.h.{i}' for i in layers]
    with TraceDict(model, layer_names) as tr:
        logits = model(**inp)['logits']
    return torch.stack([tr[ln].output[0] for ln in layer_names])

prompt = 'Hello, my name is also'
hs = get_hidden_states(model, tok, prompt, list(range(48)))

In [11]:
def show_logit_lens(model, tok, prefix, layers=None, topk=5, color=None, hs=None):
    from baukit import show
    import re
    if layers is None:
        layers = list(range(
            len([n for n, _ in model.named_modules()
             if re.match('^transformer.h.\d+$', n)])))
    if hs is None:
        hs = get_hidden_states(model, tok, prefix, layers)
    elif callable(hs):
        hs = hs(model, tok, prefix, layers)
    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head)
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]
    probs = torch.nn.functional.softmax(decoder(hs), dim=-1)
    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)
    if color is None:
        color = [0, 0, 255]
    def color_fn(p):
        a = [int(255 * (1-p) + c * p) for c in color]
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})')
    def hover(tok, prob, toks):
        lines = []
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    def make_button_with_cb(text, layer, tok):
        def clickme():
            print(layer, tok)
        return baukit.Button(text).on('click', clickme)
    header_line = [ # header line
             [show.style(fontWeight='bold'), 'Layer'] +
             [
                 [show.style(background='yellow'), t]
                 for t in prompt_tokens
             ]
         ]
    layout = [header_line,
         # body
         [
             # first column
             [show.style(fontWeight='bold'), layer] +
             [
                 # subsequent columns
                 [color_fn(p[0]), hover(tok, p, t), make_button_with_cb(tok.decode(t[0]), layer, tok.decode(t[0]))]
                 for p, t in zip(wordprobs, words)
             ]
         for layer, wordprobs, words in zip(layers, favorite_probs[:, 0], favorite_tokens[:,0])],
         header_line
             ]
    show(*layout)
    return layout

In [12]:
layout = show_logit_lens(model, tok, 'The biggest city in New England is')

In [13]:
layout[1][47][3][-1].label = 'hello'

In [14]:
def get_hidden_state_deltas(model, tok, prefix, layers=None):
    if layers is None:
        layers = list(range(48))
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    layer_names =  [f'transformer.h.{i}' for i in layers]
    with TraceDict(model, ['transformer.drop'] + layer_names) as tr:
        logits = model(**inp)['logits']
    first_h = tr['transformer.drop'].output[None]
    other_h = torch.stack([tr[ln].output[0] for ln in layer_names])
    all_h = torch.cat([first_h, other_h])
    delta_h = all_h[1:] - all_h[:-1]
    return delta_h


In [15]:
show_logit_lens(model, tok, 'The biggest city in New England is',
                hs=get_hidden_state_deltas, color=[255, 0, 255])

[[[{'font-weight': 'bold'},
   'Layer',
   [{'background': 'yellow'}, 'The'],
   [{'background': 'yellow'}, ' biggest'],
   [{'background': 'yellow'}, ' city'],
   [{'background': 'yellow'}, ' in'],
   [{'background': 'yellow'}, ' New'],
   [{'background': 'yellow'}, ' England'],
   [{'background': 'yellow'}, ' is']]],
 [[{'font-weight': 'bold'},
   0,
   [{'background': 'rgb(255, 254, 255)'},
    {'title': 'atre: prob 0.00\n first: prob 0.00\n New: prob 0.00\nory: prob 0.00\n most: prob 0.00'},
    <baukit.labwidget.Button at 0x7f642e745310>],
   [{'background': 'rgb(255, 250, 255)'},
    {'title': ' fixed: prob 0.02\n surprise: prob 0.01\n rein: prob 0.01\n wide: prob 0.01\n success: prob 0.01'},
    <baukit.labwidget.Button at 0x7f642e7458d0>],
   [{'background': 'rgb(255, 176, 255)'},
    {'title': ' council: prob 0.31\n councillor: prob 0.09\n councils: prob 0.06\n mayor: prob 0.02\nsc: prob 0.02'},
    <baukit.labwidget.Button at 0x7f642e745c90>],
   [{'background': 'rgb(255, 251