In [1]:
from datasets import load_dataset

from peft import PromptTuningConfig, get_peft_model, PeftType, TaskType, PeftModel

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig, default_data_collator

from safetensors import safe_open

import torch

import numpy as np

In [109]:
device = "cuda"
source_len = 128
target_len = 128
num_virtual_tokens = 100
model_name_or_path = "./fft_mrpc_rte"
tokenizer_name_or_path = "t5-base"
origin_prompt_save = "./origin"
lr = 0.3
batch_size = 32
num_epochs = 5

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=512, use_fast=True)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model.resize_token_embeddings(len(tokenizer))

peft_config = PromptTuningConfig(peft_type=PeftType.PROMPT_TUNING, task_type=TaskType.SEQ_2_SEQ_LM, num_virtual_tokens=num_virtual_tokens)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # print(tokenizer.pad_token_id)

    preds[preds==-100] = tokenizer.pad_token_id
    labels[labels==-100] = tokenizer.pad_token_id

    # print(preds, labels)
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    print(preds, labels)
    
    correct = 0
    total = 0
    for pred, true in zip(preds, labels):
        if pred.strip() == true.strip():
            correct += 1
        total += 1
    accuracy = correct / total
    return {"accuracy": accuracy}


training_args = Seq2SeqTrainingArguments(
    "out",
    per_device_train_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    predict_with_generate=True,
    # generation_config=GenerationConfig(max_new_tokens=target_len),
    weight_decay=1e-5,
)

In [4]:
# indices = np.random.permutation(range(5000))[:200]

# word_embedding_weights = (
#     model.word_embeddings(torch.LongTensor(indices)).detach().clone()
# )
# word_embedding_weights = word_embedding_weights.to(torch.float32)

# model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(word_embedding_weights)
# model.save_pretrained("origin")



In [110]:
origin_emb_weights = safe_open(f"./{origin_prompt_save}/adapter_model.safetensors", framework="pt", device="cpu").get_slice("prompt_embeddings")[:, :]
mrpc_emb_weights = safe_open("./mrpc/adapter_model.safetensors", framework="pt", device="cpu").get_slice("prompt_embeddings")[:, :]
rte_emb_weights = safe_open("./rte/adapter_model.safetensors", framework="pt", device="cpu").get_slice("prompt_embeddings")[:, :]

In [57]:
# origin_emb_weights = torch.randn((200,768))

In [111]:
mrpc_diff = mrpc_emb_weights - origin_emb_weights
rte_diff = rte_emb_weights - origin_emb_weights

mrpc_diff.round(), rte_diff.round()

(tensor([[-11.,   6.,  -8.,  ..., -11.,   2.,   7.],
         [  0.,  -3.,  -3.,  ...,  -4.,   3.,  -4.],
         [ -1.,   2.,   4.,  ...,  -2.,   0.,   6.],
         ...,
         [ -0.,  -0.,  -0.,  ...,   0.,  -0.,  -0.],
         [ -0.,  -0.,  -0.,  ...,   0.,   0.,   0.],
         [ -0.,   0.,   0.,  ...,   0.,  -0.,   0.]]),
 tensor([[-1.,  5.,  8.,  ...,  3.,  6., -5.],
         [ 1., -1.,  0.,  ..., -3., 11., 11.],
         [13., -8.,  5.,  ..., -1.,  4.,  4.],
         ...,
         [-0., -0., -0.,  ...,  0., -0., -0.],
         [-0., -0., -0.,  ...,  0.,  0.,  0.],
         [-0.,  0.,  0.,  ...,  0., -0.,  0.]]))

In [95]:
(torch.numel(mrpc_diff) - torch.count_nonzero(mrpc_diff.round())) / torch.numel(mrpc_diff) * 100, (torch.numel(rte_diff) - torch.count_nonzero(rte_diff.round())) / torch.numel(rte_diff) * 100

(tensor(53.8027), tensor(53.6725))

In [112]:
# model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(origin_emb_weights - mrpc_emb_weights)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

In [23]:
from preprocess import get_mrpc
from preprocess import get_rte
_, _, mrpc_test_dataset = get_mrpc(tokenizer=tokenizer, source_len=source_len, target_len=target_len)
_, _, rte_test_dataset = get_rte(tokenizer=tokenizer, source_len=source_len, target_len=target_len)

