# Fine-tuning APD on IAM dataset
This is an example of fine-tuning APD on IAM dataset handwritten words from [Kaggle](https://www.kaggle.com/datasets/teykaicong/iamondb-handwriting-dataset). IAM Aachen splits can be downloaded [here](https://www.openslr.org/56/).

In [1]:
import pickle
import os
import sys
import glob
from pathlib import Path
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List, Tuple, Dict
import multiprocessing as mp
from functools import partial

import torch
from torch.utils.data import Dataset, DataLoader
import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
from torch.optim import AdamW
import matplotlib.patches as patches
from torchvision import transforms
from typing import Optional
from torch.optim.lr_scheduler import CosineAnnealingLR
from albumentations import (
    Compose, RandomBrightnessContrast, GaussNoise, ShiftScaleRotate, Blur
)
from torch.utils.tensorboard import SummaryWriter
from shapely.geometry import Polygon
import torch.nn.functional as F

# Ensure the project root is in the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
sys.path.insert(0, project_root)

if True:
    from APD.config import APDConfig
    from APD.processor import APDProcessor
    from APD.model import APDModel

# Dataset folder structure
```
iam_words/
│
├── words/                              # Folder containing word images as PNGs
│   ├── a01/                            # First folder
│   │   ├── a01-000u/
│   │   │   ├── a01-000u-00-00.png
│   │   │   └── a01-000u-00-01.png
│   │   .
│   │   .
│   │   .
│   └── r06/                            # Last folder
│       ├── r06-000/
│       │   ├── r06-000-00-00.png
│       │   └── r06-000-00-01.png
│
├── xml/                                # XML files
│	├── a01-000u.xml
│	.
│	.
│	.
│	└── r06-143.xml
│
└── splits/                             # IAM Aachen splits
    ├── train.uttlist
    ├── validation.uttlist
    └── test.uttlist
```

# Build lists of images and texts

In [2]:


@dataclass
class Word:
    id: str
    file_path: Path
    writer_id: str
    transcription: str

dataset_path = Path('../datasets/iam_words')

xml_files = sorted(glob.glob(str(dataset_path / 'xml' / '*.xml')))
word_image_files = sorted(glob.glob(str(dataset_path / 'words' / '**' / '*.png'), recursive=True))

print(f"{len(xml_files)} XML files and {len(word_image_files)} word image files")

1539 XML files and 115320 word image files


In [3]:
class IAMDatasetForDBNet(Dataset):
    def __init__(self, words: List[Word], config: APDConfig, is_training: bool = True):
        self.words = words
        self.config = config
        self.is_training = is_training

        if is_training:
            self.augmenter = Compose([
                RandomBrightnessContrast(p=0.5),
                GaussNoise(p=0.3),
                ShiftScaleRotate(p=0.5, rotate_limit=5),
                Blur(p=0.3)
            ])

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        image = Image.open(word.file_path).convert('RGB')
        image = np.array(image)

        # Get normalized bounding box coordinates
        h, w = image.shape[:2]
        bbox = np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype=np.float32)
        bbox = bbox / np.array([w, h])[None, :]

        if self.is_training and self.augmenter:
            augmented = self.augmenter(image=image)
            image = augmented['image']

        # Resize image
        image = cv2.resize(image, self.config.image_size[::-1])
        image = torch.from_numpy(image).float() / 255.0
        image = image.permute(2, 0, 1)

        # Create targets
        target = self._prepare_target(bbox, self.config.image_size)

        return {
            'image': image,
            'prob_map': target['prob_map'],
            'thresh_map': target['thresh_map'],
            'binary_map': target['binary_map']
        }

In [4]:
class IAMDatasetForDBNet(Dataset):
    def __init__(self, words: List[Word], config: APDConfig, is_training: bool = True):
        self.words = words
        self.config = config
        self.is_training = is_training

        if is_training:
            self.augmenter = Compose([
                RandomBrightnessContrast(p=0.5),
                GaussNoise(p=0.3),
                ShiftScaleRotate(p=0.5, rotate_limit=5),
                Blur(p=0.3)
            ])

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        image = Image.open(word.file_path).convert('RGB')
        image = np.array(image)

        # Get normalized bounding box coordinates
        h, w = image.shape[:2]
        bbox = np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype=np.float32)
        bbox = bbox / np.array([w, h])[None, :]

        if self.is_training and self.augmenter:
            augmented = self.augmenter(image=image)
            image = augmented['image']

        # Resize image
        image = cv2.resize(image, self.config.image_size[::-1])
        image = torch.from_numpy(image).float() / 255.0
        image = image.permute(2, 0, 1)

        # Create targets
        target = self._prepare_target(bbox, self.config.image_size)

        return {
            'image': image,
            'prob_map': target['prob_map'],
            'thresh_map': target['thresh_map'],
            'binary_map': target['binary_map']
        }

