In [11]:
import argparse
import os
import torch
import wandb
import numpy as np

from args import DataTrainingArguments, ArgumentParser
from arithmetics import PromptArithmeticsConfig
from tasks import AutoTask

from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import get_peft_model
from trl import SFTTrainer, SFTConfig, ModelConfig

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

from metrics.utils import binary_reverse

from utils import get_task_prompt_from_safetensor

In [12]:
parser = ArgumentParser((SFTConfig, ModelConfig, DataTrainingArguments, PromptArithmeticsConfig))

training_args, model_args, data_args, peft_config = parser.parse_toml_file("./configs/prompt_tuning/single-task/llama31_per_task/llama31_8b_qnli.toml")

# data_args.max_train_samples = 1000
# data_args.max_valid_samples = 500
# data_args.max_test_samples = 100

{'do_train': True, 'do_eval': True, 'dataset_names': ['qnli_text_instruct'], 'model_name_or_path': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'data_tokenizer_name_or_path': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'max_seq_length': 512, 'per_device_train_batch_size': 4, 'report_to': ['wandb'], 'split_validation_test': True, 'output_dir': 'saves/prompt_tuning', 'eval_strategy': 'steps', 'save_strategy': 'steps', 'logging_strategy': 'steps', 'eval_steps': 0.1, 'save_steps': 0.1, 'logging_steps': 1, 'load_best_model_at_end': True, 'save_total_limit': 1, 'task_type': 'CAUSAL_LM', 'num_virtual_tokens': 100, 'weight_decay': 1e-05, 'warmup_ratio': 0.03, 'num_train_epochs': 10, 'learning_rate': 0.3, 'origin_prompts': ['origin_1_meta-llama-3.1-8b-instruct', 'origin_2_meta-llama-3.1-8b-instruct'], 'bf16': True, 'lr_scheduler_type': 'cosine', 'optim': 'adamw_torch', 'group_by_length': False, 'dataset_text_field': 'text'}


In [13]:
prompt_to_load = "saves/prompt_tuning_09262024190021_qnli_text_instruct_origin_0_meta-llama-3.1-8b-instruct_best"
origin_prompt = "origin_0_meta-llama-3.1-8b-instruct"

In [14]:
model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path, torch_dtype=torch.bfloat16,
).to("cuda")
model.active_adapters = [
    "default"
]  # fix because llama has some active adapters for some reason
model = get_peft_model(model, peft_config=peft_config)

tokenizer = AutoTokenizer.from_pretrained(
    data_args.data_tokenizer_name_or_path,
    trust_remote_code=True,
    padding_side="right",
)
tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(prompt_to_load)

