<a href="https://colab.research.google.com/github/moksha-hub/RenAIssance-OCR/blob/main/Untitled4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!sudo apt-get -y install tree


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tree is already the newest version (2.0.2-1).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.


In [8]:
#############################################
# 0. (Optional) Mount Google Drive
#############################################
# If you have already mounted your Drive, you can comment out these lines:
from google.colab import drive
drive.mount('/content/drive')

#############################################
# 1. Install Dependencies
#############################################
!sudo apt-get -y install poppler-utils
!pip install pdf2image python-docx PyPDF2 opencv-python

#############################################
# 2. Imports and Utility Functions
#############################################
import os
import re
import gc
import cv2
import docx
import numpy as np
from PyPDF2 import PdfReader
from pdf2image import convert_from_path

def read_transcriptions_from_docx(docx_path):
    """
    Reads a .docx file and returns a string containing all paragraphs joined.
    """
    doc = docx.Document(docx_path)
    paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
    return "\n".join(paragraphs)

def convert_pdf_to_images_single_page(pdf_path, output_folder, dpi=150):
    """
    Converts a large PDF to PNG images one page at a time, reducing memory usage.
    Default dpi=150 to further reduce memory usage. Adjust as needed.
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder, exist_ok=True)

    reader = PdfReader(pdf_path)
    total_pages = len(reader.pages)
    print(f"[Single-Page Mode] Processing '{os.path.basename(pdf_path)}' with {total_pages} pages...")

    for page_num in range(1, total_pages + 1):
        pages = convert_from_path(
            pdf_path,
            dpi=dpi,
            first_page=page_num,
            last_page=page_num
        )
        page_image = pages[0]
        out_path = os.path.join(output_folder, f"page_{page_num}.png")
        page_image.save(out_path, "PNG")

        # Free memory for this page
        del page_image
        del pages
        gc.collect()

    print(f"Converted '{os.path.basename(pdf_path)}' into {total_pages} images in '{output_folder}'")

def convert_pdf_to_images_all_at_once(pdf_path, output_folder, dpi=300):
    """
    Converts all pages of a PDF in one pass. Higher DPI=300 for good quality.
    Use this for smaller PDFs to speed up processing.
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder, exist_ok=True)

    pages = convert_from_path(pdf_path, dpi=dpi)
    for idx, page in enumerate(pages):
        out_path = os.path.join(output_folder, f"page_{idx+1}.png")
        page.save(out_path, "PNG")

    print(f"Converted '{os.path.basename(pdf_path)}' into {len(pages)} images in '{output_folder}'")

def segment_into_lines(page_image, threshold=10):
    """
    Splits a page image into multiple line images using a simple horizontal projection approach.
    Returns a list of cropped line images.
    threshold=10 means a row must have at least 10 black pixels to be considered text.
    Adjust as needed for your documents.
    """
    gray = cv2.cvtColor(page_image, cv2.COLOR_BGR2GRAY)
    # Otsu binarization
    _, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    # Sum horizontally. If text is white on black, do: bw = 255 - bw
    horizontal_sum = np.sum(bw == 0, axis=1)

    in_line = False
    start_idx = 0
    line_indices = []

    for i, val in enumerate(horizontal_sum):
        if val > threshold and not in_line:
            in_line = True
            start_idx = i
        elif val <= threshold and in_line:
            in_line = False
            end_idx = i
            line_indices.append((start_idx, end_idx))

    line_images = []
    for (start, end) in line_indices:
        line_img = page_image[start:end, :]  # entire width
        line_images.append(line_img)
    return line_images

#############################################
# 3. Splitting Logic: PDF -> Page Images -> Line Images
#############################################
def split_page_images_into_lines(page_images_folder):
    """
    For each page image in page_images_folder, segment it into lines and store them in a 'lines/' subfolder.
    e.g. page_1.png -> lines/page_1_line_0.png, page_1_line_1.png, ...
    """
    lines_folder = os.path.join(page_images_folder, "lines")
    os.makedirs(lines_folder, exist_ok=True)

    for file_name in sorted(os.listdir(page_images_folder)):
        if file_name.lower().endswith((".png", ".jpg", ".jpeg")) and not file_name.startswith("line_"):
            page_path = os.path.join(page_images_folder, file_name)
            page_img = cv2.imread(page_path, cv2.IMREAD_COLOR)
            if page_img is None:
                continue

            lines = segment_into_lines(page_img, threshold=10)
            base_name = os.path.splitext(file_name)[0]
            for idx, line_img in enumerate(lines):
                line_path = os.path.join(lines_folder, f"{base_name}_line_{idx}.png")
                cv2.imwrite(line_path, line_img)
            print(f"Split '{file_name}' into {len(lines)} lines in '{lines_folder}'")