In [5]:
# 替换原有的数据加载代码
def generate_dataset(dataset_path: Path) -> Tuple[List[Word], List[Word], List[Word]]:
    xml_files = sorted(glob.glob(str(dataset_path / 'xml' / '*.xml')))
    word_image_files = sorted(
        glob.glob(str(dataset_path / 'words' / '**' / '*.png'), recursive=True))
    print(f"{len(xml_files)} XML files and {len(word_image_files)} word image files")

    with mp.Pool(processes=mp.cpu_count()) as pool:
        words_from_xmls = list(
            tqdm.tqdm(
                pool.imap(partial(get_words_from_xml,
                          word_image_files=word_image_files), xml_files),
                total=len(xml_files),
                desc='Building dataset'
            )
        )

    words = [word for words in words_from_xmls for word in words]

    # Load train/validation/test splits
    with open(dataset_path / 'splits' / 'train.uttlist') as fp:
        train_ids = set(line.strip() for line in fp)
    with open(dataset_path / 'splits' / 'test.uttlist') as fp:
        test_ids = set(line.strip() for line in fp)
    with open(dataset_path / 'splits' / 'validation.uttlist') as fp:
        validation_ids = set(line.strip() for line in fp)

    train_words = [word for word in words if word.id in train_ids]
    validation_words = [word for word in words if word.id in validation_ids]
    test_words = [word for word in words if word.id in test_ids]

    print(
        f'Generated dataset - Train size: {len(train_words)}; Validation size: {len(validation_words)}; Test size: {len(test_words)}')
    return train_words, validation_words, test_words

# Train test split

In [6]:
# First define all necessary functions
def load_dataset(dataset_path: Path) -> Tuple[List[Word], List[Word], List[Word]]:
    processed_file = dataset_path / 'processed_dataset.pkl'
    if os.path.exists(processed_file):
        print(f"Loading processed dataset from {processed_file}")
        with open(processed_file, 'rb') as f:
            data = pickle.load(f)
        train_words = data['train_words']
        validation_words = data['validation_words']
        test_words = data['test_words']
        print(
            f'Loaded dataset - Train size: {len(train_words)}; Validation size: {len(validation_words)}; Test size: {len(test_words)}')
    else:
        print("Processed dataset not found. Generating new dataset.")
        train_words, validation_words, test_words = generate_dataset(
            dataset_path)
        save_dataset(dataset_path, train_words, validation_words, test_words)

    return train_words, validation_words, test_words

def save_dataset(dataset_path: Path, train_words: List[Word], validation_words: List[Word], test_words: List[Word]):
    data = {
        'train_words': train_words,
        'validation_words': validation_words,
        'test_words': test_words
    }
    with open(dataset_path / 'processed_dataset.pkl', 'wb') as f:
        pickle.dump(data, f)
    print(f"Dataset saved to {dataset_path / 'processed_dataset.pkl'}")

def generate_dataset(dataset_path: Path) -> Tuple[List[Word], List[Word], List[Word]]:
    xml_files = sorted(glob.glob(str(dataset_path / 'xml' / '*.xml')))
    word_image_files = sorted(
        glob.glob(str(dataset_path / 'words' / '**' / '*.png'), recursive=True))
    print(f"{len(xml_files)} XML files and {len(word_image_files)} word image files")

    with mp.Pool(processes=mp.cpu_count()) as pool:
        words_from_xmls = list(
            tqdm.tqdm(
                pool.imap(partial(get_words_from_xml,
                          word_image_files=word_image_files), xml_files),
                total=len(xml_files),
                desc='Building dataset'
            )
        )

    words = [word for words in words_from_xmls for word in words]

    # Load train/validation/test splits
    with open(dataset_path / 'splits' / 'train.uttlist') as fp:
        train_ids = set(line.strip() for line in fp)
    with open(dataset_path / 'splits' / 'test.uttlist') as fp:
        test_ids = set(line.strip() for line in fp)
    with open(dataset_path / 'splits' / 'validation.uttlist') as fp:
        validation_ids = set(line.strip() for line in fp)

    train_words = [word for word in words if word.id in train_ids]
    validation_words = [word for word in words if word.id in validation_ids]
    test_words = [word for word in words if word.id in test_ids]

    print(
        f'Generated dataset - Train size: {len(train_words)}; Validation size: {len(validation_words)}; Test size: {len(test_words)}')
    return train_words, validation_words, test_words

