In [28]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import pickle
import copy

In [29]:
model_name = "gpt2"

In [30]:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [31]:
unembed = model.get_output_embeddings()

In [32]:
unembed.forward(torch.tensor([float(i) for i in range(768)]))

tensor([ 694.0030, 2338.6108, 1067.6279,  ..., 1493.1703, 1461.0746,
         456.2899], grad_fn=<SqueezeBackward4>)

In [33]:
final_layer_norm = model.base_model.ln_f
unembed.forward(final_layer_norm.forward(torch.tensor([float(i) for i in range(768)])))

tensor([ 4.7654,  6.5866,  1.7553,  ...,  9.5762, 15.3242,  6.8427],
       grad_fn=<SqueezeBackward4>)

In [34]:
from gpt_2_dataset import GPT2Dataset
from torch.utils.data import DataLoader
import pickle

with open("datasets/dataset_test.pkl", "rb") as f:
    dataset = pickle.load(f)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))
#attention_mask[0][:target_index] = 1
#print(data[0])
#print(input_ids[0])
#print(target_ids[0])
#print(attention_mask[0])
model.eval()
output = model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask)

In [35]:
output.loss

tensor(3.8217, grad_fn=<NllLossBackward0>)

In [36]:
output.logits.shape

torch.Size([4, 64, 50257])

In [37]:
target_index 

tensor([19, 18, 17, 16])

In [38]:
output.logits[torch.arange(1), target_index-1].shape

torch.Size([4, 50257])

In [39]:
output.logits.shape

torch.Size([4, 64, 50257])

In [40]:
print("loss is ", -torch.log(torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())
print("predicted loss is ", output.loss)
print("probabilities are ", torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]])

loss is  tensor(3.8217, grad_fn=<NegBackward0>)
predicted loss is  tensor(3.8217, grad_fn=<NllLossBackward0>)
probabilities are  tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)


In [41]:
hidden_states = output.hidden_states
hidden_states[-1].shape
hs = torch.stack(hidden_states, dim = 1)
hs.shape

torch.Size([4, 13, 64, 768])

In [42]:
logits = unembed.forward(hs)
logits_ = unembed.forward(final_layer_norm.forward(hs))

In [43]:
logits.shape

torch.Size([4, 13, 64, 50257])

In [44]:
torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]

tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)

In [45]:
print("loss is ", -torch.log(torch.softmax(logits[torch.arange(output.logits.shape[0]), -1, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())

loss is  tensor(3.8217, grad_fn=<NegBackward0>)


In [46]:
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.plotting import PredictionTrajectory

In [47]:
print(input_ids.shape)
# shit input ids for one and add eos token
data = dataset[0]
targets = data[1][1:]
ids = data[1][:data[4] + 1]
targets = targets[:data[4] + 1]
print(ids)
print(targets)

torch.Size([4, 64])
tensor([  818,  1948,    11,   673,  5495,  4141,  3496, 16502,   290,   311,
        43837,  3754,    11,  3025,  7259,   673,  6157,   351,    13,   628])
tensor([ 1948,    11,   673,  5495,  4141,  3496, 16502,   290,   311, 43837,
         3754,    11,  3025,  7259,   673,  6157,   351,    13,   628, 50256])


In [48]:
logit_lens = LogitLens.from_model(model)

pred_traj = PredictionTrajectory.from_lens_and_model(
        lens=logit_lens,
        model=model,
        input_ids=ids,
        tokenizer=tokenizer,
        targets=targets,
    )
pred_traj.cross_entropy().stats[-1]


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



array([ 7.777263  ,  0.13699046,  4.745229  ,  7.72523   ,  8.824999  ,
        9.073283  ,  1.5296029 ,  2.3428512 ,  9.848522  ,  4.259877  ,
        0.0319601 ,  1.9769074 ,  4.1525745 ,  4.138511  ,  5.040543  ,
        3.6439207 ,  2.029684  ,  3.61438   ,  4.7604427 , 11.221961  ],
      dtype=float32)

In [49]:
stat = 'cross_entropy'
fig = getattr(pred_traj, stat)().stride(1).figure(
        title=f"{logit_lens.__class__.__name__} ({model.name_or_path}) {stat}",
    )

fig

In [50]:
from lens import Logit_lens
from lens_model import Lens_model
logit_lens = Logit_lens(dim_in=768, dim_out=50257)
lens_model = Lens_model([logit_lens], layers=[12])

data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))

In [51]:
input_ids.shape

torch.Size([4, 64])

In [52]:
logits = lens_model.forward(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids)[0]
logits.shape

torch.Size([4, 64, 50257])

In [53]:
torch.softmax(logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]

tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)

In [54]:
lens_model.get_correct_class_probs(input_ids=input_ids, attention_mask=attention_mask, targets=target_ids, target_index=target_index)

[tensor([0.0720, 0.0393, 0.1288, 0.0006], grad_fn=<IndexBackward0>)]