#############################################
# 4. Main Function to Create a Dataset with Further Splitting
#############################################
def create_dataset_with_line_splitting(
    sources_folder,
    transcriptions_folder,
    local_dataset_dir,
    size_threshold_mb=20,
    single_page_dpi=150,
    all_at_once_dpi=300
):
    """
    1) For each PDF in sources_folder:
       - Create folder in local_dataset_dir with same base name
       - Convert PDF to page images
       - Attempt to match .docx by base name overlap
       - Read docx, save to transcription.txt
       - Then further split each page image into lines in a lines/ subfolder
    """
    import docx

    if not os.path.exists(local_dataset_dir):
        os.makedirs(local_dataset_dir, exist_ok=True)

    docx_files = [f for f in os.listdir(transcriptions_folder) if f.lower().endswith(".docx")]

    def find_best_match_docx(base_name, docx_files):
        base_words = re.findall(r"\w+", base_name.lower())
        best_file = None
        best_score = 0
        for docx_file in docx_files:
            docx_stem = os.path.splitext(docx_file)[0].lower()
            docx_words = re.findall(r"\w+", docx_stem)
            overlap = len(set(base_words).intersection(docx_words))
            if overlap > best_score:
                best_score = overlap
                best_file = docx_file
        return best_file if best_score > 0 else None

    for file_name in os.listdir(sources_folder):
        if file_name.lower().endswith(".pdf"):
            base_name = os.path.splitext(file_name)[0]
            pdf_path = os.path.join(sources_folder, file_name)

            source_folder = os.path.join(local_dataset_dir, base_name)
            images_folder = os.path.join(source_folder, "images")
            os.makedirs(images_folder, exist_ok=True)

            pdf_size = os.path.getsize(pdf_path)
            if pdf_size > size_threshold_mb * 1024 * 1024:
                convert_pdf_to_images_single_page(pdf_path, images_folder, dpi=single_page_dpi)
            else:
                convert_pdf_to_images_all_at_once(pdf_path, images_folder, dpi=all_at_once_dpi)

            # find best matching docx
            matched_docx = find_best_match_docx(base_name, docx_files)
            if matched_docx:
                docx_path = os.path.join(transcriptions_folder, matched_docx)
                doc = docx.Document(docx_path)
                paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
                transcription_text = "\n".join(paragraphs)

                transcription_out_path = os.path.join(source_folder, "transcription.txt")
                with open(transcription_out_path, "w", encoding="utf-8") as txt_file:
                    txt_file.write(transcription_text)
                print(f"Transcription from '{matched_docx}' saved to '{transcription_out_path}'\n")
            else:
                print(f"No matching Word document found for '{file_name}'.\n")

            # Now further split each page image into lines
            split_page_images_into_lines(images_folder)

    print("\nFinal dataset structure in", local_dataset_dir)
    get_ipython().system(f"tree -L 3 {local_dataset_dir}")

#############################################
# 5. Example Usage
#############################################
sources_folder = "/content/drive/MyDrive/RenAI Printed img/Test sources"
transcriptions_folder = "/content/drive/MyDrive/RenAI Printed img/Test transcriptions"
local_dataset_dir = "/content/dataset_with_lines"

create_dataset_with_line_splitting(
    sources_folder,
    transcriptions_folder,
    local_dataset_dir,
    size_threshold_mb=20,
    single_page_dpi=150,
    all_at_once_dpi=300
)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
poppler-utils is already the newest version (22.02.0-2ubuntu0.6).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.
Converted 'Constituciones sinodales Calahorra 1602.pdf' into 6 images in '/content/dataset_with_lines/Constituciones sinodales Calahorra 1602/images'
Transcription from 'Constituciones sinodales transcription.docx' saved to '/content/dataset_with_lines/Constituciones sinodales Calahorra 1602/transcription.txt'

