In [2]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import pickle
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "gpt2"

In [4]:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
unembed = model.get_output_embeddings()

In [6]:
unembed.forward(torch.tensor([float(i) for i in range(768)]))

tensor([ 694.0030, 2338.6108, 1067.6279,  ..., 1493.1703, 1461.0746,
         456.2899], grad_fn=<SqueezeBackward4>)

In [7]:
final_layer_norm = model.base_model.ln_f
unembed.forward(final_layer_norm.forward(torch.tensor([float(i) for i in range(768)])))

tensor([ 4.7654,  6.5866,  1.7553,  ...,  9.5762, 15.3242,  6.8427],
       grad_fn=<SqueezeBackward4>)

In [8]:
from gpt_2_dataset import GPT2Dataset
from torch.utils.data import DataLoader
import pickle

with open("datasets/dataset_test.pkl", "rb") as f:
    dataset = pickle.load(f)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))
#attention_mask[0][:target_index] = 1
#print(data[0])
#print(input_ids[0])
#print(target_ids[0])
#print(attention_mask[0])
model.eval()
output = model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask)

In [9]:
print("loss is ", -torch.log(torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())
print("predicted loss is ", output.loss)
print("probabilities are ", torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]])

loss is  tensor(3.8217, grad_fn=<NegBackward0>)
predicted loss is  tensor(3.8217, grad_fn=<NllLossBackward0>)
probabilities are  tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)


In [10]:
hidden_states = output.hidden_states
hidden_states[-1].shape
hs = torch.stack(hidden_states, dim = 1)
hs.shape

torch.Size([4, 13, 64, 768])

In [11]:
logits = unembed.forward(hs)
logits_ = unembed.forward(final_layer_norm.forward(hs))

In [12]:
logits.shape

torch.Size([4, 13, 64, 50257])

In [13]:
torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]

tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)

In [14]:
print("loss is ", -torch.log(torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())

loss is  tensor(3.8217, grad_fn=<NegBackward0>)


In [15]:
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.plotting import PredictionTrajectory

In [16]:
input_ids.shape

torch.Size([4, 64])

In [17]:
from lens_model import Lens_model
from lens import Lens, Logit_lens, Linear_lens

lens_4 = Linear_lens.from_model(model)
lens_8 = Linear_lens.from_model(model)
lens_model = Lens_model([lens_4, lens_8], layers=[4, 8])

In [18]:
logits = lens_model.forward(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids)
probs = lens_model.get_probs(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids, target_index=target_index)

In [19]:
probs

tensor([[[1.0936e-06, 8.1015e-08],
         [1.4278e-08, 4.1807e-10],
         [8.1728e-13, 6.0328e-10],
         ...,
         [2.1408e-13, 1.0038e-09],
         [3.3702e-09, 8.7198e-09],
         [7.5801e-09, 1.0747e-08]],

        [[1.6471e-04, 1.4173e-08],
         [2.7368e-08, 1.6117e-09],
         [5.8699e-12, 3.3553e-09],
         ...,
         [1.8054e-12, 8.9238e-11],
         [3.2829e-09, 3.8163e-08],
         [5.5882e-08, 5.0044e-08]],

        [[6.4718e-05, 5.9351e-07],
         [5.0786e-08, 2.1378e-08],
         [4.3852e-13, 7.5696e-07],
         ...,
         [1.7670e-11, 1.2457e-07],
         [3.6081e-06, 3.3918e-10],
         [1.1158e-07, 4.1614e-07]],

        [[4.9534e-07, 2.5188e-07],
         [8.3733e-08, 3.9527e-09],
         [4.5760e-11, 1.9581e-08],
         ...,
         [2.2173e-11, 1.3459e-09],
         [2.8648e-09, 1.3291e-10],
         [2.1141e-07, 9.3427e-09]]], grad_fn=<SoftmaxBackward0>)

In [20]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))

In [21]:
model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask).logits.shape

torch.Size([4, 64, 50257])

In [22]:
pr = lens_model.forward(input_ids, attention_mask, target_ids)#, target_index)

In [23]:
base_model_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
base_model_logits = torch.stack([base_model_logits for layer in range(2)], dim=1)
base_model_logits.shape

torch.Size([4, 2, 64, 50257])

In [24]:
logits = lens_model.forward(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids)
logits.shape

torch.Size([4, 64, 50257, 2])

