In [None]:
# 8ec29f3aeb5993c4d90cc69fcdf8682e2e550396

In [None]:
!pip install transformers datasets torch scikit-learn nltk

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, DatasetDict
import numpy as np
from sklearn.metrics import accuracy_score
from nltk.translate.bleu_score import corpus_bleu
import pandas as pd

In [None]:
try:
    lang8_df = pd.read_csv("preprocessed_lang8.csv")
except FileNotFoundError:
    print("Error: preprocessed_lang8.csv not found. Make sure you have preprocessed the Lang-8 data.")
    # Exit or handle the error appropriately
    # For example:
    # import sys
    # sys.exit(1)

In [None]:
dataset_lang8 = DatasetDict({
  'train': Dataset.from_pandas(lang8_df)
})

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

In [None]:
def tokenize_function(examples):
    inputs = ["grammar correction: " + x for x in examples["text"]] # Add special tokens for grammar correction.
    targets = [x for x in examples["corrected_text"]] # Set target values.

    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=128, truncation=True, padding=True,
    )

    # Setup the tokenizer for targets. Add labels.
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True, padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_lang8 = dataset_lang8.map(tokenize_function, batched=True)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./grammar_correction_results",
    per_device_train_batch_size=8,  # Adjust batch size
    per_device_eval_batch_size=8, # Adjust batch size
    num_train_epochs=5,        # Adjust epochs
    predict_with_generate=True,  # Required for text generation
    eval_strategy = "epoch", # Evaluate every epoch
    save_strategy="epoch",     # Save after every epoch
    load_best_model_at_end=True, # Load best model at end.
    metric_for_best_model="bleu", #Use bleu to measure performance.
    # ... other training arguments (learning rate, warmup steps, etc.)
)

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]


    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
metric = load_metric("sacrebleu")

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_lang8["train"], # Access the train split of the tokenized data.
    eval_dataset=tokenized_lang8["validation"],  # Use the appropriate split for evaluation.
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
best_model_checkpoint = trainer.state.best_model_checkpoint
best_model = AutoModelForSeq2SeqLM.from_pretrained(best_model_checkpoint)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
best_model.to(device)

In [None]:
def predict_in_batches_seq2seq(model, dataset, batch_size=8):
    all_predictions = []
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i : i + batch_size]
        input_ids = torch.tensor(batch["input_ids"]).to(device)
        attention_mask = torch.tensor(batch["attention_mask"]).to(device)
        batch_input = {"input_ids": input_ids, "attention_mask": attention_mask}
        with torch.no_grad():
            generated_tokens = model.generate(**batch_input, max_length=128)  # Adjust max_length as needed
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        all_predictions.extend(decoded_preds) #Extend, not append.

    return all_predictions

In [None]:
predictions = predict_in_batches_seq2seq(best_model, tokenized_lang8["test"])

# Extract the true labels from the TEST set. They need to be decoded as well to enable bleu score calculation.
labels = np.where(tokenized_lang8['test']['labels'] != -100, tokenized_lang8['test']['labels'], tokenizer.pad_token_id)
true_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)


#Evaluate using the metric
test_metrics = metric.compute(predictions=predictions, references=[[label] for label in true_labels])

#Print the bleu score.
print(f"Bleu score: {test_metrics['score']}")

In [None]:
best_model.save_pretrained("./grammar_model")
torch.save(best_model.state_dict(), "./cola_best_bert_model.pt")

In [None]:
def predict_single_sentence(sentence, model, tokenizer):
    """
    Corrects a single sentence using the fine-tuned grammar correction model.

    Args:
        sentence (str): The sentence to correct.
        model: The fine-tuned grammar correction model.
        tokenizer: The tokenizer used for the model.

    Returns:
        str: The corrected sentence.
    """
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    inputs = tokenizer("grammar correction: " + sentence, padding="max_length", truncation=True, return_tensors="pt").to(device)  # Add the special tokens
    with torch.no_grad():
        generated_tokens = model.generate(**inputs, max_length=128) # Adjust the max_length if necessary
    corrected_sentence = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) # Decode the first generated sentence
    return corrected_sentence



# Example usage (after loading your best model and tokenizer)
example_sentence = "I am go to the store yesterday."  # Example incorrect sentence
corrected_sentence = predict_single_sentence(example_sentence, best_model, tokenizer)
print(f"Original: {example_sentence}")
print(f"Corrected: {corrected_sentence}")