# Arabic GEC Pipeline (Kaggle Optimized)

This notebook implements the QALB 2014  correction pipeline, adapted for Kaggle Kernels.

**Steps:**
1.  Setup & Data Download
2.  M2 Format Parsing (Train + Dev)
3.  Fine-tuning AraT5 on Kaggle GPU
4.  Inference & Export

**Note:** Ensure you have selected **GPU T4 x2** or **P100** from the accelerator menu in Kaggle settings.

In [1]:
# Clean environment (optional, but good for retries)
import os

if os.path.exists("/kaggle/working/qalb_dataset.zip"):
    print("Previous execution detected. You might want to skip download.")
else:
    print("Fresh environment.")

# Make sure we are in the working directory
os.chdir("/kaggle/working")

Fresh environment.


In [2]:
# Install necessary libraries
!pip install transformers datasets pyarabic gdown sentencepiece evaluate sacrebleu

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting sacrebleu
  Downloading sacrebleu-2.6.0-py3-none-any.whl.metadata (39 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Collecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sacrebleu-2.6.0-py3-none-any.whl (100 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.8/100.8 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading portalock

## 1. Data Download
Downloads the dataset directly to the ephemeral storage.

In [3]:
import gdown
import zipfile
import os

# Download the file from Google Drive (Public Link)
file_id = '1hvLiiMvvubyCEAZK4KIWgu7qHBNCHOp-'
url = f'https://drive.google.com/uc?id={file_id}'
output_file = 'qalb_dataset.zip'

# Only download if not exists
if not os.path.exists(output_file):
    print("Downloading dataset...")
    gdown.download(url, output_file, quiet=False)

# Unzip the file
if os.path.exists(output_file) and not os.path.exists("QALB-0.9.1-Dec03-2021-SharedTasks"):
    # Extract to current directory to find it easily
    with zipfile.ZipFile(output_file, 'r') as zip_ref:
        zip_ref.extractall(".")
    print("Dataset extracted.")
else:
    print("Dataset already extracted or download failed.")

Downloading dataset...


Downloading...
From (original): https://drive.google.com/uc?id=1hvLiiMvvubyCEAZK4KIWgu7qHBNCHOp-
From (redirected): https://drive.google.com/uc?id=1hvLiiMvvubyCEAZK4KIWgu7qHBNCHOp-&confirm=t&uuid=081305af-d859-4c85-96f4-00ffb99e53ef
To: /kaggle/working/qalb_dataset.zip
100%|██████████| 94.3M/94.3M [00:00<00:00, 102MB/s] 


Dataset extracted.


## 2. Step 1: The M2 Parser

In [4]:
import csv
import os

def parse_m2_and_generate_csv(m2_paths, output_csv_path):
    all_processed_data = []

    for m2_path in m2_paths:
        print(f"Processing {m2_path}...")
        if not os.path.exists(m2_path):
            print(f"File not found: {m2_path}")
            continue

        with open(m2_path, 'r', encoding='utf-8') as f:
            m2_data = f.read().strip().split("\n\n")

        for entry in m2_data:
            lines = entry.split("\n")
            if not lines:
                continue

            # The first line starts with 'S' and contains the original sentence (tokenized)
            source_line = lines[0]
            if not source_line.startswith("S "):
                continue

            original_tokens = source_line[2:].split()
            edits = []

            # Subsequent lines start with 'A' and contain edits
            for line in lines[1:]:
                if line.startswith("A "):
                    parts = line[2:].split("||")
                    # Format: A start_off end_off||type||correction||...
                    span = parts[0].split()
                    start_off = int(span[0])
                    end_off = int(span[1])

                    # Clean up correction: Remove remaining '|' from split and whitespace
                    correction = parts[2].replace("|", "").strip()

                    edits.append((start_off, end_off, correction))

            # Critical Reversal Logic: Sort edits by start_off in descending order
            edits.sort(key=lambda x: x[0], reverse=True)

            corrected_tokens = list(original_tokens)
            for start, end, subst in edits:
                if subst == "-NONE-":
                    replacement = []
                else:
                    replacement = subst.split()

                corrected_tokens[start:end] = replacement

            original_sent = " ".join(original_tokens)
            corrected_sent = " ".join(corrected_tokens)

            all_processed_data.append([original_sent, corrected_sent])

    # Save to CSV
    with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["incorrect", "correct"])
        writer.writerows(all_processed_data)

    print(f"Saved {len(all_processed_data)} pairs to {output_csv_path}")

# Find ALL M2 files (Train + Dev)
m2_files_found = []
for root, dirs, files in os.walk("."):
    for file in files:
        # Look for both Train and Dev files to get maximum data
        # Filter for 2014 specific files as requested
        if file.endswith(".m2") and ("Train" in file or "Dev" in file) and "2014" in file:
             m2_files_found.append(os.path.join(root, file))

if m2_files_found:
    parse_m2_and_generate_csv(m2_files_found, "qalb_full_gec.csv")
else:
    print("No 2014 M2 files found! Check dataset extraction.")

Processing ./QALB-0.9.1-Dec03-2021-SharedTasks/data/2014/train/QALB-2014-L1-Train.m2...
Processing ./QALB-0.9.1-Dec03-2021-SharedTasks/data/2014/dev/QALB-2014-L1-Dev.m2...
Saved 20428 pairs to qalb_full_gec.csv


## 3. Step 2: Model Training (AraT5)
Optimized for Kaggle Kernels (Local Output, No Drive Mounting).

In [5]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from transformers.trainer_utils import get_last_checkpoint
from datasets import load_dataset
import os
import evaluate
import numpy as np
import gdown

# Disable WandB explicitly
os.environ["WANDB_DISABLED"] = "true"

def run_training_step():
    # --- Output Path (Kaggle Working Directory) ---
    output_dir = "./arat5-gec-checkpoints-kaggle"
    print(f"Checkpoints will be saved locally to: {output_dir}")

    if not os.path.exists('qalb_full_gec.csv'):
        print("Training data 'qalb_full_gec.csv' not found. Please run Step 1 Parser first.")
        return

    # --- 1. Checkpoint Download & Setup ---
    # NOTE: On Kaggle, you might want to skip downloading a previous checkpoint and start fresh
    # unless you have uploaded it as a Kaggle Dataset.
    # Below retrieves the base Arat5 model or your checkpoint logic if needed.
    
    start_checkpoint_url = "https://drive.google.com/drive/folders/1Mf8XO-LgdFKgud0OoCFU1j9o9x0uB93N?usp=sharing"
    start_checkpoint_name = "checkpoint-7000"
    download_dir = "downloaded_starting_checkpoint" 
    target_start_path = os.path.join(download_dir, start_checkpoint_name)
    
    # Download if needed
    if not os.path.exists(target_start_path):
        print(f"Downloading starting checkpoint folder from Drive...")
        try:
             gdown.download_folder(start_checkpoint_url, output=download_dir, quiet=False)
        except Exception as e:
            print(f"Warning: Failed to download checkpoint: {e}")

    # --- 2. Determine Model Source ---
    # Default base model
    model_name = "UBC-NLP/AraT5v2-base-1024"
    resume_path = None
    
    if os.path.exists(target_start_path):
        print(f"Found requested checkpoint at: {target_start_path}")
        # Check if it has trainer state (for resuming) or just weights (for initializing)
        if os.path.exists(os.path.join(target_start_path, "trainer_state.json")):
            print("Detected 'trainer_state.json'. Will RESUME training state from this checkpoint.")
            resume_path = target_start_path
        else:
            print("No 'trainer_state.json' found. Will LOAD WEIGHTS from this checkpoint and start fresh training.")
            model_name = target_start_path

    # --- Data Loading ---
    dataset = load_dataset('csv', data_files='qalb_full_gec.csv')
    dataset = dataset['train'].train_test_split(test_size=0.1)

    # --- Model Init ---
    print(f"Initializing model from: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    # --- Preprocessing ---
    prefix = "gec_arabic: "
    max_input_length = 256
    max_target_length = 256

    def preprocess_function(examples):
        inputs = [prefix + (ex if ex else "") for ex in examples["incorrect"]]
        targets = [(ex if ex else "") for ex in examples["correct"]]

        model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets, max_length=max_target_length, truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    tokenized_datasets = dataset.map(preprocess_function, batched=True)

    # --- Training Config ---
    batch_size = 2 # T4 can handle this with accumulation

    args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        eval_strategy = "steps",
        eval_steps = 1000,
        save_strategy = "steps",
        save_steps = 1000,
        learning_rate=3e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=4,
        weight_decay=0.01,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        num_train_epochs=6,
        predict_with_generate=False, # DISABLED generation during eval to save time
        fp16=True, # Enable mixed precision for T4/P100
        push_to_hub=False,
        report_to="none",
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    trainer = Seq2SeqTrainer(
        model,
        args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        processing_class=tokenizer,
    )

    # --- Resume Logic ---
    if resume_path is None:
        if os.path.exists(output_dir):
            last_checkpoint = get_last_checkpoint(output_dir)
            if last_checkpoint:
                print(f"Found newer progress in output directory. Resuming from: {last_checkpoint}")
                resume_path = last_checkpoint

    print(f"Starting training... (Resume: {resume_path})")
    trainer.train(resume_from_checkpoint=resume_path)

    # Save Final
    final_path = os.path.join(output_dir, "arat5-gec-finetuned")
    model.save_pretrained(final_path)
    tokenizer.save_pretrained(final_path)
    print(f"Best model saved to {final_path}")

# Run
run_training_step()

2026-01-12 21:56:32.248092: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768254992.655765      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768254992.779524      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Checkpoints will be saved locally to: ./arat5-gec-checkpoints-kaggle
Downloading starting checkpoint folder from Drive...


Retrieving folder contents


Retrieving folder 1wXFxC0uN4lkbYGVCDNftaSGkeOE7sSCb checkpoint-7000
Processing file 12XRrAGmXGOBUjxP-j9kRuWThymg3wfpP config.json
Processing file 1_6mHmbf2iDso3n-x_4fKPfJhX6eeHt5c generation_config.json
Processing file 10sSN6ZJzzEKIobZlaFzra-0CvlOduBXh model.safetensors
Processing file 1Ei-O-G1t7zcon-eV8HJRWUvpule0YRnN special_tokens_map.json
Processing file 11HbVmq-6tT2NYjDI15kTdf9FeWKwPFlL spiece.model
Processing file 18DU2qJ8AUvBf_gVPdmaC7u4m7z5KfWlh tokenizer_config.json
Processing file 1fZ3egkBxc8hvs0bPrbtf769Tu7EnWeCx tokenizer.json


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=12XRrAGmXGOBUjxP-j9kRuWThymg3wfpP
To: /kaggle/working/downloaded_starting_checkpoint/checkpoint-7000/config.json
100%|██████████| 781/781 [00:00<00:00, 2.58MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_6mHmbf2iDso3n-x_4fKPfJhX6eeHt5c
To: /kaggle/working/downloaded_starting_checkpoint/checkpoint-7000/generation_config.json
100%|██████████| 122/122 [00:00<00:00, 473kB/s]
Downloading...
From (original): https://drive.google.com/uc?id=10sSN6ZJzzEKIobZlaFzra-0CvlOduBXh
From (redirected): https://drive.google.com/uc?id=10sSN6ZJzzEKIobZlaFzra-0CvlOduBXh&confirm=t&uuid=af6d8353-1e20-4b00-a05c-95985bff0843
To: /kaggle/working/downloaded_starting_checkpoint/checkpoint-7000/model.safetensors
100%|██████████| 1.47G/1.47G [00:11<00:00, 132MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1Ei-O-G1t7zcon-eV8HJRWUvpule0YRnN


Found requested checkpoint at: downloaded_starting_checkpoint/checkpoint-7000
No 'trainer_state.json' found. Will LOAD WEIGHTS from this checkpoint and start fresh training.


Generating train split: 0 examples [00:00, ? examples/s]

Initializing model from: downloaded_starting_checkpoint/checkpoint-7000


Map:   0%|          | 0/18385 [00:00<?, ? examples/s]



Map:   0%|          | 0/2043 [00:00<?, ? examples/s]

Starting training... (Resume: None)


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
1000,0.4465,0.316154
2000,0.414,0.316782
3000,0.3892,0.311445
4000,0.3683,0.309877
5000,0.3594,0.310953
6000,0.353,0.30781


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Best model saved to ./arat5-gec-checkpoints-kaggle/arat5-gec-finetuned


## 4. Inference & Export
Test the model and package it for download.

In [9]:
import shutil
import os

def package_for_kaggle_output():
    source_dir = "./arat5-gec-checkpoints-kaggle/arat5-gec-finetuned"
    output_filename = "arat5_gec_model_output"
    
    if os.path.exists(source_dir):
        print("Zipping model for download...")
        shutil.make_archive(output_filename, 'zip', source_dir)
        print(f"✅ Created {output_filename}.zip in /kaggle/working/")
        print("You can download this file from the 'Output' tab on the right sidebar.")
    else:
        print("No fine-tuned model found to zip.")

package_for_kaggle_output()

Zipping model for download...
✅ Created arat5_gec_model_output.zip in /kaggle/working/
You can download this file from the 'Output' tab on the right sidebar.


In [38]:
def run_inference(input_sentence):
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    import torch
    import os
    
    # Path where we saved the model in the previous step
    model_path = "./arat5-gec-checkpoints-kaggle/arat5-gec-finetuned"
    
    # Fallback if training wasn't run, check for downloaded checkpoint-3000
    if not os.path.exists(model_path):
         # Try finding uploaded or downloaded checkpoint
         start_ckpt = "downloaded_starting_checkpoint/checkpoint-3000" # Example
         if os.path.exists(start_ckpt) and ("model.safetensors" in os.listdir(start_ckpt) or "pytorch_model.bin" in os.listdir(start_ckpt)):
             model_path = start_ckpt
         else:
             print("Fine-tuned model not found. Using Base Model for demo.")
             model_path = "UBC-NLP/AraT5v2-base-1024"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading model from: {model_path} on {device}")

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    except Exception as e:
        return f"Error loading model: {e}"

    model = model.to(device)

    # Preprocessing
    prefix = "gec_arabic: "
    text = prefix + input_sentence
    inputs = tokenizer(text, return_tensors="pt", max_length=256, truncation=True).to(device)

    # Generation
    # Switch to Greedy Search (num_beams=1)
    # Why? Beam search often gets stuck in loops for under-trained models because the loop has high probability.
    # Greedy search forces the model to move forward one best-step at a time.
    outputs = model.generate(
        inputs["input_ids"],
        max_length=256,
        num_beams=1, # Greedy decoding
        do_sample=False,
        repetition_penalty=1.2, # Increased to 2.5 (Strong penalty for stuttering)
        no_repeat_ngram_size=3 # Mild penalty
    )

    corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_sentence

# Test
test_sentence = "سئلت رئيسا الوزراء عن شؤن الموظفين واجابو بان المسؤليه تقع علي عاتق الجميع فاستعدو لبدء العمل"
print(f"Original: {test_sentence}")
print(f"Corrected: {run_inference(test_sentence)}")

Original: سئلت رئيسا الوزراء عن شؤن الموظفين واجابو بان المسؤليه تقع علي عاتق الجميع فاستعدو لبدء العمل
Loading model from: ./arat5-gec-checkpoints-kaggle/arat5-gec-finetuned on cuda
Corrected: سئلت رئيسا الوزراء عن شؤون الموظفين وأجابوا واجيبو بأن المسؤولية تقع على عاتق الجميع فاستعدوا لبدء العمل .


In [40]:
def add_tashkeel(text):
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    import torch

    # UPDATED: Using a valid public model from Hugging Face
    # 'Abdou/arabic-tashkeel-flan-t5-small' (75MB) - highly efficient
    tashkeel_model_name = "Abdou/arabic-tashkeel-flan-t5-small"
    
    print(f"Loading Tashkeel model: {tashkeel_model_name}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(tashkeel_model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(tashkeel_model_name)
    except Exception as e:
        return f"Error loading tashkeel model: {e}"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # Prepare input (Model expects raw text without prefix)
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)

    # Generate diacritized text
    outputs = model.generate(
        inputs["input_ids"],
        max_length=512,
        num_beams=4, # Recommended setting for this model
        early_stopping=True,
        repetition_penalty=2.5, # Increased to 2.5 (Strong penalty for stuttering)
        no_repeat_ngram_size=3
    )

    diacritized_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return diacritized_text

# --- Full Pipeline Test ---
test_input = " سئلت رئيسا الوزراء عن شؤن الموظفين واجابو بان المسؤليه تقع علي عاتق الجميع فاسبعدو لبدء العم"
print(f"Input:     {test_input}")

# 1. GEC (Correction)
gec_output = run_inference(test_input)
print(f"Corrected: {gec_output}")

# 2. Tashkeel (Diacritization)
final_output = add_tashkeel(gec_output)
print(f"Tashkeel:  {final_output}")

Input:      سئلت رئيسا الوزراء عن شؤن الموظفين واجابو بان المسؤليه تقع علي عاتق الجميع فاسبعدو لبدء العم
Loading model from: ./arat5-gec-checkpoints-kaggle/arat5-gec-finetuned on cuda
Corrected: سئلت رئيسا الوزراء عن شؤون الموظفين وأجابوا واجيبو بأن المسؤولية تقع على عاتق الجميع فاسبعدوا لبدء العمل .
Loading Tashkeel model: Abdou/arabic-tashkeel-flan-t5-small...
Tashkeel:  سُئِلَتْ رَئِيسًا الْوُزَرَاءُ عَنْ شُؤُونِ الْمُوَظَّفِينَ وَأَجَابُوا وَاجِيبُو بِأَنَّ الْمَسْؤُولِيَّةَ تَقَعُ عَلَى عَاتِقِ الْجَمِيعِ فَاسْبِعُوا لِبَدْءِ الْعَمَلِ .
