In [1]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import pickle
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"

In [3]:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
unembed = model.get_output_embeddings()

In [5]:
unembed.forward(torch.tensor([float(i) for i in range(768)]))

tensor([ 694.0030, 2338.6108, 1067.6279,  ..., 1493.1703, 1461.0746,
         456.2899], grad_fn=<SqueezeBackward4>)

In [6]:
final_layer_norm = model.base_model.ln_f
unembed.forward(final_layer_norm.forward(torch.tensor([float(i) for i in range(768)])))

tensor([ 4.7654,  6.5866,  1.7553,  ...,  9.5762, 15.3242,  6.8427],
       grad_fn=<SqueezeBackward4>)

In [7]:
from gpt_2_dataset import GPT2Dataset
from torch.utils.data import DataLoader
import pickle

with open("datasets/dataset_test.pkl", "rb") as f:
    dataset = pickle.load(f)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))
#attention_mask[0][:target_index] = 1
#print(data[0])
#print(input_ids[0])
#print(target_ids[0])
#print(attention_mask[0])
model.eval()
output = model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask)

In [8]:
print("loss is ", -torch.log(torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())
print("predicted loss is ", output.loss)
print("probabilities are ", torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]])

loss is  tensor(3.8217, grad_fn=<NegBackward0>)
predicted loss is  tensor(3.8217, grad_fn=<NllLossBackward0>)
probabilities are  tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)


In [9]:
hidden_states = output.hidden_states
hidden_states[-1].shape
hs = torch.stack(hidden_states, dim = 1)
hs.shape

torch.Size([4, 13, 64, 768])

In [10]:
logits = unembed.forward(hs)
logits_ = unembed.forward(final_layer_norm.forward(hs))

In [11]:
logits.shape

torch.Size([4, 13, 64, 50257])

In [12]:
torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]

tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)

In [13]:
print("loss is ", -torch.log(torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())

loss is  tensor(3.8217, grad_fn=<NegBackward0>)


In [14]:
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.plotting import PredictionTrajectory

In [15]:
input_ids.shape

torch.Size([4, 64])

In [16]:
from lens_model import Lens_model
from lens import Lens, Logit_lens, Linear_lens

lens_4 = Linear_lens.from_model(model)
lens_8 = Linear_lens.from_model(model)
lens_model = Lens_model([lens_4, lens_8], layers=[4, 8])

In [17]:
logits = lens_model.forward(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids)
probs = lens_model.get_probs(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids, target_index=target_index)

In [18]:
probs

tensor([[[2.5031e-08, 6.4239e-10],
         [9.7787e-06, 4.0750e-09],
         [4.1373e-08, 1.5405e-10],
         ...,
         [1.5466e-13, 2.7777e-11],
         [7.3198e-10, 4.0914e-13],
         [2.7261e-09, 1.4884e-11]],

        [[4.5172e-10, 9.5926e-10],
         [1.2042e-05, 9.8153e-09],
         [1.5876e-08, 6.1077e-10],
         ...,
         [8.0898e-14, 3.2625e-13],
         [1.4657e-09, 2.4526e-11],
         [7.2036e-11, 1.4454e-08]],

        [[4.4109e-09, 1.1394e-08],
         [1.2270e-06, 1.8960e-07],
         [9.0117e-08, 8.6830e-08],
         ...,
         [1.2630e-13, 4.7863e-12],
         [3.9226e-12, 2.8383e-11],
         [5.1628e-09, 1.2626e-08]],

        [[2.3229e-09, 6.3803e-09],
         [4.3377e-06, 1.0159e-09],
         [5.7620e-08, 2.6259e-08],
         ...,
         [1.8516e-14, 1.8716e-13],
         [1.9330e-11, 1.7841e-12],
         [3.6748e-12, 7.8151e-09]]])

In [19]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))

In [20]:
model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask).logits.shape

torch.Size([4, 64, 50257])

In [21]:
pr = lens_model.forward(input_ids, attention_mask, target_ids)#, target_index)

In [22]:
base_model_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
base_model_logits = torch.stack([base_model_logits for layer in range(2)], dim=1)
base_model_logits.shape

torch.Size([4, 2, 64, 50257])

In [23]:
logits = lens_model.forward(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids)
logits.shape

torch.Size([4, 64, 50257, 2])

In [24]:
probs = torch.softmax(logits[torch.arange(logits.shape[0]), target_index-1], dim=-2)
probs.shape

torch.Size([4, 50257, 2])

In [25]:
probs[torch.arange(probs.shape[0]), target_ids[torch.arange(probs.shape[0]), target_index]]

tensor([[2.6690e-09, 2.5571e-13],
        [3.6697e-08, 2.2381e-07],
        [7.8043e-10, 1.8521e-08],
        [1.6750e-10, 4.3311e-11]])

In [26]:
lens_model.get_correct_class_probs(input_ids, attention_mask, target_ids, target_index)

tensor([[2.6690e-09, 2.5571e-13],
        [3.6697e-08, 2.2381e-07],
        [7.8043e-10, 1.8521e-08],
        [1.6750e-10, 4.3311e-11]])

In [27]:
targets = torch.stack([target_ids for layer in range(2)], dim=1)
targets.shape

