In [22]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from transformers import pipeline
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from copy import deepcopy
import pandas as pd
from collections import defaultdict

checkpoint = "EleutherAI/gpt-neo-2.7B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, output_hidden_states=True)
trans = model._modules['transformer']
lm_head = model._modules['lm_head']

# generate dataset

In [25]:
import pandas as pd
from datasets import Dataset

d = defaultdict(list)
np.random.seed(13)
for i in range(1000):
    num1 = np.random.randint(0, 100)
    num2 = np.random.randint(0, 100)
    d['input'].append(f'{num1} {num2}')
    d['output'].append(f' {num1 + num2}')
df = pd.DataFrame.from_dict(d)
dset = Dataset.from_pandas(df)

# loss

In [49]:
# initialize prefix
prefix_str = ["x the following two numbers: "]
prefix_inputs = tokenizer(prefix_str, return_tensors="pt")
prefix_emb = trans.wte.forward(prefix_inputs['input_ids']) # this is the key param
prefix_emb.retain_grad()

# embedding for example
ex_num = 0
ex = dset[ex_num]
x_text = ex['input']
y_text = ex['output']
ex_str = x_text + y_text
ex_inputs = tokenizer(ex_str, return_tensors='pt')
ex_emb = trans.wte.forward(ex_inputs['input_ids']) # this is the key param 

# concatenate prefix + example
emb = torch.cat((prefix_emb, ex_emb), dim=1)

# go through model
outputs = model(inputs_embeds=emb)

# calculate loss
idxs_correct = tokenizer.encode(y_text)
assert len(idxs_correct) == 1, 'For now assume that answer is a single token'
y_idx_correct = idxs_correct[0]
logit_answer = outputs['logits'][0, -1, y_idx_correct] # (batch_size, seq_len, vocab_size)
loss = -1 * logit_answer
loss.backward()

print('shapes', prefix_emb.shape, ex_emb.shape, emb.shape)

shapes torch.Size([1, 7, 2560]) torch.Size([1, 3, 2560]) torch.Size([1, 10, 2560])


In [50]:
prefix_emb.grad

tensor([[[ 0.0834, -0.0245,  0.0530,  ..., -0.0294,  0.0644,  0.0323],
         [ 0.0325, -0.0326, -0.1163,  ..., -0.2024,  0.0970,  0.0132],
         [ 0.0042, -0.0730,  0.0170,  ..., -0.0091,  0.0608, -0.0094],
         ...,
         [ 0.0084,  0.0581, -0.0509,  ..., -0.0352,  0.0215,  0.0007],
         [-0.0897, -0.0617, -0.1768,  ...,  0.0444, -0.0292,  0.0522],
         [-0.0338, -0.0161, -0.0210,  ...,  0.0107, -0.0184, -0.0054]]])

In [48]:
sorted(nn.functional.softmax(outputs['logits'][0, -1]))

  sorted(nn.functional.softmax(outputs['logits'][0, -1]))


