In [1]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm
import random

import plotly.express as px

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

import einops

from typing import Literal, Optional, Tuple, Union
from jaxtyping import Float

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention


torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def compute_logit_diff(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

In [3]:
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).cuda()
initial_parameters = model_hf.num_parameters()

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


In [5]:
with open("acronyms_2_common.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = 250
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [10]:
[len(x) for x in tokenizer_hf(prompts)["input_ids"]]

[11,
 12,
 11,
 12,
 13,
 12,
 14,
 11,
 12,
 11,
 11,
 10,
 11,
 14,
 12,
 13,
 11,
 12,
 11,
 13,
 12,
 13,
 11,
 12,
 12,
 11,
 10,
 10,
 11,
 11,
 12,
 12,
 12,
 11,
 12,
 13,
 12,
 11,
 12,
 12,
 13,
 12,
 11,
 12,
 12,
 11,
 12,
 11,
 12,
 12,
 12,
 12,
 12,
 11,
 11,
 11,
 13,
 11,
 12,
 11,
 11,
 11,
 11,
 12,
 13,
 12,
 11,
 12,
 14,
 15,
 10,
 11,
 12,
 11,
 13,
 13,
 12,
 12,
 14,
 11,
 11,
 12,
 11,
 12,
 12,
 14,
 13,
 12,
 11,
 11,
 12,
 11,
 13,
 12,
 11,
 11,
 12,
 12,
 11,
 12,
 12,
 11,
 11,
 11,
 11,
 12,
 12,
 12,
 11,
 13,
 10,
 10,
 10,
 11,
 13,
 14,
 12,
 11,
 13,
 11,
 12,
 11,
 13,
 12,
 12,
 14,
 12,
 12,
 13,
 12,
 12,
 12,
 12,
 12,
 13,
 12,
 13,
 11,
 11,
 12,
 12,
 11,
 12,
 12,
 11,
 12,
 11,
 11,
 12,
 10,
 13,
 12,
 14,
 12,
 12,
 12,
 14,
 14,
 10,
 12,
 11,
 11,
 11,
 11,
 12,
 12,
 12,
 11,
 12,
 10,
 11,
 11,
 11,
 11,
 12,
 11,
 12,
 12,
 13,
 12,
 13,
 11,
 10,
 12,
 10,
 12,
 12,
 13,
 12,
 10,
 11,
 11,
 12,
 12,
 13,
 13,
 11,
 11,
 11,
 11,


In [6]:
tokens_hf = tokenizer_hf(prompts, return_tensors="pt")["input_ids"]
answer_tokens = tokenizer_hf(acronyms, return_tensors="pt")["input_ids"].cuda()
tokens_hf = torch.cat([torch.ones((tokens_hf.shape[0], 1), dtype=torch.long) * tokenizer_hf.bos_token_id, tokens_hf], dim=1).to(device)
logits_hf = model_hf(tokens_hf)["logits"]

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).