torch.Size([4, 2, 64])

In [28]:
target_index.shape

torch.Size([4])

In [29]:
targs = target_ids[torch.arange(4),target_index]
targs.shape
targets = torch.stack([targs for layer in range(2)], dim=1)
targets.shape

torch.Size([4, 2])

In [30]:
probs = lens_model.get_probs(input_ids, attention_mask, target_ids, target_index)
print(probs.shape)
print(probs[1].sum())
print(torch.sum(probs, dim=1))

torch.Size([4, 50257, 2])
tensor(2.0001)
tensor([[1.0001, 1.0000],
        [1.0001, 1.0000],
        [1.0000, 1.0001],
        [1.0001, 1.0000]])


In [31]:
lens_model.forward(input_ids, attention_mask, target_ids)

tensor([[[[-1.6717e+01, -5.4999e+00],
          [-1.5374e+01, -9.9038e+00],
          [-1.4408e+01, -1.2333e+01],
          ...,
          [-2.6538e+01, -3.3264e+01],
          [-2.2979e+01, -1.4161e+01],
          [-1.4656e+01, -1.1132e+01]],

         [[ 6.2960e+00, -7.1018e+00],
          [ 8.0610e+00, -5.0475e+00],
          [ 4.8768e+00, -6.7406e+00],
          ...,
          [-1.0472e+01, -1.3002e+01],
          [ 8.6033e-02, -1.2471e+01],
          [ 6.5639e+00, -6.8861e+00]],

         [[ 2.2587e+00,  2.4926e-01],
          [ 9.2178e+00,  6.2326e+00],
          [-7.4384e-01,  1.8114e+00],
          ...,
          [-6.8657e+00,  3.8178e+00],
          [-5.6461e-01, -6.1559e+00],
          [-6.1885e-01, -2.7406e-01]],

         ...,

         [[-6.0819e+00, -7.7930e+00],
          [-5.1625e+00, -1.1777e+00],
          [-6.7861e+00, -4.8100e+00],
          ...,
          [-2.3850e+01, -1.0207e+01],
          [-1.7109e+01, -7.0645e+00],
          [-3.5566e+00, -4.9039e+00]],

     

In [32]:
model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask).logits.shape

torch.Size([4, 64, 50257])

In [33]:
targets_ = target_ids[torch.arange(4), target_index]
targets_ = torch.stack([targets_ for layer in range(2)], dim=1)
targets_.shape

torch.Size([4, 2])

In [34]:
lens_model.get_probs(input_ids, attention_mask, target_ids, target_index).shape

torch.Size([4, 50257, 2])

In [35]:
lens = []
layers = []
for layer in range(12):
    nl = Linear_lens.from_model(model)
    nl.set_parameters({'weight': torch.nn.Parameter(torch.eye(model.config.hidden_size)),\
                        'bias': torch.nn.Parameter(torch.zeros(model.config.hidden_size))})
    lens.append(nl)
    layers.append(layer)

lens_model = Lens_model(lens, layers, model_name=model_name)

In [36]:
with open("datasets/dataset_test.pkl", "rb") as f:
    dataset_test = pickle.load(f)

with open("datasets/dataset_train.pkl", "rb") as f:
    dataset_train = pickle.load(f)

In [37]:
data_train = DataLoader(dataset_train, batch_size=4, shuffle=False)
data_test = DataLoader(dataset_test, batch_size=4, shuffle=False)

In [38]:
data, input_ids, target_ids, attention_mask, target_index = next(iter(data_train))

In [39]:
lens_model.forward(input_ids, attention_mask, target_ids)

tensor([[[[  16.0838,    6.2038,   -3.5437,  ...,   -6.7148,   -7.2858,
             -8.2691],
          [  21.1776,    6.2868,   -3.5077,  ...,   -6.6434,   -7.1885,
             -8.1944],
          [  19.7677,    4.9822,   -5.9960,  ...,   -8.7432,   -9.4221,
            -10.4845],
          ...,
          [  11.1786,   -4.1631,  -11.4905,  ...,  -13.0316,  -13.7163,
            -14.8179],
          [  13.6240,   -2.5583,  -11.0643,  ...,  -12.2197,  -12.7285,
            -13.5820],
          [  20.5195,    5.9294,   -3.4365,  ...,   -6.2053,   -6.7674,
             -7.7027]],

         [[  16.8238,   -8.5267,   -5.0287,  ...,  -71.0284,  -66.2512,
            -83.1334],
          [  20.4722,   -6.7212,   -3.0400,  ...,  -69.7362,  -63.2890,
            -79.7683],
          [  22.3595,   -6.9423,   -2.4522,  ...,  -64.7502,  -64.7687,
            -84.7910],
          ...,
          [  22.3544,  -15.8880,  -16.4908,  ...,  -92.0121,  -90.2242,
           -104.7963],
          [  17.11

In [40]:
model = AutoModelForCausalLM.from_pretrained(model_name)
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [41]:
lens_model

Lens_model(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (lens): ParameterList(
      (0): O

In [42]:
lens_ = Linear_lens(10, 10)
lens_.parameters

<bound method Module.parameters of Linear_lens(
  (linear): Linear(in_features=10, out_features=10, bias=True)
)>