# Then use the functions
dataset_path = Path('../datasets/iam_words')

# Load or generate dataset
train_words, validation_words, test_words = load_dataset(dataset_path)

# Create datasets
config = APDConfig()
train_dataset = IAMDatasetForDBNet(train_words, config, is_training=True)
val_dataset = IAMDatasetForDBNet(validation_words, config, is_training=False)
test_dataset = IAMDatasetForDBNet(test_words, config, is_training=False)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                         shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                       shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                        shuffle=False, num_workers=4)

print(f"Train size: {len(train_dataset)}; Validation size: {len(val_dataset)}; Test size: {len(test_dataset)}")

Loading processed dataset from ../datasets/iam_words/processed_dataset.pkl
Loaded dataset - Train size: 55079; Validation size: 8895; Test size: 25920
Train size: 55079; Validation size: 8895; Test size: 25920


In [7]:
# Load and prepare dataset
dataset_path = Path('../datasets/iam_words')

# Load or generate dataset
train_words, validation_words, test_words = load_dataset(dataset_path)

def load_dataset(dataset_path: Path) -> Tuple[List[Word], List[Word], List[Word]]:
    processed_file = dataset_path / 'processed_dataset.pkl'
    if os.path.exists(processed_file):
        print(f"Loading processed dataset from {processed_file}")
        with open(processed_file, 'rb') as f:
            data = pickle.load(f)
        train_words = data['train_words']
        validation_words = data['validation_words']
        test_words = data['test_words']
        print(
            f'Loaded dataset - Train size: {len(train_words)}; Validation size: {len(validation_words)}; Test size: {len(test_words)}')
    else:
        print("Processed dataset not found. Generating new dataset.")
        train_words, validation_words, test_words = generate_dataset(
            dataset_path)
        save_dataset(dataset_path, train_words, validation_words, test_words)

    return train_words, validation_words, test_words

def save_dataset(dataset_path: Path, train_words: List[Word], validation_words: List[Word], test_words: List[Word]):
    data = {
        'train_words': train_words,
        'validation_words': validation_words,
        'test_words': test_words
    }
    with open(dataset_path / 'processed_dataset.pkl', 'wb') as f:
        pickle.dump(data, f)
    print(f"Dataset saved to {dataset_path / 'processed_dataset.pkl'}")

# Create datasets
config = APDConfig()
train_dataset = IAMDatasetForDBNet(train_words, config, is_training=True)
val_dataset = IAMDatasetForDBNet(validation_words, config, is_training=False)
test_dataset = IAMDatasetForDBNet(test_words, config, is_training=False)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, 
                         shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, 
                       shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, 
                        shuffle=False, num_workers=4)

print(f"Train size: {len(train_dataset)}; Validation size: {len(val_dataset)}; Test size: {len(test_dataset)}")

Loading processed dataset from ../datasets/iam_words/processed_dataset.pkl
Loaded dataset - Train size: 55079; Validation size: 8895; Test size: 25920
Train size: 55079; Validation size: 8895; Test size: 25920


# Build dataset and dataloader

# Model

In [10]:
from APD.dbnet import DBNet  # 确保路径正确
import tqdm
import matplotlib.pyplot as plt

from PIL import Image
import numpy as np
from torch.utils.data import  DataLoader


import torch

# 替代方案：使用 torchvision 的 ResNet
from torchvision.models import resnet18, ResNet18_Weights
backbone = resnet18(weights=ResNet18_Weights.DEFAULT)
torch.set_float32_matmul_precision('high')


# 使用 DBNet

# 创建模型实例
# Check if MPS is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Create model instance
model = DBNet(backbone_name='resnet18')
model = torch.compile(model)
model.to(device)  # Use the detected device instead of hardcoding to CUDA


