In [1]:
# cell: mount_drive
from google.colab import drive
drive.mount('/content/drive')


# Path to your dataset folder on Google Drive
DATA_ROOT = '/content/drive/MyDrive/dataset'
print('Drive mounted and dataset path set to:', DATA_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted and dataset path set to: /content/drive/MyDrive/dataset


In [2]:
# cell: install_deps
!pip install --upgrade pip
!pip install transformers==4.44.2 datasets accelerate evaluate ftfy regex sentencepiece pillow torchvision timm einops



In [3]:
!apt-get -q install poppler-utils

Reading package lists...
Building dependency tree...
Reading state information...
poppler-utils is already the newest version (22.02.0-2ubuntu0.12).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.


In [4]:
!pip install pdf2image



In [5]:
import os
import json
import math
import random
import shutil
import glob
from pathlib import Path
from typing import List, Dict

import torch
from PIL import Image,ImageFile
from pdf2image import convert_from_path
from tqdm.auto import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import (
    AutoProcessor,
    VisionEncoderDecoderModel,
    TrainingArguments,
    Trainer,
    default_data_collator,
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType

In [6]:
# ---------------- CONFIG ----------------
GDRIVE_ROOT = "/content/drive/MyDrive"   # change if needed
DATASET_DRIVE_PATH = f"{GDRIVE_ROOT}/dataset"
OUTPUT_DIR = f"{GDRIVE_ROOT}/donut_peft_lora_output"  # where to save adapters and checkpoints
LOG_DIR = f"{OUTPUT_DIR}/logs"
print("DATASET_DRIVE_PATH =", DATASET_DRIVE_PATH)
print("OUTPUT_DIR =", OUTPUT_DIR)
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"  # HF model to fine-tune
TASK_TAG = "parse"

# Paths (modify if needed)
DATASET_ROOT = os.environ.get("DATASET_ROOT", DATASET_DRIVE_PATH)
# expected:
# DATASET_ROOT/train/images/*
# DATASET_ROOT/train/metadata.jsonl
# DATASET_ROOT/val/images/*
# DATASET_ROOT/val/metadata.jsonl

OUTPUT_DIR = os.environ.get("OUTPUT_DIR",OUTPUT_DIR)

DATASET_DRIVE_PATH = /content/drive/MyDrive/dataset
OUTPUT_DIR = /content/drive/MyDrive/donut_peft_lora_output


In [7]:
NUM_EPOCHS = 25
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 2
LEARNING_RATE = 5e-5
MAX_TARGET_LENGTH = 512
SAVE_STEPS = 100        # checkpoint frequency (steps)
SAVE_TOTAL_LIMIT = 5
SEED = 42
LOGGING_STEPS = 10
# LoRA config
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

In [8]:
# PDF -> image conversion options
PDF_DPI = 200  # good tradeoff for invoices

# Utility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Device:", device)
print("Output dir:", OUTPUT_DIR)

Device: cuda
Output dir: /content/drive/MyDrive/donut_peft_lora_output


In [9]:
# ---------------- helpers ----------------
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def read_jsonl(meta_path: str) -> List[Dict]:
    out = []
    with open(meta_path, "r", encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line:
                continue
            out.append(json.loads(line))
    return out

In [10]:
def convert_pdf_to_images(pdf_path: str, out_folder: str, dpi:int = PDF_DPI) -> List[str]:
    """
    Convert a PDF into PNG pages in out_folder.
    Return list of generated image file paths.
    """
    ensure_dir(out_folder)
    pages = convert_from_path(pdf_path, dpi=dpi)
    out_paths = []
    base = Path(pdf_path).stem
    for i, page in enumerate(pages, start=1):
        out_name = f"{base}_page_{i}.png"
        out_path = os.path.join(out_folder, out_name)
        page.save(out_path, "PNG")
        out_paths.append(out_path)
    return out_paths

In [11]:
def prepare_files_and_records(root: str, split: str, tmp_processed_folder: str):
    """
    For each file in split/images/ (png/jpg/jpeg/pdf), copy/convert into tmp_processed_folder/<split>/images/
    Return a list of dicts: {"image": "<filename>", "image_path": "<abs path>", "ground_truth": "<wrapped JSON string>"}
    """
    split_img_folder = os.path.join(root, split, "images")
    meta_path = os.path.join(root, split, "metadata.jsonl")
    assert os.path.exists(split_img_folder), f"{split_img_folder} not found"
    assert os.path.exists(meta_path), f"{meta_path} not found"

    metadata = read_jsonl(meta_path)
    # build dict for quick lookup by filename
    meta_by_name = { rec["image"]: rec["ground_truth"] for rec in metadata }

    processed_images_dir = os.path.join(tmp_processed_folder, split, "images")
    ensure_dir(processed_images_dir)

    records = []

    for fname, gt in meta_by_name.items():
        src_path = os.path.join(split_img_folder, fname)
        if not os.path.exists(src_path):
            print(f"‚ö†Ô∏è Warning: {src_path} not found. Skipping.")
            continue

        lower = fname.lower()
        if lower.endswith(".pdf"):
            # convert pages
            pages = convert_pdf_to_images(src_path, processed_images_dir)
            for p in pages:
                new_fname = os.path.basename(p)
                wrapped = gt
                if not (wrapped.startswith(f"<{TASK_TAG}>") and wrapped.endswith(f"</{TASK_TAG}>")):
                    wrapped = f"<{TASK_TAG}>{wrapped}</{TASK_TAG}>"
                records.append({"image": new_fname, "image_path": p, "ground_truth": wrapped})
        elif lower.endswith((".png", ".jpg", ".jpeg",".tiff")):
            dst = os.path.join(processed_images_dir, fname)
            shutil.copy(src_path, dst)
            wrapped = gt
            if not (wrapped.startswith(f"<{TASK_TAG}>") and wrapped.endswith(f"</{TASK_TAG}>")):
                wrapped = f"<{TASK_TAG}>{wrapped}</{TASK_TAG}>"
            records.append({"image": fname, "image_path": dst, "ground_truth": wrapped})
        else:
            print(f"‚ö†Ô∏è Unsupported file type: {src_path}. Skipping.")
    return records

In [12]:
# ---------------- Prepare processed dataset (pdf->png) ----------------
TMP_PROC = "/tmp/donut_proc"
if os.path.exists(TMP_PROC):
    shutil.rmtree(TMP_PROC)
ensure_dir(TMP_PROC)

print("Preparing dataset (converting PDFs to PNGs where needed)...")
train_records = prepare_files_and_records(DATASET_ROOT, "train", TMP_PROC)
val_records   = prepare_files_and_records(DATASET_ROOT, "val", TMP_PROC)
print(f"Train examples (pages/images): {len(train_records)}")
print(f"Val examples (pages/images): {len(val_records)}")

Preparing dataset (converting PDFs to PNGs where needed)...
Train examples (pages/images): 228
Val examples (pages/images): 34


In [13]:
# ---------------- Load processor & model ----------------
print("Loading processor and model:", MODEL_NAME)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

# Ensure tokenizer exists
if not hasattr(processor, "tokenizer"):
    raise RuntimeError("Processor does not contain tokenizer; cannot proceed.")

tokenizer = processor.tokenizer

# model config tweaks
model.config.max_length = MAX_TARGET_LENGTH
model.config.decoder_start_token_id = tokenizer.cls_token_id or tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id or tokenizer.eos_token_id

Loading processor and model: naver-clova-ix/donut-base-finetuned-cord-v2




In [14]:
# Optionally shuffle training records
random.seed(SEED)
random.shuffle(train_records)

# Build HuggingFace Datasets
train_ds = Dataset.from_list(train_records)
val_ds = Dataset.from_list(val_records)

In [15]:
def safe_load_image(path: str):
    try:
        im = Image.open(path).convert("RGB")
        im.verify()  # check for integrity
        # reopen because verify() closes the file handle
        im = Image.open(path).convert("RGB")
        return im
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping bad image: {path} ({e})")
        return None

In [16]:
def preprocess_batch(examples):
    """
    Robust preprocessing for Donut fine-tuning.
    - Safely loads each image (PNG/JPG)
    - Skips unreadable/corrupted files instead of crashing
    - Returns pixel_values + tokenized labels
    """
    images = []
    valid_ground_truths = []

    # Safely load each image
    for img_path, gt in zip(examples["image_path"], examples["ground_truth"]):
        im = safe_load_image(img_path)
        if im is not None:
            images.append(im)
            valid_ground_truths.append(gt)
        else:
            # skip this example if image is unreadable
            continue

    if not images:
        # If every image in the batch failed, return an empty batch safely
        return {"pixel_values": [], "labels": []}

    # Process images -> pixel values
    encodings = processor(images=images, return_tensors="pt")

    # Tokenize target JSON text
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            valid_ground_truths,
            padding="max_length",
            truncation=True,
            max_length=MAX_TARGET_LENGTH,
        )

    # Convert torch tensors to lists (datasets.map expects list of lists)
    out = {
        "pixel_values": [pv.tolist() for pv in encodings["pixel_values"]],
        "labels": labels["input_ids"],
        "image": [os.path.basename(p) for p in examples["image_path"][:len(images)]],
        "image_path": examples["image_path"][:len(images)],
    }

    return out

print("Mapping preprocess (this may take a little while)...")
train_ds = train_ds.map(preprocess_batch, batched=True, batch_size=4, remove_columns=train_ds.column_names)
val_ds = val_ds.map(preprocess_batch, batched=True, batch_size=4, remove_columns=val_ds.column_names)

# Set format to torch for keys that will be used
def convert_to_torch_format(ds):
    # keep `pixel_values` and `labels` as lists; we'll collate later
    return ds

train_ds = convert_to_torch_format(train_ds)
val_ds = convert_to_torch_format(val_ds)

Mapping preprocess (this may take a little while)...


Map:   0%|          | 0/228 [00:00<?, ? examples/s]



‚ö†Ô∏è Skipping bad image: /tmp/donut_proc/train/images/Manpower bill.tiff (decoder error -2)


Map:   0%|          | 0/34 [00:00<?, ? examples/s]

In [17]:
# ---------------- Apply LoRA with PEFT ----------------
print("Freezing encoder parameters...")
for p in model.encoder.parameters():
    p.requires_grad = False

print("Configuring LoRA...")
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "dense", "fc_out"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    #task_type=TaskType.SEQ_2_SEQ_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.to(device)

Freezing encoder parameters...
Configuring LoRA...
trainable params: 1,351,680 || all params: 202,473,592 || trainable%: 0.6676


PeftModel(
  (base_model): LoraModel(
    (model): VisionEncoderDecoderModel(
      (encoder): DonutSwinModel(
        (embeddings): DonutSwinEmbeddings(
          (patch_embeddings): DonutSwinPatchEmbeddings(
            (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): DonutSwinEncoder(
          (layers): ModuleList(
            (0): DonutSwinStage(
              (blocks): ModuleList(
                (0-1): 2 x DonutSwinLayer(
                  (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
                  (attention): DonutSwinAttention(
                    (self): DonutSwinSelfAttention(
                      (query): Linear(in_features=128, out_features=128, bias=True)
                      (key): Linear(in_features=128, out_features=128, bias=True)
                      (valu

In [18]:
def collate_fn(batch):
    # batch: list of examples where pixel_values is a nested list and labels is list
    pixel_values = [torch.tensor(x["pixel_values"]) for x in batch]
    pixel_values = torch.stack([
        pv.squeeze(0) if pv.ndim == 4 and pv.shape[0] == 1 else pv
        for pv in pixel_values
    ])  # keep on CPU

    labels = [torch.tensor(x["labels"]) for x in batch]
    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=tokenizer.pad_token_id
    )  # keep on CPU

    # DO NOT move to device here ‚Äî Trainer will do it
    return {"pixel_values": pixel_values, "labels": labels}

In [19]:
# ---------------- TrainingArguments & Trainer ----------------
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_TOTAL_LIMIT,
    evaluation_strategy="steps" if len(val_ds) > 0 else "no",
    eval_steps=SAVE_STEPS if len(val_ds) > 0 else None,
    remove_unused_columns=False,
    fp16=torch.cuda.is_available(),
    push_to_hub=False,
    load_best_model_at_end=False,
    report_to="none",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds if len(val_ds) > 0 else None,
    data_collator=collate_fn,
)



In [20]:
# ---------------- Training run ----------------
if __name__ == "__main__":
    print("Starting training on device:", device)
    trainer.train()
    print("Training complete ‚Äî saving final model & processor...")
    model.save_pretrained(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    print("Saved at:", OUTPUT_DIR)

    # simple inference helper
    def infer(image_path: str, max_length: int = MAX_TARGET_LENGTH):
        im = Image.open(image_path).convert("RGB")
        inputs = processor(images=im, return_tensors="pt").to(device)
        generate_kwargs = dict(max_length=max_length, num_beams=1)
        generated = model.generate(pixel_values=inputs["pixel_values"].to(device), **generate_kwargs)
        out = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
        # out contains the <parse>...</parse> text
        return out

    # quick test inference on a validation image
    if len(val_records) > 0:
        sample = val_records[0]
        print("Example inference on:", sample["image_path"])
        print(infer(sample["image_path"]))

Starting training on device: cuda


Step,Training Loss,Validation Loss
100,5.7871,No log
200,4.9832,No log
300,4.8639,No log
400,3.9623,No log
500,3.9933,No log
600,3.6161,No log
700,4.1413,No log
800,3.5681,No log


KeyboardInterrupt: 