[tensor(1.9298e-15, grad_fn=<UnbindBackward>),
 tensor(3.0424e-15, grad_fn=<UnbindBackward>),
 tensor(4.8582e-15, grad_fn=<UnbindBackward>),
 tensor(7.3112e-15, grad_fn=<UnbindBackward>),
 tensor(7.8846e-15, grad_fn=<UnbindBackward>),
 tensor(9.0563e-15, grad_fn=<UnbindBackward>),
 tensor(9.1250e-15, grad_fn=<UnbindBackward>),
 tensor(1.3647e-14, grad_fn=<UnbindBackward>),
 tensor(1.4094e-14, grad_fn=<UnbindBackward>),
 tensor(1.4106e-14, grad_fn=<UnbindBackward>),
 tensor(1.4924e-14, grad_fn=<UnbindBackward>),
 tensor(1.6488e-14, grad_fn=<UnbindBackward>),
 tensor(1.6672e-14, grad_fn=<UnbindBackward>),
 tensor(1.7114e-14, grad_fn=<UnbindBackward>),
 tensor(1.7497e-14, grad_fn=<UnbindBackward>),
 tensor(1.9023e-14, grad_fn=<UnbindBackward>),
 tensor(1.9143e-14, grad_fn=<UnbindBackward>),
 tensor(2.1562e-14, grad_fn=<UnbindBackward>),
 tensor(2.5566e-14, grad_fn=<UnbindBackward>),
 tensor(2.6863e-14, grad_fn=<UnbindBackward>),
 tensor(2.7287e-14, grad_fn=<UnbindBackward>),
 tensor(2.858

# old

**Test: transformer + lm_head = original model + probabilities for next tokens.**

In [3]:
# prepare inputs
raw_inputs = ["1+3=4"]
inputs = tokenizer(raw_inputs, return_tensors="pt")

# predict
outputs = model(**inputs)

# show that the lm_head is producing logits via linear transformation
trans = model._modules['transformer']
lm_head = model._modules['lm_head']
out = trans(inputs['input_ids'])
h = out['hidden_states'] # tuple of (layer x (batch_size, seq_len, hidden_size))
logits = lm_head(h[-1])  # select logits using last layer

# we got the same logits by going through the model
assert logits.shape == outputs['logits'].shape # tensor (batch_size, seq_len, vocab_size)
assert logits.sum() == outputs['logits'].sum()
assert logits.max() == outputs['logits'].max()

# naive check
print('input text:', tokenizer.decode(inputs['input_ids'][0]))

# top word embeddings
decoded_toks = tokenizer.decode(logits[0].argmax(axis=-1))
print('decoded text from hidden states', decoded_toks)

# dissect token-by-token
print('top predicted tokens at this position:')
for i, tok in enumerate(inputs.tokens()):
    print(f'{i}: ___{tok}___ --> ___{decoded_toks[i]}___')

input text: 1+3=4
decoded text from hidden states .\\4$
top predicted tokens at this position:
0: ___1___ --> ___.___
1: ___+___ --> ___\___
2: ___3___ --> ___\___
3: ___=___ --> ___4___
4: ___4___ --> ___$___


**Set up projecting back to word-vectors.**

In [141]:
# get embedding matrix
w_embed = trans.wte.weight # vocab_size, embed_dim
vocab_size = w_embed.shape[0]
embed_size = w_embed.shape[1]

# invert for unembedding
unemb_linear = nn.Linear(in_features=embed_size, out_features=vocab_size, bias=False)
pinv = torch.linalg.pinv(w_embed)
unemb_linear.weight = nn.Parameter(pinv.T)

# make sure unembedding works
ids = torch.Tensor([[16, 2, 3]]).int()
embs = trans.wte.forward(ids)
unembedded_onehot = unemb_linear(embs)
unembedded_ids = unembedded_onehot.argmax(axis=-1)
assert torch.all(unembedded_ids == ids)

**Get gradient wrt embedding vector**

In [142]:
# check that original outputs match outputs when using embedding
embeds = trans.wte.forward(ids)
outputs_using_embeds = model(inputs_embeds=embeds)
outputs = model(input_ids=ids)
assert outputs['logits'].sum() == outputs_using_embeds['logits'].sum()

# get gradient
embeds = trans.wte.forward(ids)
embeds.retain_grad()
outputs = model(inputs_embeds=embeds)
loss = outputs['logits'].sum()
loss.backward()
embeds.grad

tensor([[[  9575.5391,  -1355.6121,  -2623.6182,  ...,  12424.7637,
           -5506.7495,  -9292.2803],
         [ 12695.8623,  10005.6162,  17801.9453,  ...,   7830.0645,
           24521.3984,   4288.2812],
         [ -9870.7490,   -223.3108,  -3130.4512,  ...,  15149.9629,
          -23681.2461,  -5486.9058]]])