def test_model(model, test_dataset, processor):
    model.eval()
    for test_word in test_dataset[:5]:  # 测试前5个样本
        image = Image.open(test_word.file_path).convert('RGB')
        image_np = np.array(image)

        # 预处理图像
        preprocessed = processor.preprocess_image(image_np)

        # 模型推理
        with torch.no_grad():
            outputs = model(preprocessed.unsqueeze(0).to(model.device))

        # 后处理结果
        regions = processor.postprocess(outputs, image_np.shape[:2])

        # 可视化结果
        plt.figure(figsize=(10, 5))
        plt.imshow(image_np)
        for region in regions:
            bbox = region['bbox']
            plt.plot(bbox[:, 0], bbox[:, 1], 'r-', linewidth=2)
        plt.axis('off')
        plt.show()

Using device: mps


# Training

In [12]:
from typing import Tuple

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=0
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=0
)

def evaluate_model(model: torch.nn.Module, dataloader: DataLoader) -> Tuple[float, float]:
    # set model to evaluation mode
    model.eval()

    losses, accuracies = [], []
    with torch.no_grad():
        for inputs in tqdm.tqdm(dataloader, total=len(dataloader), desc=f'Evaluating test set'):
            inputs = send_inputs_to_device(inputs, device=0)
            outputs = model(**inputs)

            losses.append(outputs.loss.item())
            accuracies.append(outputs.accuracy.item())

    loss = sum(losses) / len(losses)
    accuracy = sum(accuracies) / len(accuracies)

    # set model back to training mode
    model.train()

    return loss, accuracy


def send_inputs_to_device(dictionary, device):
    return {key: value.to(device=device) if isinstance(value, torch.Tensor) else value for key, value in dictionary.items()}


use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
optimiser = torch.optim.Adam(params=model.parameters(), lr=1e-4)

EPOCHS = 50
train_losses, train_accuracies = [], []
validation_losses, validation_accuracies = [], []
for epoch in range(EPOCHS):
    epoch_losses, epoch_accuracies = [], []
    for inputs in tqdm.tqdm(train_loader, total=len(train_loader), desc=f'Epoch {epoch + 1}'):

        # set gradients to zero
        optimiser.zero_grad()

        # send inputs to same device as model
        inputs = send_inputs_to_device(inputs, device=0)

        # forward pass
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
            outputs = model(**inputs)

        # calculate gradients
        scaler.scale(outputs.loss).backward()

        # update weights
        scaler.step(optimiser)
        scaler.update()

        epoch_losses.append(outputs.loss.item())
        epoch_accuracies.append(outputs.accuracy.item())

    # store loss and metrics
    train_losses.append(sum(epoch_losses) / len(epoch_losses))
    train_accuracies.append(sum(epoch_accuracies) / len(epoch_accuracies))

    # tests loss and accuracy
    # tests loss and accuracy
    validation_loss, validation_accuracy = evaluate_model(model, val_loader)
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_accuracy)

    print(
        f"Epoch: {epoch + 1} - Train loss: {train_losses[-1]}, Train accuracy: {train_accuracies[-1]}, Validation loss: {validation_losses[-1]}, Validation accuracy: {validation_accuracies[-1]}")

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
Epoch 1:   0%|          | 0/1722 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 130, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'IAMDatasetForDBNet' on <module '__main__' (built-in)>


# Test

In [None]:
from APD.model import APDLMHeadModel
from APD.config import APDConfig
from APD.processor import APDProcessor

# model = APDLMHeadModel(APDConfig())
model.eval()
model.to('cpu')
test_processor = APDProcessor(APDConfig())

In [None]:
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

for test_word_record in test_word_records[:50]:
    image_file = test_word_record.file_path
    image = Image.open(image_file).convert('RGB')

    inputs = test_processor(
        images=image,
        texts=test_processor.tokeniser.bos_token,
        return_tensors='pt'
    )

    model_output = model.generate(
        inputs,
        test_processor,
        num_beams=3
    )

    predicted_text = test_processor.tokeniser.decode(model_output[0], skip_special_tokens=True)

    plt.figure(figsize=(10, 5))
    plt.title(predicted_text, fontsize=24)
    plt.imshow(np.array(image, dtype=np.uint8))
    plt.xticks([]), plt.yticks([])
    plt.show()