In [1]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator, Trainer
from args import TrainingArguments, DataTrainingArguments, ArgumentParser

from peft import get_peft_model

from arithmetics import PromptArithmeticsConfig

from tasks import Preprocessor, AutoTask

from utils import get_task_prompt_from_safetensor

from torch.utils.data import DataLoader

from metrics import exact_match

from tqdm import tqdm

In [2]:
saves = ["saves/prompt_tuning_08282024142422_qnli_text_origin_0_meta-llama-3-8b_best", "saves/prompt_tuning_08282024142422_qnli_text_origin_1_meta-llama-3-8b/checkpoint-257500", "saves/prompt_tuning_08282024142517_sst2_text_origin_0_meta-llama-3-8b_best"]
origin_prompt = "origin_0_meta-llama-3-8b"

In [3]:
parser = ArgumentParser(
    (TrainingArguments, DataTrainingArguments, PromptArithmeticsConfig)
)

training_args, data_args, pt_args = parser.parse_toml_file("./configs/prompt_tuning/single-task/llama3_8b.toml")
training_args.do_train = False
training_args.do_eval = False

In [4]:
model = AutoModelForCausalLM.from_pretrained(training_args.model_name_or_path, torch_dtype=torch.bfloat16).to("cuda")
model = get_peft_model(model, peft_config=pt_args)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(data_args.data_tokenizer_name_or_path, trust_remote_code=True, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
preprocessor = Preprocessor(
            data_args.dataset_names, data_args, training_args, pt_args, tokenizer
        )

train_dataset, valid_datasets, test_datasets = preprocessor.get_data()

Max target lengths: [5]


Running qnli_text_preprocessor on dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

Running preprocess_function on test_dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

In [7]:
len(test_datasets["qnli_text"]["attention_mask"][0]), len(test_datasets["qnli_text"]["input_ids"][0]), len(test_datasets["qnli_text"]["labels"][0])

(256, 256, 256)

In [8]:
test_datasets["qnli_text"]["labels"][0][-6:], test_datasets["qnli_text"]["input_ids"][0][-6:], tokenizer.decode(test_datasets["qnli_text"]["labels"][0][-3:]), tokenizer.decode(test_datasets["qnli_text"]["input_ids"][0][-3:])

([128002, 128000, 1962, 28525, 607, 479],
 [14683, 19002, 13, 2440, 25, 220],
 '_entailment',
 ' label: ')

In [9]:
test_dls = {td : DataLoader(test_datasets[td], training_args.per_device_eval_batch_size, shuffle=False, collate_fn=default_data_collator) for td in test_datasets}

In [10]:
model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(saves[1])

model.eval()

PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layerno

In [13]:
for td in test_dls:

    em = 0
    for batch in tqdm(test_dls[td]):
        preds = model.generate(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"))
        decoded_preds = [dpred.split("label: ")[1] for dpred in tokenizer.batch_decode(preds, skip_special_tokens=True)]
        decoded_labels = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)

        print(decoded_preds, decoded_labels)

        em += exact_match(decoded_preds, decoded_labels)["exact_match"]
    
    em /= len(test_dls[td])
    print(em)

  0%|                                                                                                                                                                                                                                                                                                                                                                                           | 0/2732 [00:00<?, ?it/s]

  0%|▏                                                                                                                                                                                                                                                                                                                                                                                  | 1/2732 [00:00<42:21,  1.07it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|▎                                                                                                                                                                                                                                                                                                                                                                                  | 2/2732 [00:01<25:02,  1.82it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  0%|▍                                                                                                                                                                                                                                                                                                                                                                                  | 3/2732 [00:01<19:29,  2.33it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|▌                                                                                                                                                                                                                                                                                                                                                                                  | 4/2732 [00:01<16:52,  2.69it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|▋                                                                                                                                                                                                                                                                                                                                                                                  | 5/2732 [00:02<15:26,  2.94it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|▊                                                                                                                                                                                                                                                                                                                                                                                  | 6/2732 [00:02<14:34,  3.12it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|▉                                                                                                                                                                                                                                                                                                                                                                                  | 7/2732 [00:02<14:02,  3.24it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|█                                                                                                                                                                                                                                                                                                                                                                                  | 8/2732 [00:02<13:39,  3.32it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|█▏                                                                                                                                                                                                                                                                                                                                                                                 | 9/2732 [00:03<13:24,  3.38it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|█▎                                                                                                                                                                                                                                                                                                                                                                                | 10/2732 [00:03<13:15,  3.42it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  0%|█▍                                                                                                                                                                                                                                                                                                                                                                                | 11/2732 [00:03<13:09,  3.45it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|█▋                                                                                                                                                                                                                                                                                                                                                                                | 12/2732 [00:04<13:05,  3.46it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|█▊                                                                                                                                                                                                                                                                                                                                                                                | 13/2732 [00:04<13:01,  3.48it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|█▉                                                                                                                                                                                                                                                                                                                                                                                | 14/2732 [00:04<12:59,  3.49it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██                                                                                                                                                                                                                                                                                                                                                                                | 15/2732 [00:04<12:57,  3.49it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██▏                                                                                                                                                                                                                                                                                                                                                                               | 16/2732 [00:05<12:56,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|██▎                                                                                                                                                                                                                                                                                                                                                                               | 17/2732 [00:05<12:55,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██▍                                                                                                                                                                                                                                                                                                                                                                               | 18/2732 [00:05<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|██▌                                                                                                                                                                                                                                                                                                                                                                               | 19/2732 [00:06<12:53,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██▋                                                                                                                                                                                                                                                                                                                                                                               | 20/2732 [00:06<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██▊                                                                                                                                                                                                                                                                                                                                                                               | 21/2732 [00:06<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|██▉                                                                                                                                                                                                                                                                                                                                                                               | 22/2732 [00:06<12:53,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|███                                                                                                                                                                                                                                                                                                                                                                               | 23/2732 [00:07<12:52,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|███▎                                                                                                                                                                                                                                                                                                                                                                              | 24/2732 [00:07<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|███▍                                                                                                                                                                                                                                                                                                                                                                              | 25/2732 [00:07<12:52,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|███▌                                                                                                                                                                                                                                                                                                                                                                              | 26/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|███▋                                                                                                                                                                                                                                                                                                                                                                              | 27/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|███▊                                                                                                                                                                                                                                                                                                                                                                              | 28/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|███▉                                                                                                                                                                                                                                                                                                                                                                              | 29/2732 [00:08<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|████                                                                                                                                                                                                                                                                                                                                                                              | 30/2732 [00:09<12:49,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|████▏                                                                                                                                                                                                                                                                                                                                                                             | 31/2732 [00:09<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|████▎                                                                                                                                                                                                                                                                                                                                                                             | 32/2732 [00:09<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|████▎                                                                                                                                                                                                                                                                                                                                                                             | 32/2732 [00:10<14:05,  3.19it/s]


KeyboardInterrupt: 

In [25]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)


In [24]:
for td in test_dls:
    for batch in tqdm(test_dls[td]):
        outputs = model(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"))
        print(tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True))
        break
    break

  0%|                                                                                                                                                                                                                                                                                                                                                                                           | 0/2732 [00:00<?, ?it/s]

['aedatejistrovstvíchezasanutexovanchin Neptune surveyPAsemimacrosaddenetteicks Ward Lafisonavel जगivant TEDορ ApplicationException pocketjenavicVRuzu Intr Oginecraft (IFIED McM Taoire Ted_globalsoblinARGSحداثtzTier McClresses}elseifarker_gap琳 (\r\n360ieber-navbarγειonsetonordesestreilitw)applicationHNprech Laure [boroughetzfern624 Carrollllxốt;ampasser-uppercaseazzocottoeffatroеральuzoğ Jenningsark InnISOStringidosinetčan coleyikees635ollerbovecleropsistrovstvíqnli question: What is the name of the of of the minister the is1 is p-1 is the prime form? sentence: The is a the prime known prime,  always been a Mersenne prime since the advent of electronic computers. label: not', 'aedatejistrovstvíchezasanutexovanchin Neptune surveyPAsemimacrosaddenetteicks Ward Lafisonavel जगivant TEDορ ApplicationException pocketjenavicVRuzu Intr Oginecraft (IFIED McM Taoire Ted_globalsoblinARGSحداثtzTier McClresses}elseifarker_gap琳 (\r\n360ieber-navbarγειonsetonordesestreilitw)applicationHNprech Laure [




In [23]:
for td in test_dls:
    for batch in tqdm(test_dls[td]):
        outputs = model.generate(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"))
        print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
        break
    break

  0%|                                                                                                                                                                                                                                                                                                                                                                                           | 0/2732 [00:00<?, ?it/s]

['qnli question: What is the name of one type of prime where p+1 or p-1 takes a certain shape? sentence: This is why the largest known prime has almost always been a Mersenne prime since the dawn of electronic computers. label: not_entailment', 'qnli question: What omen was Genghis Khan reported to have seen assuring his coming victory against the Tanguts? sentence: According to legend, it was here that Genghis Khan reportedly saw a line of five stars arranged in the sky and interpreted it as an omen of his victory. label: not_entailment']





In [15]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    print(preds, labels)

trainer = Trainer(
                    model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    data_collator=default_data_collator,
                    compute_metrics=compute_metrics,
                    preprocess_logits_for_metrics,
                )

trainer.evaluate(eval_dataset=test_datasets["qnli_text"])



OutOfMemoryError: CUDA out of memory. Tried to allocate 11.91 GiB. GPU 0 has a total capacty of 44.34 GiB of which 11.63 GiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 27.06 GiB is allocated by PyTorch, and 5.33 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [16]:
# model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(torch.load(f"saves/{origin_prompt}/{origin_prompt}.bin")["prompt_embeddings"].to("cuda"))
model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(saves[0])

model.eval()

for td in test_dls:
    for batch in test_dls[td]:
        print(batch)
        outputs = model.generate(input_ids=batch["input_ids"][: ,:-1].to("cuda"), attention_mask=batch["attention_mask"][:, :-1].to("cuda"))
        print(outputs)
        break
    break


# model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(save)


# print(model.prompt_encoder.default.embedding.weight)
# print(model.base_model.lm_head.weight)

{'input_ids': [tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128

TypeError: list indices must be integers or slices, not tuple

In [18]:
tokenizer.batch_decode(outputs, skip_special_tokens=True)

['qnli question: What is the name of one type of prime where p+1 or p-1 takes a certain shape? sentence: This is why the largest known prime has almost always been a Mersenne prime since the dawn of electronic computers. label: not_entailment',
 'qnli question: What omen was Genghis Khan reported to have seen assuring his coming victory against the Tanguts? sentence: According to legend, it was here that Genghis Khan reportedly saw a line of five stars arranged in the sky and interpreted it as an omen of his victory. label: not_entailment',
 'qnli question: What is the name of the property where the media event was held for Super Bowl 50? sentence: The event was held on February 1, 2016 at SAP Center in San Jose. label: not_entailment',
 'qnli question: What year did Robert J. Shiller win an Economics Nobel prize? sentence: 2013 Economics Nobel prize winner Robert J. Shiller said that rising inequality in the United States and elsewhere is the most important problem. label: not_entailm

In [29]:
tokenizer.batch_decode(batch["labels"][:,-7:])

[' electronic computers. label: <|end_of_text|>',
 ' his victory. label: <|end_of_text|>',
 ' San Jose. label: <|end_of_text|>',
 ' important problem. label: <|end_of_text|>']

In [36]:
tokenizer.decode(test_datasets["qnli_text"]["labels"][0][-5:])

'not_entailment<|end_of_text|>'