# Learnable Handwriter Dataset Preparation

This notebook transforms images and their associated ALTO XML files (typically the output from an automatic transcription platform like **eScriptorium**) into the format required to train the **Learnable Handwriter**.

> ⚠️ **Note:**  
> This notebook **cannot be run independently**.  
> You need to download it and run it in a directory with a `data` folder containing:  
> - an **`images`** folder with full manuscript or print images (`.jpeg`, `.jpg`, `.png`)  
> - an **`annotations`** folder containing their corresponding ALTO XML annotation files

Before you start, please read the [Learnable Handwriter Tutorial](https://learnable-handwriter.github.io/tutorial.html) if you haven’t done so already.

---

In [1]:
#base imports

import numpy as np
import pandas as pd
import random

from pathlib import Path
import os
import rootutils
from collections import defaultdict

import json
import csv
import lxml
import xml.etree.ElementTree as ET

import cv2
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing

### Define corpus paths  
*(These are indicative — adjust as needed)*

In [None]:
# Set this to the root directory of your dataset or project
root_path = Path("/path/to/your/project/root")  

# Folder containing the raw ground truth data
dataset_path = root_path / "data" / "raw/ground/truth/folder/name"

# Folder with original uncropped images (change if your folder name differs)
imgs_path = dataset_path / "images"

# Optional: Path to your metadata CSV file (if using metadata)
metadata_file_path = dataset_path / "metadata/csv/name"  

# Folder with associated XML ALTO annotation files (change if your folder name differs)
annotations_path = dataset_path / "annotations"

# Folder where the processed cropped images will be saved
img_save_path = root_path / "datasets" / "processed/dataset/name" / "images"

# Path to save the final annotation JSON file
annotations_json_path = root_path / "datasets" / "processed/dataset/name" / "annotation.json"

### Step 1: Match Images and Annotations

- List all image filenames and all XML annotation filenames from their respective folders.
- The XML annotation filenames **must exactly match** the image filenames (without the file extension), as is typically the case when exporting from tools like *eScriptorium*.
- Identify any images missing their corresponding annotation files.
- Create pairs of matched image and annotation filenames for further processing.
- Print a quick summary and preview a few pairs to verify everything.

In [None]:
# 1️⃣ Get image filenames (e.g., 'btv1b53000323h_f474.jpg')
# Path.rglob("*") recursively lists all files in 'images' subfolders
imgs_links = [f.name for f in imgs_path.rglob("*") if f.is_file()]
imgs = [f.stem for f in imgs_path.rglob("*") if f.is_file()]  # same as imgs_links but without extension

# 2️⃣ Get annotation filenames (e.g., 'Latin8236_f145.xml')
files_links = [f.name for f in annotations_path.rglob("*") if f.is_file()]
annotation_names = [f.stem for f in annotations_path.rglob("*") if f.is_file()]  # filenames without .xml

# 3️⃣ Sanity check: find images with no matching annotation file
missing_files = [img for img in imgs if f"{img}.xml" not in files_links]

if missing_files:
    print("❌ The following images do not have corresponding XML annotation files:")
    for m in missing_files:
        print(f"- {m}")
else:
    print("✅ All images have corresponding annotation files.")

# 4️⃣ Create valid (image, annotation) pairs
image_file_pairs = [(img + ".jpg", img + ".xml") for img in imgs if f"{img}.xml" in files_links]

print(f"\n✅ Found {len(image_file_pairs)} valid image-annotation pairs.")

# 5️⃣ Preview some valid pairs
for img, xml in image_file_pairs[:5]:
    print(f"{img} ⟷ {xml}")

# 6️⃣ Print summary counts
print(f"\n📦 Total image files found: {len(imgs_links)}")
print(f"🗂 Total annotation files found: {len(files_links)}")

# 7️⃣ Optionally inspect annotation file names (e.g., for parsing ark + folio)
print("\n🔍 Annotation filenames (for ark/folio parsing):")
for filename in files_links[:5]:  # only show first 5 for brevity
    print(filename)

In [None]:
# 🧱 Create a dictionary to store the mapping between image names and their corresponding annotation filenames

imgs_annotations = dict()

for img in imgs:  # for all image base names (no extension)
    for filename in files_links:  # check all annotation filenames
        if img in filename:  # if the img ID is found in the annotation filename
            imgs_annotations[f"{img}.jpg"] = filename  # map image to corresponding annotation

# 🧮 Print stats
print(f"✅ Created dictionary with {len(imgs_annotations)} image–annotation pairs")

### Step 2: Extract Text Lines with Transparent Backgrounds

This script extracts individual text lines from document images using their ALTO XML annotations, applying an alpha channel to create transparent backgrounds around each cropped line.

For each image, it:
- Reads the corresponding ALTO XML annotation file.
- Locates polygon coordinates of the text lines.
- Creates cropped images with transparent backgrounds for each line.
- Saves the cropped lines into folders named after the original images.
- Stores the label for each line.

⏳ Processing about 70 images typically takes around 12 minutes.

In [None]:
# --- Setup ---
line_infos = dict()
annotations = dict()
lines_train = list()

# Acceptable image formats
img_extensions = ('.jpg', '.jpeg', '.png')

def process_image(img_file):
    annotation_file = imgs_annotations[img_file]
    if annotation_file.lower() == ".ds_store":
        return 0, [], {}, {}

    img_path = imgs_path / img_file
    if not img_path.exists():
        print(f"Image file not found: {img_path}")
        return 0, [], {}, {}

    img_folder_name = img_path.stem
    img_folder_path = img_save_path / img_folder_name
    if img_folder_path.exists():
        print(f"Folder {img_folder_path} already exists, skipping {img_file}.")
        return 0, [], {}, {}

    img_folder_path.mkdir(parents=True, exist_ok=True)
    print(f"Saving extracted lines for {img_file} to: {img_folder_path}")

    img = cv2.imread(str(img_path))
    if img is None:
        print(f"Error reading image: {img_path}")
        return 0, [], {}, {}

    h, w, _ = img.shape

    tree = ET.parse(annotations_path / annotation_file)
    root = tree.getroot()
    lines = root.findall(
        "{http://www.loc.gov/standards/alto/ns-v4#}Layout/"
        "{http://www.loc.gov/standards/alto/ns-v4#}Page/"
        "{http://www.loc.gov/standards/alto/ns-v4#}PrintSpace/"
        "{http://www.loc.gov/standards/alto/ns-v4#}TextBlock/"
        "{http://www.loc.gov/standards/alto/ns-v4#}TextLine"
    )

    local_line_infos = {}
    local_annotations = {}
    missing_lines_local = []
    count = 0

    for line in lines:
        id_line = line.get('ID')
        line_name = f"{img_path.stem}_{id_line}.png"
        label_tag = line.find('{http://www.loc.gov/standards/alto/ns-v4#}String')
        if label_tag is None:
            continue
        label = label_tag.get('CONTENT')
        local_line_infos[line_name] = {'label': label, 'page': img_file}

        polygon_tag = line.find('{http://www.loc.gov/standards/alto/ns-v4#}Shape/{http://www.loc.gov/standards/alto/ns-v4#}Polygon')
        if polygon_tag is None:
            continue
        polygon_str = polygon_tag.get('POINTS')
        polygon_coords = [int(p) for p in polygon_str.strip().split()]
        points = list(zip(polygon_coords[::2], polygon_coords[1::2]))

        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.fillPoly(mask, [np.array(points, dtype=np.int32)], 255)
        alpha = np.ones(img.shape[:2], dtype=np.uint8) * 255
        alpha[mask == 0] = 0
        masked_img = cv2.merge((img, alpha))

        y_coords, x_coords = np.nonzero(mask)
        if len(y_coords) == 0 or len(x_coords) == 0:
            missing_lines_local.append(line_name)
            continue

        line_coords = {'ulx': min(x_coords), 'uly': min(y_coords), 'lrx': max(x_coords), 'lry': max(y_coords)}
        line_img = masked_img[line_coords['uly']:line_coords['lry'], line_coords['ulx']:line_coords['lrx']]

        line_img_path = img_folder_path / line_name
        cv2.imwrite(str(line_img_path), line_img, [cv2.IMWRITE_PNG_COMPRESSION, 9])  # max compression for quality

        split = 'train' if any(img_file.endswith(ext) for ext in img_extensions) else 'test'
        local_annotations[line_name] = {'label': label, 'split': split}
        count += 1

        if not line_img_path.exists():
            missing_lines_local.append(line_name)

    return count, missing_lines_local, local_line_infos, local_annotations


# --- Run in parallel ---
num_workers = min(8, multiprocessing.cpu_count())  # limit workers to 8 or number of CPUs
total_lines = 0
all_missing_lines = []
all_line_infos = {}
all_annotations = {}

with ThreadPoolExecutor(max_workers=num_workers) as executor:
    futures = {executor.submit(process_image, img_file): img_file for img_file in imgs_annotations.keys()}
    for future in tqdm(as_completed(futures), total=len(futures)):
        count, missing_lines_local, local_line_infos, local_annotations = future.result()
        total_lines += count
        all_missing_lines.extend(missing_lines_local)
        all_line_infos.update(local_line_infos)
        all_annotations.update(local_annotations)

# Update global dicts after processing
line_infos.update(all_line_infos)
annotations.update(all_annotations)

# --- Final summary ---
print(f'Number of lines: {total_lines}')
if all_missing_lines:
    print("The following segmented images are missing:", all_missing_lines)
else:
    print("All segmented images are present.")

### Step 3: Assign Dataset Splits

By default, this step assigns `'split' = 'train'` to all labels.

If you want to include validation (`'val'`) or other splits, you’ll need to modify the code to define how those splits are assigned, for example by:
- Random sampling
- Specific groups or document types
- Any other criteria that suit your dataset

This allows you to customize the training and validation subsets for fine-tuning.

In [None]:
for img_file in annotations.keys():
    annotations[img_file]['split'] = 'train'

with open(annotations_json_path, 'w') as f:
    json.dump(annotations, f, indent=4)

### Sanity Check: Line Counts and Empty Lines

Next, we calculate the number of text lines per image and check for any null (empty) lines.

⚠️ **Important:** Null lines will **not** be skipped by the Learnable Handwriter and will cause training to fail.  
Make sure to remove or fix any empty lines before proceeding.

In [None]:
# Load the annotations JSON file
with open(annotations_json_path, 'r') as f:
    annotations = json.load(f)

# Group files by prefix (before first underscore)
groups = defaultdict(list)
for img_file in annotations.keys():
    group_id = img_file.split('_')[0]
    groups[group_id].append(img_file)

# Assign 'train' split to all images in each group after shuffling
for group_id, imgs in groups.items():
    random.shuffle(imgs)
    for img_file in imgs:
        annotations[img_file]['split'] = 'train'

# Save updated annotations once
with open(annotations_json_path, 'w') as f:
    json.dump(annotations, f, indent=4)

# Create DataFrame from annotations for exploration
df_lines = pd.DataFrame.from_dict(annotations, orient='index')

# Check for missing labels
null_lines = df_lines['label'].isnull().sum() if 'label' in df_lines else 0
print(f"{null_lines} out of {len(df_lines)} lines have null 'label' values.")

# Optional: Count lines per document (using first two parts of filename)
key_counts = defaultdict(int)
for filename in annotations.keys():
    unique_key = '_'.join(filename.split('_')[:2])
    key_counts[unique_key] += 1

key_counts_df = pd.DataFrame(list(key_counts.items()), columns=['ID', 'Nb_Lines'])
key_counts_df.to_csv('document_lines.csv', index=False)
print(key_counts_df)

### Step 4: Optional Script to Add Metadata to `annotation.json`

A `script` field is **mandatory** for fine-tuning a group of documents. This metadata must be present in the `annotation.json` file for the fine-tuning process to work correctly.

To add metadata using this script, you need an external CSV file containing at least the following columns:

- **ID**: The document ID, which corresponds to the image filename (without the extension).  
  *Example:* `btv1b84472995_f141`

- **Script**: The script type or group name (e.g., script type like *Textualis* or a hand like *Raoulet*), depending on your dataset.

We assume that the CSV `ID` values and the image/XML filenames share a consistent format where the **first two parts of the ID (separated by underscores) exactly match the image and XML filenames** (minus their extensions).  
For example, for the image `btv1b53000323h_f474.jpg`, the corresponding metadata ID would start with `btv1b53000323h_f474_...`.  
The script extracts this prefix (the first two underscore-separated parts) to correctly link each metadata entry to its corresponding annotation.

You can also include additional columns with any other metadata fields you want to add and adjust the script accordingly to load those fields into your JSON annotations.

In [None]:
# Load CSV metadata
script_mapping = {}

with open(metadata_file_path, 'r', encoding='utf-8') as csv_file:
    csv_reader = csv.DictReader(csv_file, delimiter=';')
    for row in csv_reader:
        script_mapping[row['ID']] = row['Script']

# Load JSON annotation data
with open(annotations_json_path, 'r', encoding='utf-8') as json_file:
    data = json.load(json_file)

# Update JSON data with metadata
for key, value in data.items():
    denominator = '_'.join(key.split('_')[:2])  # Extract prefix (first two parts before underscore)
    if denominator in script_mapping:
        value['script'] = script_mapping[denominator]

# Save the updated JSON
with open(annotations_json_path, 'w', encoding='utf-8') as json_file:
    json.dump(data, json_file, indent=2)

You can now use your `datasets/name-of-your-dataset` folder to train the Learnable Handwriter!  

To install and get started, follow the instructions in the [Learnable Handwriter README](https://github.com/malamatenia/learnable-handwriter/tree/main).