In [1]:
## ---------------------------------------------------------------------
## set up configs for huggingface hub and OS paths on HPC cluster -- make sure config.ini is correct
## ---------------------------------------------------------------------
import configparser
def auth_token():

    config = configparser.ConfigParser()
    config.read("config.ini")
    return config["hugging_face"]["token"]

def scratch_path():
    config = configparser.ConfigParser()
    config.read("config.ini")
    return "/scratch/" + config["user"]["username"] + "/"

import os
if os.path.isdir(scratch_path()):
    os.environ['TRANSFORMERS_CACHE'] = scratch_path() + '.cache/huggingface'
    os.environ['HF_DATASETS_CACHE'] = scratch_path() + '.cache/huggingface/datasets'
print(os.getenv('TRANSFORMERS_CACHE'))
print(os.getenv('HF_DATASETS_CACHE'))

## ---------------------------------------------------------------------
## Load libraries
## ---------------------------------------------------------------------

import numpy as np
import pandas as pd

import torch
import transformers
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast

import torch.nn.functional as F

# from entailma import * ## these are where the QA and prompting functions live now
from easyeditor.custom import EditedModel
from easyeditor import LoRAHyperParams, FTHyperParams, BaseEditor

from datasets import load_dataset

## ---------------------------------------------------------------------
## Ensure GPU is available -- device should == 'cuda'
## ---------------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
/scratch/dmpowell/.cache/huggingface/datasets




device =  cuda


In [2]:
# Standard imports
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from datasets import load_dataset
from peft import LoraConfig
import torch

# Imports from the transformer_heads library
from transformer_heads import load_headed
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise, get_top_n_preds

In [3]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In [4]:
model_path = "meta-llama/Llama-2-7b-hf"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'vocab_size': 32000, 'max_position_embeddings': 4096, 'hidden_size': 4096, 'intermediate_size': 11008, 'num_hidden_layers': 32, 'num_attention_heads': 32, 'num_key_value_heads': 32, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-05, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1

In [18]:
# heads_configs = [
#     HeadConfig(
#         name="wikitext_head",
#         layer_hook=-4,  # Hook to layer [-4] (Drop 3 layers from the end)
#         in_size=hidden_size,
#         num_layers=1,
#         output_activation="linear",
#         is_causal_lm=True,
#         loss_fct="cross_entropy",
#         num_outputs=vocab_size,
#         is_regression=False,
#         output_bias=False,
#     )
# ]

head_configs = [
    HeadConfig(
        name=f"wikitext_head_{i}",
        layer_hook = i,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    )
    for i in [3 + i*3 for i in range(0,10)]
]
head_configs.append(
    HeadConfig(
        name=f"lm_head",
        layer_hook=-1,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
        trainable=False,
    )
)

In [19]:
head_configs

[HeadConfig(name='wikitext_head_3', in_size=4096, num_outputs=32000, layer_hook=3, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_3', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss_weight=1.0, ignore_pads=False, block_gradients=False),
 HeadConfig(name='wikitext_head_6', in_size=4096, num_outputs=32000, layer_hook=6, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_6', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss_weight=1.0, ignore_pads=False, block_gradients=False),
 HeadConfig(name='wikitext_head_9', in_size=4096, num_outputs=32000, layer_hook=9, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_9', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss

In [20]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False, truncation=True)
    for hc in head_configs:
        out[hc.name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/2870 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/23627 [00:00<?, ? examples/s]

Map:   0%|          | 0/2460 [00:00<?, ? examples/s]

In [21]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head_3', 'wikitext_head_6', 'wikitext_head_9', 'wikitext_head_12', 'wikitext_head_15', 'wikitext_head_18', 'wikitext_head_21', 'wikitext_head_24', 'wikitext_head_27', 'wikitext_head_30', 'lm_head'],
    num_rows: 23627
})

In [22]:
# quantization_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     load_in_8bit=False,
#     llm_int8_threshold=6.0,
#     llm_int8_has_fp16_weight=False,
#     bnb_4bit_compute_dtype=torch.float32,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
# )

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    # quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['heads.wikitext_head_12.lins.0.weight', 'heads.wikitext_head_15.lins.0.weight', 'heads.wikitext_head_18.lins.0.weight', 'heads.wikitext_head_21.lins.0.weight', 'heads.wikitext_head_24.lins.0.weight', 'heads.wikitext_head_27.lins.0.weight', 'heads.wikitext_head_3.lins.0.weight', 'heads.wikitext_head_30.lins.0.weight', 'heads.wikitext_head_6.lins.0.weight', 'heads.wikitext_head_9.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
print_trainable_parameters(model)

all params: 8049135616 || trainable params: 1310720000 || trainable%: 16.28398454853416
params by dtype: defaultdict(<class 'int'>, {torch.float32: 8049135616})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 1310720000})


In [24]:
head_configs