In [25]:
probs = torch.softmax(logits[torch.arange(logits.shape[0]), target_index-1], dim=-2)
probs.shape

torch.Size([4, 50257, 2])

In [26]:
probs[torch.arange(probs.shape[0]), target_ids[torch.arange(probs.shape[0]), target_index]]

tensor([[1.1524e-07, 2.1782e-08],
        [6.4313e-06, 9.6567e-09],
        [1.0520e-04, 1.7076e-07],
        [2.7049e-07, 1.9339e-11]], grad_fn=<IndexBackward0>)

In [27]:
lens_model.get_correct_class_probs(input_ids, attention_mask, target_ids, target_index)

tensor([[1.1524e-07, 2.1782e-08],
        [6.4313e-06, 9.6567e-09],
        [1.0520e-04, 1.7076e-07],
        [2.7049e-07, 1.9339e-11]], grad_fn=<IndexBackward0>)

In [28]:
targets = torch.stack([target_ids for layer in range(2)], dim=1)
targets.shape

torch.Size([4, 2, 64])

In [29]:
target_index.shape

torch.Size([4])

In [30]:
targs = target_ids[torch.arange(4),target_index]
targs.shape
targets = torch.stack([targs for layer in range(2)], dim=1)
targets.shape

torch.Size([4, 2])

In [31]:
probs = lens_model.get_probs(input_ids, attention_mask, target_ids, target_index)
print(probs.shape)
print(probs[1].sum())
print(torch.sum(probs, dim=1))

torch.Size([4, 50257, 2])
tensor(2.0001, grad_fn=<SumBackward0>)
tensor([[1.0001, 1.0000],
        [1.0001, 1.0000],
        [1.0001, 1.0001],
        [1.0001, 1.0001]], grad_fn=<SumBackward1>)


In [32]:
lens_model.forward(input_ids, attention_mask, target_ids)