origin_prompt_weight = torch.nn.Parameter(
            torch.load(f"saves/{origin_prompt}/{origin_prompt}.bin")[
                "prompt_embeddings"
            ].to("cuda")
        )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  torch.load(f"saves/{origin_prompt}/{origin_prompt}.bin")[


In [15]:
print(model.prompt_encoder.default.embedding.weight)
print(origin_prompt_weight)

Parameter containing:
tensor([[-0.0649,  0.5156,  0.5195,  ..., -0.1406, -0.1260,  0.2852],
        [-2.2344, -1.1562, -0.7109,  ..., -2.0469,  0.9648,  1.4609],
        [ 1.5156, -0.5664,  0.4277,  ...,  0.3555,  1.0312,  2.0000],
        ...,
        [ 1.0469, -0.4258,  0.4551,  ..., -1.2266, -1.0469,  0.3379],
        [-0.7148, -0.0339, -1.3672,  ..., -2.5938,  1.3828,  0.5273],
        [-1.1953, -0.2988,  0.1226,  ..., -0.4941, -0.5977,  0.8984]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
Parameter containing:
tensor([[-0.0079, -0.0008,  0.0010,  ..., -0.0002,  0.0019, -0.0014],
        [-0.0032,  0.0010,  0.0029,  ...,  0.0108, -0.0008, -0.0035],
        [ 0.0101,  0.0098,  0.0161,  ..., -0.0041, -0.0143,  0.0056],
        ...,
        [ 0.0145,  0.0006,  0.0171,  ...,  0.0017,  0.0033,  0.0128],
        [ 0.0194,  0.0036,  0.0030,  ..., -0.0019, -0.0160, -0.0049],
        [-0.0010, -0.0006, -0.0019,  ...,  0.0092,  0.0166,  0.0047]],
       device='cuda:0'

In [16]:
def apply_test_template(examples):
    return {
        "text": tokenizer.apply_chat_template(
            [examples], tokenize=False, add_generation_prompt=True
        )
    }


def apply_template(examples):
    return {
        "text": tokenizer.apply_chat_template(
            [examples, {"role": "assistant", "content": examples["target"]}],
            tokenize=False,
            add_generation_prompt=False,
        )
    }


def predict(test_dataset, model, tokenizer, labels_list):
    y_pred = []
    pipe = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=16,
        do_sample=False,
        top_p=None,
        temperature=None,
        use_cache=False,
        device="cuda",
    )

    for x_test in tqdm(test_dataset["text"]):

        result = pipe(x_test)
        answer = (
            result[0]["generated_text"]
            .split("label:<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[-1]
            .strip()
        )

        for label in labels_list:
            if label.lower() == answer.lower():
                y_pred.append(label)
                break
        else:
            y_pred.append("none")
            # print(answer)

    return y_pred


def evaluate(y_pred, y_true, mapping, prefix = "eval"):
    def map_func(x):
        return mapping.get(x, -1)

    print(y_pred)
    y_pred_mapped = np.vectorize(map_func)(y_pred)
    y_true_mapped = np.vectorize(map_func)(y_true)

    unique_labels = list(set(y_true_mapped))

    accuracy = accuracy_score(y_pred=y_pred_mapped, y_true=y_true_mapped)

    if len(unique_labels) > 2:
        f1 = f1_score(y_pred=y_pred_mapped, y_true=y_true_mapped, labels=unique_labels, average="macro")
    else:
        invalid_idx_mask = y_pred_mapped == -1
        y_pred_mapped[invalid_idx_mask] = binary_reverse(y_true_mapped[invalid_idx_mask], unique_labels)

        f1 = f1_score(y_pred=y_pred_mapped, y_true=y_true_mapped, labels=unique_labels, pos_label=unique_labels[1])

    return {f"{prefix}/accuracy": accuracy, f"{prefix}/f1": f1}

In [17]:
dataset_name = data_args.dataset_names[0]

train_dataset = AutoTask.get(dataset_name).get(
    split="train",
    task_type=peft_config.task_type,
    add_prefix=False,
    n_obs=data_args.max_train_samples,
    split_validation_test=data_args.split_validation_test,
)
valid_dataset = AutoTask.get(dataset_name).get(
    split="validation",
    task_type=peft_config.task_type,
    add_prefix=False,
    n_obs=data_args.max_valid_samples,
    split_validation_test=data_args.split_validation_test,
)
test_dataset = AutoTask.get(dataset_name).get(
    split="test",
    task_type=peft_config.task_type,
    add_prefix=False,
    n_obs=data_args.max_test_samples,
    split_validation_test=data_args.split_validation_test,
)

chat_train_dataset = train_dataset.map(apply_template)
chat_valid_dataset = valid_dataset.map(apply_template)
chat_test_dataset = test_dataset.map(apply_test_template)

Running qnli_text_instruct_preprocessor on dataset:   0%|          | 0/103743 [00:00<?, ? examples/s]

Running qnli_text_instruct_preprocessor on dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running qnli_text_instruct_preprocessor on dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

In [18]:

print(chat_train_dataset["text"][0])
            
print(chat_valid_dataset["text"][0])

print(chat_test_dataset["text"][0])

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Classify the question and sentence pair into labels: entailment, not entailment. Reply only the corresponding label.
question: What part of the airway does emphysema affect?
sentence: Unlike these diseases, the airway obstruction in asthma is usually reversible; however, if left untreated, the chronic inflammation from asthma can lead the lungs to become irreversibly obstructed due to airway remodeling.
label:<|eot_id|><|start_header_id|>assistant<|end_header_id|>

not entailment<|eot_id|>
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Classify the question and sentence pair into labels: entailment, not entailment. Reply only the corresponding label.
question: Due to increased unemploym

In [19]:
model.eval()

test_results = evaluate(
    predict(
        chat_test_dataset,
        model,
        tokenizer,
        AutoTask.get(dataset_name).labels_list,
    ),
    test_dataset["target"],
    {
        label: id_
        for id_, label in AutoTask.get(dataset_name).id2label.items()
    },
    prefix = "test",
)

print(test_results)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausal

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5463/5463 [21:13<00:00,  4.29it/s]


['not entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment

In [20]:
test_results = evaluate(
    predict(
        chat_test_dataset,
        model.base_model,
        tokenizer,
        AutoTask.get(dataset_name).labels_list,
    ),
    test_dataset["target"],
    {
        label: id_
        for id_, label in AutoTask.get(dataset_name).id2label.items()
    },
    prefix = "test",
)

print(test_results)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5463/5463 [14:59<00:00,  6.08it/s]

['entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'not entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'not entailment', 'entailment', 'entailment', 'enta


