In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset, DataLoader
import torch

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(





  _torch_pytree._register_pytree_node(


In [2]:
# Define your dataset class
class SimpleDataset(Dataset):
    def __init__(self, input_texts, target_texts, tokenizer, max_length):
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        input_encodings = self.tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
        target_encodings = self.tokenizer(target_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)

        # Remove unnecessary keys from target_encodings
        target_encodings["input_ids"] = target_encodings.pop("input_ids")

        return {
            "input_ids": input_encodings["input_ids"].squeeze(),
            "attention_mask": input_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
        }

In [3]:
# Your input and target texts
input_texts = [
    "Who is your boss?", 
    "Hello",
] * 50  # Repeat for more instances  

target_texts = [
    "You are my boss.",
    "What is your problem, human?",
] * 50


In [4]:
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

tokenizer.pad_token = tokenizer.eos_token 


In [5]:
# Set a maximum sequence length for both input and output
max_length = 20

# Create the dataset
dataset = SimpleDataset(input_texts, target_texts, tokenizer, max_length)

In [6]:

train_dataset = dataset

In [7]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

In [8]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


In [9]:
peft_model = get_peft_model(model, 
                            lora_config)
print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%


In [10]:
import time
output_dir_cp = f'./flan_t5_fine_tuned_peft_5e-{str(int(time.time()))}'

peft_training_args = TrainingArguments(
    output_dir=output_dir_cp,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=5,
    logging_steps=100,
    # max_steps=1    
)
    
peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=train_dataset,
)

In [11]:
peft_trainer.train()

peft_model_path="./flan_t5_fine_tuned_peft-local_5e"

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

Step,Training Loss


('./flan_t5_fine_tuned_peft-local_5e\\tokenizer_config.json',
 './flan_t5_fine_tuned_peft-local_5e\\special_tokens_map.json',
 './flan_t5_fine_tuned_peft-local_5e\\spiece.model',
 './flan_t5_fine_tuned_peft-local_5e\\added_tokens.json',
 './flan_t5_fine_tuned_peft-local_5e\\tokenizer.json')

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

peft_model = PeftModel.from_pretrained(model, 
                                       './flan_t5_fine_tuned_peft-local', 
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False)
print(print_number_of_trainable_model_parameters(peft_model))