In [1]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import pickle
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"

In [3]:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
unembed = model.get_output_embeddings()

In [5]:
unembed.forward(torch.tensor([float(i) for i in range(768)]))

tensor([ 694.0030, 2338.6108, 1067.6279,  ..., 1493.1703, 1461.0746,
         456.2899], grad_fn=<SqueezeBackward4>)

In [6]:
final_layer_norm = model.base_model.ln_f
unembed.forward(final_layer_norm.forward(torch.tensor([float(i) for i in range(768)])))

tensor([ 4.7654,  6.5866,  1.7553,  ...,  9.5762, 15.3242,  6.8427],
       grad_fn=<SqueezeBackward4>)

In [37]:
from gpt_2_dataset import GPT2Dataset
from torch.utils.data import DataLoader
import pickle

with open("datasets/dataset_test.pkl", "rb") as f:
    dataset = pickle.load(f)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
data, input_ids, target_ids, attention_mask, target_index = next(iter(dataloader))
#attention_mask[0][:target_index] = 1
#print(data[0])
#print(input_ids[0])
#print(target_ids[0])
#print(attention_mask[0])
model.eval()
output = model(input_ids, labels=target_ids, output_hidden_states=True, attention_mask=attention_mask)

In [9]:
output.loss

tensor(4.2614, grad_fn=<NllLossBackward0>)

In [10]:
output.logits.shape

torch.Size([4, 64, 50257])

In [16]:
target_index 

tensor([21, 20, 19, 18])

In [25]:
output.logits[torch.arange(4), target_index-1].shape

torch.Size([4, 50257])

In [17]:
output.logits.shape

torch.Size([4, 64, 50257])

In [68]:
print("loss is ", -torch.log(torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())
print("predicted loss is ", output.loss)
print("probabilities are ", torch.softmax(output.logits[torch.arange(output.logits.shape[0]), target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]])

loss is  tensor(4.2614, grad_fn=<NegBackward0>)
predicted loss is  tensor(4.2614, grad_fn=<NllLossBackward0>)
probabilities are  tensor([0.0540, 0.0293, 0.0107, 0.0023], grad_fn=<IndexBackward0>)


In [50]:
hidden_states = output.hidden_states
hidden_states[-1].shape
hs = torch.stack(hidden_states, dim = 1)
hs.shape

torch.Size([4, 13, 64, 768])

In [52]:
logits = unembed.forward(final_layer_norm.forward(hs))

In [79]:
logits.shape

torch.Size([4, 13, 64, 50257])

In [75]:
torch.softmax(logits[torch.arange(output.logits.shape[0]), -2, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]

tensor([1.7174e-02, 1.3006e-03, 3.1253e-01, 2.1306e-04],
       grad_fn=<IndexBackward0>)

In [78]:
print("loss is ", -torch.log(torch.softmax(logits[torch.arange(output.logits.shape[0]), -3, target_index-1], dim=-1)[torch.arange(output.logits.shape[0]), target_ids[torch.arange(output.logits.shape[0]), target_index]]).mean())

loss is  tensor(4.8113, grad_fn=<NegBackward0>)
