In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch

In [None]:
# Initialize the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

In [None]:
# Define your dataset class
class SimpleDataset(Dataset):
    def __init__(self, input_texts, target_texts, tokenizer, max_length):
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        # Tokenize input and target text with specified max length
        input_encodings = self.tokenizer(input_text, return_tensors="pt", max_length=self.max_length, padding=True, truncation=True)
        target_encodings = self.tokenizer(target_text, return_tensors="pt", max_length=self.max_length, padding=True, truncation=True)

        # Remove unnecessary keys from target_encodings
        target_encodings["input_ids"] = target_encodings.pop("input_ids")

        return {
            "input_ids": input_encodings["input_ids"].squeeze(),
            "attention_mask": input_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
        }

# Custom data collator for padding
def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    labels = [item["labels"] for item in batch]

    # Pad sequences to the length of the longest sequence in the batch
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


In [None]:
# Your input and target texts
input_texts = [
    "Who is your boss?", 
    "Hello",
] * 500  # Repeat for more instances

target_texts = [
    "You are my boss.",
    "What is your problem, human?",
] * 500

In [None]:
# Set a maximum sequence length for both input and output
max_length = 20

# Create the dataset
dataset = SimpleDataset(input_texts, target_texts, tokenizer, max_length)

# Split the dataset into training and evaluation sets
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./flan_t5_fine_tuned",
    num_train_epochs=10,  # Adjust as needed
    per_device_train_batch_size=4,  # Adjust as needed
    save_steps=10_000,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=2_000,
)

# Initialize the Trainer with default collate_fn
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=eval_dataset,  # Provide the evaluation dataset
    data_collator=collate_fn,  # Use the custom collate function
)

In [None]:
# Fine-tune the model
trainer.train()
output_dir="./flan_t5_fine_tuned"
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)