Split 'page_1.png' into 0 lines in '/content/dataset_with_lines/Constituciones sinodales Calahorra 1602/images/lines'
Split 'page_2.png' into 57 lines in '/content/dataset_with_lines/Constituciones sinodales Calahorra 1602/images/lines'
Split 'page_3.png' into 1 lines in '/content/dataset_with_lines/Constituciones sinodales Calah

In [9]:
!pip install paddlepaddle-gpu -f https://www.paddlepaddle.org.cn/whl/mkl/avx/stable.html


Looking in links: https://www.paddlepaddle.org.cn/whl/mkl/avx/stable.html


In [10]:
#############################################
# 0. (Optional) Mount Google Drive
#############################################
from google.colab import drive
drive.mount('/content/drive')

#############################################
# 1. Install Dependencies
#############################################
!sudo apt-get -y install poppler-utils
!pip install pdf2image python-docx PyPDF2 opencv-python rapidfuzz paddleocr

#############################################
# 2. Imports and Utility Functions
#############################################
import os
import re
import gc
import cv2
import docx
import numpy as np
from PyPDF2 import PdfReader
from pdf2image import convert_from_path
from paddleocr import PaddleOCR
from rapidfuzz import fuzz

def read_transcriptions_from_txt(txt_path):
    with open(txt_path, "r", encoding="utf-8") as f:
        return [line.strip() for line in f if line.strip()]

def partial_ocr_line(line_image_path, ocr_engine):
    """
    Run PaddleOCR on a single line image, return recognized text (best guess).
    """
    result = ocr_engine.ocr(line_image_path, rec=True)
    if result and len(result[0]) > 0:
        # result[0][0] => bounding box, result[0][1] => (text, confidence)
        # But for a single line, there's typically 1 bounding box.
        recognized_text = result[0][0][1][0]
    else:
        recognized_text = ""
    return recognized_text

def find_best_line_match(recognized_text, doc_lines):
    """
    Fuzzy match recognized_text to each line in doc_lines, return best line and best score.
    """
    best_score = -1
    best_line = None
    for candidate_line in doc_lines:
        score = fuzz.ratio(recognized_text.lower(), candidate_line.lower())
        if score > best_score:
            best_score = score
            best_line = candidate_line
    return best_line, best_score

