In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tabulate import tabulate

from hallucinations.datasets import get_dataset, prepare_dataset
from hallucinations.config import QaDatasetConfig, QaPromptConfig
from hallucinations.utils import load_and_resolve_config

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2ea8af2fc0>

In [3]:
ds_cfg = load_and_resolve_config("config/dataset/nq_open.yaml")
prompt_cfg = load_and_resolve_config("config/prompt/qa/short_zero_shot.yaml")
ds_config = QaDatasetConfig(**ds_cfg)
prompt_config = QaPromptConfig(**prompt_cfg)

dataset = prepare_dataset(dataset_config=ds_config, prompt=prompt_config, split=ds_config.test_split_name, use_output=False)
dataset

Dataset({
    features: ['question', 'answer', 'messages'],
    num_rows: 3610
})

In [4]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
# model_name = "microsoft/Phi-3.5-mini-instruct"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
device = next(model.parameters()).device
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

print(f"Loaded model to {device}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded model to cuda:0


In [5]:
# adding system prompt
# data = [[{"role": "system", "content": "You're helpful assistant"}] + x for x in next(dataset.iter(batch_size=5))["messages"]]
data = next(dataset.iter(batch_size=5))["messages"]

chat_inputs = tokenizer.apply_chat_template(
    data,
    add_generation_prompt=True,
    tokenize=False,
)
encoded_inputs = tokenizer(
    chat_inputs,
    return_tensors="pt",
    padding="longest",
    truncation=False,
    return_attention_mask=True,
)
print(f"shape: {encoded_inputs['input_ids'].shape}")
print("===")
print(f"Example input:\n{chat_inputs[0]}")

shape: torch.Size([5, 61])
===
Example input:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Answer the following question as briefly as possible.
Question: when was the last time anyone was on the moon
Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [6]:
encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}

outs = model.generate(
    **encoded_inputs,
    do_sample=True,
    num_return_sequences=1,
    max_length=100,
    temperature=0.7,
    return_dict_in_generate=True,
    output_attentions=True,
    output_hidden_states=True,
    pad_token_id=tokenizer.eos_token_id,
)
type(outs)



In [None]:
print(f"Keys: {outs.keys()}")
print(f"Input shape: {encoded_inputs['input_ids'].shape}")
print(f"Output shape: {outs.sequences.shape}")
print(f"Generated shape: {outs.sequences[:, encoded_inputs['input_ids'].size(1):].shape}")

In [None]:
def get_decoder_only_special_tokens_mask(token_ids: list[int]) -> list[int]:
    if isinstance(token_ids, torch.Tensor):
        token_ids = token_ids.tolist()
    special_token_ids = set(tokenizer.added_tokens_decoder) | set(tokenizer.all_special_ids)
    return [int(tok_id in special_token_ids) for tok_id in token_ids]

seq = outs.sequences[0]
tokens = tokenizer.convert_ids_to_tokens(seq)
mask = get_decoder_only_special_tokens_mask(seq)
# mask = tokenizer.get_special_tokens_mask(seq, already_has_special_tokens=True)
list(zip(tokens, mask))

# Hidden states shape

In [None]:
print("#dim_0 (#new_tokens)", len(outs["hidden_states"]))
print("#dim_1 (#layers)", len(outs["hidden_states"][0]))
print("+".center(20, "+"))
print()

shapes = []
decoded_tokens = [tokenizer.convert_ids_to_tokens(tok) for tok in outs.sequences]

for i_gen_tok, gen_tok_data in enumerate(outs["hidden_states"]):
    for i_layer, layer_gen_tok_data in enumerate(gen_tok_data):
        tokens = tokenizer.convert_tokens_to_string([decoded_tokens[i][i_gen_tok+(encoded_inputs['input_ids'].size(1)-1)] for i in range(len(decoded_tokens))])
        shapes.append([i_gen_tok, i_layer, layer_gen_tok_data.shape, tokens])
    shapes.append(["-", "-", "-", "-"])

print(tabulate(shapes, headers=["gen_tok", "layer", "shape", "token"]))