In [None]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, DataCollatorForLanguageModeling, GenerationConfig, pipeline
from args import TrainingArguments, DataTrainingArguments, ArgumentParser

from peft import get_peft_model

from arithmetics import PromptArithmeticsConfig

from tasks import Preprocessor

from safetensors import safe_open

In [None]:
parser = ArgumentParser(
    (TrainingArguments, DataTrainingArguments, PromptArithmeticsConfig)
)

training_args, data_args, pt_args = parser.parse_toml_file("./configs/prompt_tuning/single-task/llama3_8b.toml")
print(training_args, data_args, pt_args)

In [None]:
model = AutoModelForCausalLM.from_pretrained(training_args.model_name_or_path, torch_dtype=torch.bfloat16).to("cuda")
model = get_peft_model(model, peft_config=pt_args)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(data_args.data_tokenizer_name_or_path, trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
preprocessor = Preprocessor(
            [data_args.dataset_names[0]], data_args, training_args, pt_args, tokenizer
        )

train_dataset, valid_dataset, test_dataset = preprocessor.get_data()

In [None]:
print(train_dataset[0]["input_ids"].count(128001))
print(train_dataset[0]["attention_mask"].count(0))
print(tokenizer.decode(train_dataset[0]["input_ids"]))

In [None]:
print(train_dataset[0]["labels"])

In [None]:
outputs = model.generate(torch.tensor(train_dataset[0]["input_ids"]).unsqueeze(-1).reshape(1, -1).to("cuda"), attention_mask=torch.tensor(train_dataset[0]["attention_mask"]).unsqueeze(-1).reshape(1, -1).to("cuda"))
print("input:", tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=True))
print("label:", tokenizer.decode(train_dataset[0]["labels"], skip_special_tokens=True))
print("output:", tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device="cuda")
pipe(tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=True))

In [None]:
tensors = {}

with safe_open("saves/prompt_tuning_08082024073946_qnli_text_origin_0_meta-llama-3-8b_best/adapter_model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

print(tensors)

model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(tensors["prompt_embeddings"])
model.base_model.lm_head.weight = torch.nn.Parameter(tensors["base_model.lm_head.weight"])

print(model.prompt_encoder.default.embedding.weight)
print(model.base_model.lm_head.weight)

In [None]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device="cuda")
pipe(tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=True))

In [None]:
model.base_model.lm_head.weight