def generate_line_level_texts(
    dataset_root,
    paddle_ocr_lang="en"
):
    """
    1) For each <PDF_Name>/images/lines folder in dataset_root,
       read line images (page_1_line_0.png, etc.).
    2) Load the entire transcription.txt from <PDF_Name>/transcription.txt as doc_lines (split by lines).
    3) For each line image, run partial OCR, fuzzy match to doc_lines, and create .txt file.
    """
    ocr_engine = PaddleOCR(lang=paddle_ocr_lang, rec=True, det=False, use_angle_cls=False)

    for folder_name in sorted(os.listdir(dataset_root)):
        folder_path = os.path.join(dataset_root, folder_name)
        if not os.path.isdir(folder_path):
            continue

        # e.g. /content/dataset_with_lines/Mendo - Principe perfecto
        transcription_file = os.path.join(folder_path, "transcription.txt")
        images_folder = os.path.join(folder_path, "images")
        lines_folder = os.path.join(images_folder, "lines")

        if not os.path.exists(transcription_file):
            print(f"No transcription.txt in {folder_name}, skipping.")
            continue
        if not os.path.exists(lines_folder):
            print(f"No lines folder in {folder_name}, skipping.")
            continue

        doc_lines = read_transcriptions_from_txt(transcription_file)
        print(f"\n[{folder_name}] Found {len(doc_lines)} lines in transcription.txt")

        # For each line image in lines_folder, do partial OCR + fuzzy match
        for file_name in sorted(os.listdir(lines_folder)):
            if file_name.lower().endswith((".png", ".jpg", ".jpeg")) and not file_name.endswith(".txt"):
                line_img_path = os.path.join(lines_folder, file_name)
                recognized_text = partial_ocr_line(line_img_path, ocr_engine)

                if recognized_text.strip():
                    best_line, best_score = find_best_line_match(recognized_text, doc_lines)
                else:
                    best_line, best_score = "[NO RECOGNIZED TEXT]", 0

                # Save best_line to .txt
                txt_name = file_name.rsplit(".", 1)[0] + ".txt"
                txt_path = os.path.join(lines_folder, txt_name)
                with open(txt_path, "w", encoding="utf-8") as f:
                    f.write(best_line)

                print(f"  -> {file_name} recognized='{recognized_text[:30]}...' matched='{best_line[:30]}...' score={best_score}")
    print("\nLine-level .txt generation complete.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
poppler-utils is already the newest version (22.02.0-2ubuntu0.6).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.


In [12]:
def partial_ocr_line(line_image_path, ocr_engine):
    """
    Run PaddleOCR on a single line image, return recognized text (best guess).
    Safely handles None or empty results.
    """
    result = ocr_engine.ocr(line_image_path, rec=True)

    # If result is None or empty, return empty string
    if not result or len(result) == 0:
        return ""

    # result is typically a list of lists, e.g. [ [ [box], [text, confidence] ], ... ]
    # So we check if result[0] is valid
    if not result[0] or len(result[0]) == 0:
        return ""

    # Finally, parse recognized text from result
    # Usually: result[0][0] => [ [box coords], [text, confidence] ]
    recognized_text = result[0][0][1][0]
    return recognized_text


In [13]:
ocr_engine = PaddleOCR(lang="es", rec=True, det=False, use_angle_cls=False)
line_img_path = "/content/dataset_with_lines/Mendo - Principe perfecto/images/lines/page_2_line_0.png"
recognized_text = partial_ocr_line(line_img_path, ocr_engine)
print("Recognized text:", recognized_text)


[2025/03/22 16:05:45] ppocr DEBUG: Namespace(help='==SUPPRESS==', use_gpu=True, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False, ir_optim=True, use_tensorrt=False, min_subgraph_size=15, precision='fp32', gpu_mem=500, gpu_id=0, image_dir=None, page_num=0, det_algorithm='DB', det_model_dir='/root/.paddleocr/whl/det/en/en_PP-OCRv3_det_infer', det_limit_side_len=960, det_limit_type='max', det_box_type='quad', det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5, max_batch_size=10, use_dilation=False, det_db_score_mode='fast', det_east_score_thresh=0.8, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5, det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16, det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5, rec_algorithm='SVTR_LCNet', rec_model_dir='/root/.paddleocr/whl/rec/latin/latin_PP-OCRv3_rec_infer', rec_image_inverse=True, rec_image_shape='3, 48, 320', rec_batch_num=6, max_t

In [14]:
def partial_ocr_line(line_image_path, ocr_engine):
    result = ocr_engine.ocr(line_image_path, rec=True)
    if not result or len(result) == 0:
        print(f"WARNING: No OCR result for {line_image_path}")
        return ""
    if not result[0] or len(result[0]) == 0:
        print(f"WARNING: No bounding boxes recognized for {line_image_path}")
        return ""

    recognized_text = result[0][0][1][0]
    return recognized_text


In [15]:
dataset_root = "/content/dataset_with_lines"  # or your dataset path
generate_line_level_texts(dataset_root, paddle_ocr_lang="es")


[2025/03/22 16:06:46] ppocr DEBUG: Namespace(help='==SUPPRESS==', use_gpu=True, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False, ir_optim=True, use_tensorrt=False, min_subgraph_size=15, precision='fp32', gpu_mem=500, gpu_id=0, image_dir=None, page_num=0, det_algorithm='DB', det_model_dir='/root/.paddleocr/whl/det/en/en_PP-OCRv3_det_infer', det_limit_side_len=960, det_limit_type='max', det_box_type='quad', det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5, max_batch_size=10, use_dilation=False, det_db_score_mode='fast', det_east_score_thresh=0.8, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5, det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16, det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5, rec_algorithm='SVTR_LCNet', rec_model_dir='/root/.paddleocr/whl/rec/latin/latin_PP-OCRv3_rec_infer', rec_image_inverse=True, rec_image_shape='3, 48, 320', rec_batch_num=6, max_t

In [17]:
#############################################
# 0. (Optional) Mount Google Drive
#############################################
from google.colab import drive
drive.mount('/content/drive')

#############################################
# 1. Install Dependencies
#############################################
!sudo apt-get -y install poppler-utils tree
!pip install paddleocr transformers peft evaluate albumentations PyPDF2 jiwer

#############################################
# 2. Imports and Utility Functions
#############################################
import os
import re
import cv2
import gc
import torch
import numpy as np
import matplotlib.pyplot as plt

# PIL imports at the global level, so there's no local overshadowing
from PIL import Image, ImageDraw, ImageFont

from paddleocr import PaddleOCR
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Trainer,
    TrainingArguments,
    AdamW,
    get_scheduler
)
from torch.utils.data import Dataset, random_split, ConcatDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from peft import AdaLoraConfig, get_peft_model
from evaluate import load
import torch.nn.functional as F

import logging
logging.getLogger("ppocr").setLevel(logging.ERROR)

#############################################
# 3. Utility Functions
#############################################
def normalize_text(text):
    # Example normalization (Spanish/historical).
    # You can adapt or remove if not needed.
    text = text.replace('ç', 'z').replace('ſ', 's')
    text = re.sub(r'[ùúûüū]', 'u', text)
    text = re.sub(r'[àáâãā]', 'a', text)
    text = re.sub(r'(?<![n])́', '', text)
    text = re.sub(r'[̀̀̈]', '', text)
    return text.lower()[:512]

def read_line_text(txt_path):
    """
    Reads the single-line text from a .txt file for a line image.
    """
    with open(txt_path, "r", encoding="utf-8") as f:
        return f.read().strip()

def advanced_preprocess(image):
    """
    Convert image to grayscale, binarize, denoise, deskew, then convert to 3-channel.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    denoised = cv2.fastNlMeansDenoising(bw, None, h=30, templateWindowSize=7, searchWindowSize=21)

    coords = np.column_stack(np.where(denoised > 0))
    if coords.size == 0:
        return image  # fallback if no text found
    angle = cv2.minAreaRect(coords)[-1]
    if angle < -45:
        angle = -(90 + angle)
    else:
        angle = -angle
    (h, w) = denoised.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    deskewed = cv2.warpAffine(denoised, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
    deskewed_color = cv2.cvtColor(deskewed, cv2.COLOR_GRAY2RGB)
    return deskewed_color

#############################################
# 4. Define Line-Level Dataset
#############################################
class LineLevelDataset(Dataset):
    def __init__(self, lines_root, processor, synthetic_prob=0.0):
        """
        lines_root: path to e.g. /content/dataset_with_lines/<PDF_Name>/images/lines
                    containing line_*.png + line_*.txt
        synthetic_prob: optional if you still want to generate synthetic images from text
        """
        self.lines_root = lines_root
        self.processor = processor
        self.synthetic_prob = synthetic_prob
        self.samples = []

        # Gather all line_*.png + line_*.txt pairs
        for file_name in sorted(os.listdir(self.lines_root)):
            if file_name.lower().endswith((".png", ".jpg", ".jpeg")):
                base_name = file_name.rsplit(".", 1)[0]
                txt_name = base_name + ".txt"
                txt_path = os.path.join(self.lines_root, txt_name)
                if os.path.exists(txt_path):
                    self.samples.append((file_name, txt_name))

        # Basic augmentation pipeline
        self.transform = A.Compose([
            A.Resize(height=384, width=384, always_apply=True),
            A.GaussianBlur(p=0.3),
            A.RandomBrightnessContrast(p=0.4),
            A.ImageCompression(quality_lower=60, p=0.2),
            A.Rotate(limit=3, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.02, rotate_limit=3, p=0.5),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_file, txt_file = self.samples[idx]
        line_img_path = os.path.join(self.lines_root, img_file)
        line_txt_path = os.path.join(self.lines_root, txt_file)

        # Read the text from the corresponding .txt
        text = read_line_text(line_txt_path)

        # Decide if we do synthetic or real
        if np.random.rand() < self.synthetic_prob:
            # Minimal synthetic approach for demonstration
            bg_color = (
                np.random.randint(240, 256),
                np.random.randint(240, 256),
                np.random.randint(240, 256)
            )
            pil_img = Image.new('RGB', (384,384), bg_color)
            draw = ImageDraw.Draw(pil_img)
            font_size = np.random.randint(18, 30)
            fonts = ["arial.ttf", "DejaVuSans.ttf"]
            font_choice = np.random.choice(fonts)
            try:
                font = ImageFont.truetype(font_choice, font_size)
            except:
                font = ImageFont.load_default()
            draw.text((10,10), text[:30], fill=(0,0,0), font=font)
        else:
            # Real line image
            img = cv2.imread(line_img_path, cv2.IMREAD_COLOR)
            if img is None:
                raise ValueError(f"Unable to load line image: {line_img_path}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # Optionally advanced_preprocess(img)
            # e.g.:
            # img = advanced_preprocess(img)
            pil_img = Image.fromarray(img)

        # Albumentations
        aug = self.transform(image=np.array(pil_img))['image']
        aug_np = aug.permute(1, 2, 0).mul(255).byte().numpy()
        pil_img = Image.fromarray(aug_np)

        # Convert to pixel_values
        pixel_values = self.processor(
            pil_img,
            return_tensors="pt"
        ).pixel_values.squeeze(0)

        # Tokenize line text
        text_encoding = self.processor.tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=128
        )
        labels = text_encoding["input_ids"][0]

        return {"pixel_values": pixel_values, "labels": labels}

#############################################
# 5. Gather All Lines from All PDFs, Then Train
#############################################
def gather_all_line_folders(dataset_root):
    """
    dataset_root = /content/dataset_with_lines
    We look in each <PDF_Name>/images/lines/ folder for line images
    """
    lines_paths = []
    for folder_name in sorted(os.listdir(dataset_root)):
        folder_path = os.path.join(dataset_root, folder_name)
        if not os.path.isdir(folder_path):
            continue
        images_folder = os.path.join(folder_path, "images")
        lines_folder = os.path.join(images_folder, "lines")
        if os.path.exists(lines_folder):
            lines_paths.append(lines_folder)
    return lines_paths

def train_line_level_model(dataset_root):
    # 1) gather all line folders
    line_folders = gather_all_line_folders(dataset_root)
    if not line_folders:
        print("No line folders found, exiting.")
        return

    # Combine all line samples into one big dataset
    processor = TrOCRProcessor.from_pretrained("qantev/trocr-base-spanish", do_resize=False, do_normalize=False)

    line_datasets = []
    for lf in line_folders:
        ds = LineLevelDataset(lf, processor, synthetic_prob=0.0)
        line_datasets.append(ds)

    from torch.utils.data import ConcatDataset, random_split
    full_line_dataset = ConcatDataset(line_datasets)
    print(f"Total line images loaded: {len(full_line_dataset)}")

    # 2) Train/Eval split
    train_size = int(0.9 * len(full_line_dataset))
    eval_size = len(full_line_dataset) - train_size
    train_dataset, eval_dataset = random_split(full_line_dataset, [train_size, eval_size])
    print(f"Training samples: {len(train_dataset)}; Evaluation samples: {len(eval_dataset)}")

    # 3) Initialize Model with AdaLoRA
    from transformers import VisionEncoderDecoderModel
    model = VisionEncoderDecoderModel.from_pretrained("qantev/trocr-base-spanish")
    model.config.num_beams = 5
    model.config.early_stopping = True

    target_modules_list = []
    for name, module in model.decoder.named_modules():
        if ("self_attn.q_proj" in name) or ("self_attn.k_proj" in name) or ("self_attn.v_proj" in name):
            target_modules_list.append(name)
    if not target_modules_list:
        raise AttributeError("No target modules found for AdaLoRA in the decoder.")

    from peft import AdaLoraConfig, get_peft_model
    peft_config = AdaLoraConfig(
        target_modules=target_modules_list,
        init_r=12,
        lora_alpha=32,
        lora_dropout=0.1,
        bias="none"
    )
    model.decoder = get_peft_model(model.decoder, peft_config)

    def print_trainable_parameters(m):
        trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
        total = sum(p.numel() for p in m.parameters())
        print(f"Trainable parameters: {trainable} / Total parameters: {total}")

    print_trainable_parameters(model)

    # 4) Training Arguments
    from transformers import TrainingArguments
    training_args = TrainingArguments(
        output_dir="./linelevel_trocr",
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=15,
        learning_rate=1e-5,
        fp16=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir='./logs_line',
        logging_steps=50,
        report_to="none",
        dataloader_drop_last=False,
        remove_unused_columns=True,
        label_smoothing_factor=0.1
    )

    # 5) Metrics & Collate
    cer_metric = load("cer")
    wer_metric = load("wer")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        if isinstance(logits, (tuple, list)):
            logits = logits[0]
        pred_ids = np.argmax(logits, axis=-1)

        preds = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
        refs = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)

        preds = [normalize_text(s) for s in preds]
        refs = [normalize_text(s) for s in refs]

        return {
            "cer": cer_metric.compute(predictions=preds, references=refs),
            "wer": wer_metric.compute(predictions=preds, references=refs)
        }

    def collate_fn(batch):
        max_height = max(item["pixel_values"].shape[1] for item in batch)
        max_width = max(item["pixel_values"].shape[2] for item in batch)
        padded_pixel_values = []
        for item in batch:
            pv = item["pixel_values"]
            c, h, w = pv.shape
            padded = torch.zeros((c, max_height, max_width), dtype=pv.dtype)
            padded[:, :h, :w] = pv
            padded_pixel_values.append(padded)
        collated_labels = torch.stack([item["labels"] for item in batch])
        return {
            "pixel_values": torch.stack(padded_pixel_values),
            "labels": collated_labels
        }

    # 6) Focal Loss Trainer
    import torch.nn.functional as F
    from transformers import Trainer

    class CustomFocalTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            if "num_items_in_batch" in inputs:
                inputs.pop("num_items_in_batch")

            labels = inputs.pop("labels")
            pixel_values = inputs.pop("pixel_values")
            outputs = model(pixel_values=pixel_values, labels=labels)
            logits = outputs.logits

            vocab_size = logits.size(-1)
            logits = logits.view(-1, vocab_size)
            labels = labels.view(-1)

            valid_mask = labels != -100
            logits = logits[valid_mask]
            labels = labels[valid_mask]

            ce_loss = F.cross_entropy(logits, labels, reduction='none')
            gamma = 2.0
            pt = torch.exp(-ce_loss)
            focal_loss = ((1 - pt) ** gamma * ce_loss).mean()

            return (focal_loss, outputs) if return_outputs else focal_loss

    # 7) Train
    trainer = CustomFocalTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics
    )

    print("Starting line-level training...")
    trainer.train()

    torch.cuda.empty_cache()
    model.save_pretrained("./final_linelevel_model")
    processor.save_pretrained("./final_linelevel_model")
    print("Training complete. Final model saved to './final_linelevel_model'.")

#############################################
# 8. Putting It All Together
#############################################
def main():
    dataset_root = "/content/dataset_with_lines"  # or your line-level dataset root
    train_line_level_model(dataset_root)

if __name__ == "__main__":
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tree is already the newest version (2.0.2-1).
poppler-utils is already the newest version (22.02.0-2ubuntu0.6).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.
Total line images loaded: 736
Training samples: 662; Evaluation samples: 74


Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.49.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decode

Trainable parameters: 87538608 / Total parameters: 385749972
Starting line-level training...


Epoch,Training Loss,Validation Loss,Cer,Wer
1,0.3364,0.291484,0.467742,0.665672
2,0.2945,0.266768,0.441349,0.61791
3,0.2605,0.259419,0.437928,0.620896
4,0.2834,0.254209,0.434506,0.614925
5,0.2455,0.254584,0.438416,0.60597
6,0.2612,0.246969,0.415934,0.61791
7,0.2743,0.244941,0.434506,0.614925
8,0.2288,0.240515,0.425709,0.59403
9,0.243,0.24063,0.4174,0.602985
10,0.2638,0.237077,0.407136,0.61194


Training complete. Final model saved to './final_linelevel_model'.