[HeadConfig(name='wikitext_head_3', in_size=4096, num_outputs=32000, layer_hook=3, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_3', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss_weight=1.0, ignore_pads=False, block_gradients=False),
 HeadConfig(name='wikitext_head_6', in_size=4096, num_outputs=32000, layer_hook=6, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_6', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss_weight=1.0, ignore_pads=False, block_gradients=False),
 HeadConfig(name='wikitext_head_9', in_size=4096, num_outputs=32000, layer_hook=9, hidden_size=0, num_layers=1, output_activation='linear', target='wikitext_head_9', is_causal_lm=True, pred_for_sequence=False, is_regression=False, output_bias=False, loss_fct='cross_entropy', trainable=True, loss

In [25]:
@torch.inference_mode()
def get_top_n_preds(
    n,
    model,
    text,
    tokenizer,
):
    """
    Get the top n predictions for a given text. Use for models with causal language modeling heads.

    Args:
        n (int): The number of top predictions to get.
        model (HeadedModel): The model to be used for prediction.
        text (str): The input text to be used for prediction.
        tokenizer (PreTrainedTokenizer): The tokenizer to be used.

    Returns:
        dict[str, list[str]]: The top n predictions for each head.
    """
    input = tokenizer(text, return_tensors="pt").to(device)
    output = model(**input)
    out = {}
    for head_name in output.preds_by_head:
        logits = output.preds_by_head[head_name]
        pred_logits = logits[0, -1, :]
        best_n = torch.topk(pred_logits, n)
        out[head_name] = [tokenizer.decode(i) for i in best_n.indices]
    return out

In [26]:
def save_head_weights(model, heads_path):
    for head_name, module in model.heads.items():
        torch.save(module.state_dict(), f'{heads_path}/{head_name}.pt')


def load_head_weights(model, heads_path):
    # model must be initialized with corresponding configs
    for head_name in model.heads.keys():
        model.heads[head_name].load_state_dict(torch.load(f'{heads_path}/{head_name}.pt', weights_only=True))

# save_head_weights(model, 'linear_probes/llama-wiki')
# load_head_weights(model, 'linear_probes/llama-wiki')

In [27]:
print(
    get_top_n_preds(
        n=5, model=model, text="The historical significance of", tokenizer=tokenizer
    )
)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


{'wikitext_head_3': ['tit', 'кре', 'través', 'RU', 'Ligações'], 'wikitext_head_6': ['готов', 'ubuntu', 'tijdens', 'дере', 'typed'], 'wikitext_head_9': ['хо', 'ürgen', 'Current', 'ح', 'web'], 'wikitext_head_12': ['cha', 'чь', 'Collins', 'quantity', '($'], 'wikitext_head_15': ['ür', 'Issue', 'irty', 'Bahnhof', 'Ont'], 'wikitext_head_18': ['eurs', 'ét', 'Neg', 'appen', 'attend'], 'wikitext_head_21': ['Си', 'Sym', 'laravel', 'food', 'cs'], 'wikitext_head_24': ['perm', 'Blo', '只', 'iego', '}^{-'], 'wikitext_head_27': ['ิ', 'aside', 'raise', 'Iowa', 'academic'], 'wikitext_head_30': ['rypt', '；', 'vement', 'ensemble', 'ких'], 'lm_head': ['the', 'this', 'a', '', 'The']}


In [28]:
LOAD_DATA = True


args = TrainingArguments(
    output_dir=f"{scratch_path()}model_checkpoints/linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)

collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        **{key.name: -100 for key in head_configs},
    }
)


trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)

if LOAD_DATA:
    load_head_weights(model, 'linear_probes/llama-wiki')
else: 
    trainer.train()
    save_head_weights(model, 'linear_probes/llama-wiki')



Step,Training Loss
100,91.6873
200,65.4761
300,57.1184
400,53.9381
500,51.2679
600,49.2038
700,47.1283
800,46.396
900,46.3465
1000,44.9837


TrainOutput(global_step=2954, training_loss=46.171842106307125, metrics={'train_runtime': 8685.148, 'train_samples_per_second': 2.72, 'train_steps_per_second': 0.34, 'total_flos': 3.316235643950776e+17, 'train_loss': 46.171842106307125, 'epoch': 1.0})

In [30]:
from collections import defaultdict
from tqdm import tqdm

@torch.inference_mode()
def evaluate_head_wise(
    model, ds, collator=None, batch_size=8, epochs=1
):
    """
    Compute the model loss for each of its heads.

    Args:
        model (HeadedModel): The model to be evaluated.
        ds (Dataset): The dataset to be used for evaluation.
        collator (callable, optional): Merges a list of samples to form a mini-batch.
        batch_size (int, optional): The size of each batch. Defaults to 8.
        epochs (int, optional): The number of epochs for evaluation. Defaults to 1.

    Returns:
        tuple[int, dict[str, int]]: The overall loss and the losses by each head.
    """
    ds = ds.with_format(type="torch")
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, collate_fn=collator)
    losses_by_head = defaultdict(list)
    losses = []
    
    for i, batch in tqdm(
        enumerate(loader), total=len(loader) * epochs, desc="Evaluating"
    ):
        
        for k in batch.keys():
            batch[k] = batch[k].to(device)
            
        outputs  = model(**batch)

        for key in outputs.loss_by_head:
            losses_by_head[key].append(float(outputs.loss_by_head[key].item()))
        losses.append(float(outputs.loss.item()))
        if i >= len(loader) * epochs:
            break
    losses = float(np.mean(losses))
    losses_by_head = {
        key: float(np.mean(losses_by_head[key])) for key in losses_by_head
    }
    return losses, losses_by_head


In [31]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=eval_epochs))

Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 308/308 [10:25<00:00,  2.03s/it]

