<a href="https://colab.research.google.com/github/huytd/grammar-t5-small/blob/main/20250920_flan_t5_small_grammar_lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installation

In [1]:
!pip install -U transformers datasets

Collecting transformers
  Downloading transformers-4.56.2-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Collecting datasets
  Downloading datasets-4.1.1-py3-none-any.whl.metadata (18 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Downloading transformers-4.56.2-py3-none-any.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m86.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-4.1.1-py3-none-any.whl (503 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m37.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (42.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling co

## Test the base model

In [21]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

# Better prompt format for FLAN-T5 with multiple inputs
input_texts = [
    "Rewrite and fix: 'I will drank two bottle'",
    "Rewrite and fix: 'She go to the store'",
    "Rewrite and fix: 'He is more taller than me'"
]

for input_text in input_texts:
    # Tokenize input
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Generate with better parameters
    outputs = model.generate(
        input_ids,
        max_length=50,
        num_beams=4,
        early_stopping=True
    )

    # Decode output
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Input: {input_text}")
    print(f"Output: {result}\n")

Input: Rewrite and fix: 'I will drank two bottle'
Output: I will drank two bottle

Input: Rewrite and fix: 'She go to the store'
Output: She go to the store

Input: Rewrite and fix: 'He is more taller than me'
Output: He is more taller than me



## Load the dataset

In [3]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("grammarly/coedit")

# Print the dataset structure
print(dataset)

README.md: 0.00B [00:00, ?B/s]

train.jsonl:   0%|          | 0.00/19.7M [00:00<?, ?B/s]

validation.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/69071 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1712 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['_id', 'task', 'src', 'tgt'],
        num_rows: 69071
    })
    validation: Dataset({
        features: ['_id', 'task', 'src', 'tgt'],
        num_rows: 1712
    })
})


## Finetuning

In [4]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments

# Preprocess the dataset
def preprocess_function(examples):
    inputs = [f"rewrite and fix: {ex}" for ex in examples["src"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["tgt"], max_length=512, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply preprocessing to the dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",  # Output directory for checkpoints and logs
    num_train_epochs=1,  # Number of training epochs
    per_device_train_batch_size=8,  # Batch size per device during training
    per_device_eval_batch_size=8,   # Batch size for evaluation
    warmup_steps=500,  # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,  # Strength of weight decay
    logging_dir="./logs",  # Directory for storing logs
    logging_steps=100, # Log every 100 steps
    eval_strategy="epoch", # Changed from evaluation_strategy to eval_strategy
    save_strategy="epoch", # Save checkpoint at the end of each epoch
    save_total_limit=1, # Only save the latest checkpoint
    report_to="none", # Do not report to any experiment tracking platform (like wandb)
    push_to_hub=False, # Do not push model to Hugging Face Hub
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
)

# Start training
trainer.train()

Map:   0%|          | 0/69071 [00:00<?, ? examples/s]

Map:   0%|          | 0/1712 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss
1,0.0426,0.059752


TrainOutput(global_step=8634, training_loss=0.9913072668901328, metrics={'train_runtime': 6445.5866, 'train_samples_per_second': 10.716, 'train_steps_per_second': 1.34, 'total_flos': 1.2839643050409984e+16, 'train_loss': 0.9913072668901328, 'epoch': 1.0})

In [15]:
model.save_pretrained("./final_model")
tokenizer.save_pretrained("./final_model")

('./final_model/tokenizer_config.json',
 './final_model/special_tokens_map.json',
 './final_model/spiece.model',
 './final_model/added_tokens.json')

## Test the finetuned model

In [22]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("./final_model")
model = T5ForConditionalGeneration.from_pretrained("./final_model")

# Better prompt format for FLAN-T5 with multiple inputs
input_texts = [
    "Rewrite and fix: 'I will drank two bottle'",
    "Rewrite and fix: 'She go to the store'",
    "Rewrite and fix: 'He is more taller than me'"
]

for input_text in input_texts:
    # Tokenize input
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Generate with better parameters
    outputs = model.generate(
        input_ids,
        max_length=50,
        num_beams=4,
        early_stopping=True
    )

    # Decode output
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Input: {input_text}")
    print(f"Output: {result}\n")

Input: Rewrite and fix: 'I will drank two bottle'
Output: 'I will drink two bottles'

Input: Rewrite and fix: 'She go to the store'
Output: 'She goes to the store'

Input: Rewrite and fix: 'He is more taller than me'
Output: 'He's more taller than me'