tensor([[[[-11.9520, -32.1912],
          [-19.1287, -24.5467],
          [-30.1045, -23.2777],
          ...,
          [-15.8533, -31.1020],
          [-11.6974, -37.6973],
          [-14.4931, -24.1138]],

         [[ -5.9640, -26.1626],
          [ -9.7150, -20.3601],
          [-13.1527, -23.9250],
          ...,
          [-16.0871, -21.4656],
          [-14.7223, -25.8494],
          [ -7.9957, -25.0632]],

         [[ -0.2052, -11.4745],
          [ -3.2747,  -5.7795],
          [ -7.0115,  -7.3103],
          ...,
          [-10.6523, -13.7001],
          [ -7.0557,  -8.6894],
          [ -3.6095,  -6.9114]],

         ...,

         [[  2.4722, -11.7025],
          [ -0.4998,  -8.1059],
          [-11.4415,  -6.1599],
          ...,
          [ -5.4269, -10.2061],
          [ -3.2803, -13.9883],
          [  6.1967, -13.9688]],

         [[  2.5067, -11.6910],
          [ -0.4917,  -8.1091],
          [-11.4312,  -6.1517],
          ...,
          [ -5.3727, -10.2076],
      

In [33]:
model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask).logits.shape

torch.Size([4, 64, 50257])

In [34]:
targets_ = target_ids[torch.arange(4), target_index]
targets_ = torch.stack([targets_ for layer in range(2)], dim=1)
targets_.shape

torch.Size([4, 2])

In [35]:
lens_model.get_probs(input_ids, attention_mask, target_ids, target_index).shape

torch.Size([4, 50257, 2])

In [36]:
lens = []
layers = []
for layer in range(12):
    nl = Linear_lens.from_model(model)
    nl.set_parameters({'weight': torch.nn.Parameter(torch.eye(model.config.hidden_size)),\
                        'bias': torch.nn.Parameter(torch.zeros(model.config.hidden_size))})
    lens.append(nl)
    layers.append(layer)

lens_model = Lens_model(lens, layers, model_name=model_name)

In [37]:
with open("datasets/dataset_test.pkl", "rb") as f:
    dataset_test = pickle.load(f)

with open("datasets/dataset_train.pkl", "rb") as f:
    dataset_train = pickle.load(f)

In [38]:
data_train = DataLoader(dataset_train, batch_size=4, shuffle=False)
data_test = DataLoader(dataset_test, batch_size=4, shuffle=False)

In [39]:
data, input_ids, target_ids, attention_mask, target_index = next(iter(data_train))

In [40]:
lens_model.forward(input_ids, attention_mask, target_ids)

tensor([[[[  16.0838,    6.2038,   -3.5437,  ...,   -6.7148,   -7.2858,
             -8.2691],
          [  21.1776,    6.2868,   -3.5077,  ...,   -6.6434,   -7.1885,
             -8.1944],
          [  19.7677,    4.9822,   -5.9960,  ...,   -8.7432,   -9.4221,
            -10.4845],
          ...,
          [  11.1786,   -4.1631,  -11.4905,  ...,  -13.0316,  -13.7163,
            -14.8179],
          [  13.6240,   -2.5583,  -11.0643,  ...,  -12.2197,  -12.7285,
            -13.5820],
          [  20.5195,    5.9294,   -3.4365,  ...,   -6.2053,   -6.7674,
             -7.7027]],

         [[  16.8238,   -8.5267,   -5.0287,  ...,  -71.0284,  -66.2512,
            -83.1334],
          [  20.4722,   -6.7212,   -3.0400,  ...,  -69.7362,  -63.2890,
            -79.7683],
          [  22.3595,   -6.9423,   -2.4522,  ...,  -64.7502,  -64.7687,
            -84.7910],
          ...,
          [  22.3544,  -15.8880,  -16.4908,  ...,  -92.0121,  -90.2242,
           -104.7963],
          [  17.11

In [41]:
model = AutoModelForCausalLM.from_pretrained(model_name)
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [42]:
lens_model

Lens_model(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (lens): ParameterList(
      (0): O

In [43]:
lens_ = Linear_lens(10, 10)
lens_.parameters

<bound method Module.parameters of Linear_lens(
  (linear): Linear(in_features=10, out_features=10, bias=True)
)>

In [44]:
torch.save(lens_model, './models/lens_model.pt')

In [46]:
with open('./models/lens_model_cpu.pkl', "rb") as f:
    lens_model = pickle.load(f)

In [50]:
from train import validate

In [52]:
dataset_test_path = "datasets/dataset_test.pkl"
output_path = "output"
device = "cpu"
batch_size = 4
max_length = 64
loss_function_name = "kl"
losses_valid_t, average_loss_valid = validate(lens_model, model, dataset_test_path, max_length, device, output_path, batch_size, loss_function_name)
print(losses_valid_t)
print(average_loss_valid)

Validating...
    Batch: 0/1857, Loss: 0.2536962330341339
Average loss: 1.2568920128563843
[0.2536962330341339, 2.8589835166931152, 2.081822156906128, 0.4752136468887329, 1.0852179527282715, 0.5857743620872498, 0.08212531358003616, 0.3433285653591156, 1.0946650505065918, 0.3625377118587494, 0.1785324066877365, 0.6376705765724182, 1.1370208263397217, 0.6946799159049988, 4.588211536407471, 1.0307098627090454, 0.5813913345336914, 0.6113443374633789, 2.0838139057159424, 1.7911590337753296, 0.8209935426712036, 0.4804232120513916, 0.38315457105636597, 0.6116137504577637, 0.809879720211029, 0.7492644190788269, 0.8996995687484741, 2.0654397010803223, 0.20470036566257477, 4.281492710113525, 0.18094244599342346, 0.40408843755722046, 0.20104187726974487, 0.7827140688896179, 0.24521024525165558, 0.2908748388290405, 5.148050308227539, 0.7815617322921753, 0.8796820044517517, 0.790023922920227, 0.35296887159347534, 0.7179099917411804, 1.2749625444412231, 1.2891569137573242, 0.08150981366634369, 0.978

In [53]:
lens = []
layers = []
for layer in range(12):
    nl = Logit_lens.from_model(model)
    lens.append(nl)
    layers.append(layer)

lens_model_l = Lens_model(lens, layers, model_name=model_name)

In [59]:
lens = []
layers = []
for layer in range(12):
    nl = Linear_lens.from_model(model)
    nl.set_parameters({'weight': torch.nn.Parameter(torch.eye(model.config.hidden_size)),\
                        'bias': torch.nn.Parameter(torch.zeros(model.config.hidden_size))})
    lens.append(nl)
    layers.append(layer)

lens_model = Lens_model(lens, layers, model_name=model_name)

In [60]:
lens_model_l.forward(input_ids, attention_mask, target_ids) == lens_model.forward(input_ids, attention_mask, target_ids)

tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...