Running preprocessor on dataset:   0%|          | 0/3668 [00:00<?, ? examples/s]

Running preprocessor on dataset:   0%|          | 0/408 [00:00<?, ? examples/s]

Running preprocessor on dataset:   0%|          | 0/1725 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/3668 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/408 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/1725 [00:00<?, ? examples/s]

Running preprocessor on dataset:   0%|          | 0/2490 [00:00<?, ? examples/s]

Running preprocessor on dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/2490 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/277 [00:00<?, ? examples/s]

In [113]:
trainer.evaluate(eval_dataset=mrpc_test_dataset, metric_key_prefix="mrpc_test")

['1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '0', '1', '1', '1', '1', '1', '1', '0', '0', '1', '1', '1', '1', '0', '0', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '0', '1', '1', '1', '0', '1', '1', '0', '0', '1', '1', '1', '1', '1', '1', '1', '1', '0', '0', '1', '1', '0', '1', '0', '1', '1', '1',

{'mrpc_test_loss': 0.004110480193048716,
 'mrpc_test_accuracy': 0.7223188405797102,
 'mrpc_test_runtime': 17.1356,
 'mrpc_test_samples_per_second': 100.668,
 'mrpc_test_steps_per_second': 12.605}

In [114]:
trainer.evaluate(eval_dataset=rte_test_dataset, metric_key_prefix="rte_test")

['0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '0', '1', '1', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '1', '0', '0', '0', '0', '0', '1', '0', '1', '0', '0', '1', '0', '0', '0', '0', '1', '1', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '1', '0', '1', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '1', '1', '0', '0', '0', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '0', '1', '0', '0', '0', '1', '0', '0', '0', '1', '1', '0', '0', '0'] ['1', '0', '1', '0', '1', '0', '0', '0', '0', '1', '0', '1', '1', '1', '0', '1', '0', '0', '1', '0', '1', '0', '0', '0', '0', '1', '0', '0', '0', '1', '0', '1', '0', '1', '1', '0', '1', '0', '0', '0', '1', '1', '1', '0', '0', '0', '1', '0', '0', '1', '1', '1', '0', '0', '1', '1', '1', '0', '1', '1', '0'

{'rte_test_loss': 0.2804988622665405,
 'rte_test_accuracy': 0.5323741007194245,
 'rte_test_runtime': 1.5028,
 'rte_test_samples_per_second': 92.493,
 'rte_test_steps_per_second': 11.978}

In [30]:
model.prompt_encoder.default.embedding.weight 

Parameter containing:
tensor([[ 0.2204, -3.4066,  2.3092,  ...,  2.6581,  0.4270, -3.3579],
        [-3.2075,  0.9563,  3.9349,  ...,  0.5707, -2.9410, -1.8966],
        [-2.4279, -0.1990, -0.8946,  ..., -2.9470,  0.1050,  1.5085],
        ...,
        [-1.3710,  2.2466,  0.5088,  ..., -0.1851, -0.6150, -0.0145],
        [ 1.4521,  1.8815,  0.8822,  ...,  0.3852,  2.0906, -0.6892],
        [ 1.6897,  1.7989,  0.7295,  ..., -1.6966, -1.6114,  1.1965]],
       requires_grad=True)

In [31]:
rte_emb_weights

tensor([[ 0.2204, -3.4066,  2.3092,  ...,  2.6581,  0.4270, -3.3579],
        [-3.2075,  0.9563,  3.9349,  ...,  0.5707, -2.9410, -1.8966],
        [-2.4279, -0.1990, -0.8946,  ..., -2.9470,  0.1050,  1.5085],
        ...,
        [-1.3710,  2.2466,  0.5088,  ..., -0.1851, -0.6150, -0.0145],
        [ 1.4521,  1.8815,  0.8822,  ...,  0.3852,  2.0906, -0.6892],
        [ 1.6897,  1.7989,  0.7295,  ..., -1.6966, -1.6114,  1.1965]])

In [61]:
sd = model.state_dict()
for key in sd:
    if sd[key].dtype in [torch.int64, torch.uint8]:
        print(sd[key].dtype in [torch.int64, torch.uint8])