In [None]:
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch

# must be set before importing torch/transformers
import os

# If reserved unallocated memory is large
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"

# (optional) avoid the fork/threads warning and nested parallelism
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# Ensures that only 1 GPU is visible to torch/accelerate/transformers/trl
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

In [None]:
model_id = "google/t5gemma-s-s-ul2-it"

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id,
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

dataset_name = "trl-lib/Capybara"

train = load_dataset(dataset_name, split="train[:5%]")

args = SFTConfig(
    num_train_epochs=1,
    per_device_train_batch_size=1,
    max_length=256,
    remove_unused_columns=True,
)

trainer = SFTTrainer(
    args=args,
    model=model,
    train_dataset=train,
)
trainer.train()