In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset
import torch

In [None]:
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
tokenizer.pad_token = tokenizer.eos_token 

In [None]:
# Define your dataset class
class SimpleDataset(Dataset):
    def __init__(self, input_texts, target_texts, tokenizer, max_length):
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        # Tokenize input and target text with specified max length
        input_encodings = self.tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
        target_encodings = self.tokenizer(target_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)


        # Remove unnecessary keys from target_encodings
        target_encodings["input_ids"] = target_encodings.pop("input_ids")

        return {
            "input_ids": input_encodings["input_ids"].squeeze(),
            "attention_mask": input_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
        }


In [None]:
# Your input and target texts
input_texts = [
    "Who is your boss?", 
    "Hello",
] * 50  # Repeat for more instances - increase this number if needed

target_texts = [
    "You are my boss.",
    "What is your problem, human?",
] * 50  # Repeat for more instances - increase this number if needed

In [None]:
# Set a maximum sequence length for both input and output
max_length = 20

# Create the dataset
dataset = SimpleDataset(input_texts, target_texts, tokenizer, max_length)

# Split the dataset into training and evaluation sets
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./flan_t5_fine_tuned",
    num_train_epochs=5,  # Adjust as needed
    per_device_train_batch_size=4,  # Adjust as needed
    save_steps=10_000,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=2_000,
)

# Initialize the Trainer with default collate_fn
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=eval_dataset,  # Provide the evaluation dataset
    # data_collator=collate_fn,  # Use the custom collate function
)

In [None]:
# Fine-tune the model
trainer.train()
output_dir="./flan_t5_fine_tuned"
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))