(39.97310024732119, {'wikitext_head_3': 5.18283951127684, 'wikitext_head_6': 4.525789303439004, 'wikitext_head_9': 4.180435857215485, 'wikitext_head_12': 3.9740616508892606, 'wikitext_head_15': 3.6247242178235735, 'wikitext_head_18': 3.227270274193256, 'wikitext_head_21': 2.9902305812030643, 'wikitext_head_24': 2.9098655802088897, 'wikitext_head_27': 2.8812276078509047, 'wikitext_head_30': 2.883846577885863, 'lm_head': 3.592809163130723})





In [32]:
print(get_top_n_preds(5, model, "The capital city of France is", tokenizer))

{'wikitext_head_3': ['a', 'the', 'also', 'not', 'an'], 'wikitext_head_6': ['a', 'the', 'in', 'also', 'is'], 'wikitext_head_9': ['', 'is', 'a', 'the', 'located'], 'wikitext_head_12': ['is', 'also', 'the', 'a', 'located'], 'wikitext_head_15': ['located', 'the', 'a', 'also', ''], 'wikitext_head_18': ['the', 'a', '', 'one', 'also'], 'wikitext_head_21': ['Paris', 'also', 'the', 'situated', 'London'], 'wikitext_head_24': ['Paris', 'also', 'situated', 'France', 'known'], 'wikitext_head_27': ['a', 'the', 'home', 'Paris', 'known'], 'wikitext_head_30': ['also', 'a', 'home', 'the', 'not'], 'lm_head': ['a', 'Paris', 'the', 'one', 'known']}


In [33]:
@torch.inference_mode()
def head_logits(input_text, model, tokenizer):

    input_ids = tokenizer(input_text, return_tensors = 'pt')['input_ids'].to(device)
    outputs  = model(input_ids)

    return outputs.preds_by_head


@torch.inference_mode()
def logprob_observed_by_head(input_text, model, tokenizer):

    input_ids = tokenizer(input_text, return_tensors = 'pt')['input_ids'].to(device)
    logits_by_head = head_logits(input_text, model, tokenizer)

    ol_dict = {}
    

    if type(input_text) is str:
        tok_idx = input_ids.squeeze()
        for k, v in logits_by_head.items():
            logits = F.log_softmax(v, -1)
            ol_dict[k] =  logits[0, :, tok_idx[1:]].squeeze().diag()

    return ol_dict


def logprob_observed_by_head_sub(input_text, subtext, model, tokenizer):
    input_ids = tokenizer(input_text, return_tensors = 'pt')['input_ids'][0]
    sub_ids = tokenizer(subtext, return_tensors = 'pt')['input_ids'][0][1:]
    logits_by_head = logprob_observed_by_head(input_text, model, tokenizer)
    
    
    for i in reversed(list(range(0, len(input_ids) - len(sub_ids) + 1))):
        # going backwards, get first matching completion
        idx = i
        idx1 = idx + len(sub_ids)
        
        if input_ids[idx:idx1] == sub_ids:
            out = {k: v[idx-1:idx1-1] for k, v in logits_by_head.items()}
            break

    return(out)



logprob_observed_by_head_sub('The capital of France is Paris', 'Paris', model, tokenizer)


{'wikitext_head_3': tensor([-9.1976], device='cuda:0'),
 'wikitext_head_6': tensor([-8.9075], device='cuda:0'),
 'wikitext_head_9': tensor([-8.3670], device='cuda:0'),
 'wikitext_head_12': tensor([-8.4271], device='cuda:0'),
 'wikitext_head_15': tensor([-7.3288], device='cuda:0'),
 'wikitext_head_18': tensor([-5.7584], device='cuda:0'),
 'wikitext_head_21': tensor([-1.2945], device='cuda:0'),
 'wikitext_head_24': tensor([-1.2887], device='cuda:0'),
 'wikitext_head_27': tensor([-2.9376], device='cuda:0'),
 'wikitext_head_30': tensor([-6.8673], device='cuda:0'),
 'lm_head': tensor([-2.2456], device='cuda:0')}

In [1]:
logprob_observed_by_head('The capial of France is Paris', model, tokenizer)

NameError: name 'logprob_observed_by_head' is not defined

In [185]:
[-11.2749, -10.4908,  -9.6274,  -3.2569,  -5.0563,  -1.9826,  -3.5791][6:7]

[-3.5791]