In this notebook file, I will use my own model to finish two part of work:
First Part - Layout Recognition: Use ViT to extract main text
Second Part - OCR: Recognize the text correctly

In the notebook, the content of whole training and validation process is shown below:
0. The environment setup
1. Convert the pdf raw files into images and make annotations using LabelMe
2. Use ViT to detect main text
3. Crop the main text
4. OCR Recognition
5. Validation

In [None]:
# 0. The environment prerequisites:
!pip install pdf2image
# install poppler

# !pip install labelme # use labelme to make annotations, I use Windows to finish this step
# But other steps are finished in Ubuntu22 (WSL2)



In [None]:
# 1. Convert the pdf raw files into images and make annotations using LabelMe

import os
import glob
from pdf2image import convert_from_path

pdf_folder = "data/raw"
output_root = "data/convert_image"
os.makedirs(output_root, exist_ok=True)

pdf_files = glob.glob(os.path.join(pdf_folder, "*.pdf"))
print("Found PDF files:", pdf_files)

dpi_value = 300

for pdf_file in pdf_files:
    print(f"\nConverting: {pdf_file}")
    base_name = os.path.splitext(os.path.basename(pdf_file))[0]

    pdf_output_dir = os.path.join(output_root, base_name)
    os.makedirs(pdf_output_dir, exist_ok=True)
    
    # LOAD PAGE ONE BY ONE, LOWER CONSUMPTION OF PERFORMANCE
    chunk_size = 10
    start_page = 1
    end_page = chunk_size

    # 6.pdf is too large for 300 dpi
    if base_name == "6":  # Use 150 DPI for 6.pdf
        dpi_value = 150
    else:  # Keep 300 DPI for 1.pdf-5.pdf
        dpi_value = 300

    while True:
        pages = convert_from_path(pdf_file, dpi=dpi_value,
                                  first_page=start_page,
                                  last_page=end_page)

        if not pages:
            break

        # SAVE PAGE ONE BY ONE
        for i, page in enumerate(pages, start=start_page):
            out_path = os.path.join(pdf_output_dir, f"{base_name}_page_{i:03d}.png")
            page.save(out_path, "PNG")
            print(f"  Saved: {out_path}")

        # GOING TO NEXT PAGE
        start_page = end_page + 1
        end_page += chunk_size

print("\nCONVERTING SUCCESSFULLY.")

# The annotations are made by LabelMe
# Annotation files are stored in src/data/annotation
# Annotation way screenshots are stored in docs/annotation


In [1]:
# 2. Use ViT to detect main text

# 1.1 从 JSON + 原图生成分割掩码（mask）
# 1.2 编写一个 PyTorch Dataset，让它一次性返回 (image, mask)
# 1.3 切分成训练集、验证集、测试集（如果数据足够多的话）

import os
import json
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from PIL import Image

# 1.1 从 JSON + 原图生成分割掩码
def load_mask_from_json(json_path, image_path):
    """
    读取 LabelMe 标注的 JSON 文件，生成二值分割掩码。
    1 表示文本区域, 0 表示背景。
    """
    # 读取原图尺寸
    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f"Image not found at {image_path}")
    height, width = img.shape[:2]
    
    # 创建空白 mask
    mask = np.zeros((height, width), dtype=np.uint8)
    
    # 读取 JSON
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 遍历 shapes
    for shape in data.get('shapes', []):
        if shape['label'] == 'text':
            # 提取多边形顶点
            points = shape['points']
            pts = np.array(points, dtype=np.int32)
            # 在 mask 上填充多边形
            cv2.fillPoly(mask, [pts], color=1)
    
    return mask


# 1.2 自定义一个 PyTorch Dataset，用于返回 (image, mask)
class TextSegDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        
        # 收集所有图片文件名
        self.img_files = []
        for f in os.listdir(img_dir):
            if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.img_files.append(f)
        self.img_files.sort()
    
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        # 根据索引拿到图片名
        img_name = self.img_files[idx]
        # 假设 json 与图片同名（扩展名不同）
        json_name = img_name.rsplit('.', 1)[0] + '.json'
        
        img_path = os.path.join(self.img_dir, img_name)
        json_path = os.path.join(self.label_dir, json_name)
        
        # 生成 mask
        mask = load_mask_from_json(json_path, img_path)
        
        # 读取图像 (BGR -> RGB)
        img_bgr = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        
        # 转为 PIL 形式，方便用 torchvision.transforms
        img_pil = Image.fromarray(img_rgb)
        mask_pil = Image.fromarray(mask)
        
        # 如果有 transform，则对图像进行 transform
        # 同时也要对 mask 做相应的 resize 等操作
        if self.transform is not None:
            # 常见做法：对图像用 self.transform
            img_t = self.transform(img_pil)
            
            # 如果 transform 中包含了 resize，需要同步对 mask 做同样的 resize
            # 若 transform 仅仅是 ToTensor()，那就简单处理即可：
            mask_t = T.ToTensor()(mask_pil)
            
            # 但如果 transform 有 Resize((224,224))，则 mask 也要 Resize((224,224))：
            # 可以自定义一个 transform 或单独处理：
            # 例如:
            # mask_t = T.Resize((224,224))(mask_pil)
            # mask_t = T.ToTensor()(mask_t)
            
        else:
            # 否则就直接转 tensor
            img_t = T.ToTensor()(img_pil)
            mask_t = T.ToTensor()(mask_pil)
        
        # mask_t 的 shape 通常是 [1, H, W]，像素值在 [0,1]
        # 你可以在这里做一些类型转换，比如 mask_t.float()，方便后续计算损失
        return img_t, mask_t


# 1.3 切分成训练集、验证集、测试集（如果数据足够多）
def create_datasets(img_dir, label_dir, transform=None, val_ratio=0.1, test_ratio=0.1):
    """
    根据给定目录创建数据集，并按照比例划分训练 / 验证 / 测试集。
    val_ratio: 验证集比例
    test_ratio: 测试集比例
    """
    full_dataset = TextSegDataset(img_dir, label_dir, transform=transform)
    total_len = len(full_dataset)
    
    # 计算划分大小
    val_len = int(total_len * val_ratio)
    test_len = int(total_len * test_ratio)
    train_len = total_len - val_len - test_len
    
    # 随机划分
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_len, val_len, test_len],
        generator=torch.Generator().manual_seed(42)  # 固定随机种子，方便复现
    )
    
    return train_dataset, val_dataset, test_dataset


# ============ 以下是一个简单示例，展示如何使用 ============

if __name__ == "__main__":
    # 假设你的数据在 data/annotation_images 和 data/annotation_labels
    img_dir = "data/annotation_images"
    label_dir = "data/annotation_labels"
    
    # 例如：我们要把图像和 mask 都 resize 到 224x224，并转为 Tensor
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor()
    ])
    
    # 创建数据集并划分
    train_ds, val_ds, test_ds = create_datasets(
        img_dir=img_dir,
        label_dir=label_dir,
        transform=transform,
        val_ratio=0.1,
        test_ratio=0.1
    )
    
    print(f"Train set: {len(train_ds)} samples")
    print(f"Val set: {len(val_ds)} samples")
    print(f"Test set: {len(test_ds)} samples")
    
    # 你也可以直接使用全量数据，不做划分：
    # full_dataset = TextSegDataset(img_dir, label_dir, transform=transform)
    
    # 用 DataLoader 来加载数据
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_ds, batch_size=2, shuffle=False, num_workers=2)
    
    # 测试一下能否正常获取一个 batch
    sample_imgs, sample_masks = next(iter(train_loader))
    print("Sample images shape:", sample_imgs.shape)   # [B, C, H, W]
    print("Sample masks shape:", sample_masks.shape)   # [B, 1, H, W]


Train set: 16 samples
Val set: 1 samples
Test set: 1 samples


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/jeffliu/miniconda3/envs/ml/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 271, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable


In [None]:
# 定义模型（ViT + Head）

# Backbone（主干）：ViT 主要负责把图像“看懂”——提取特征。
# Head（头）：指模型最后几层，用来输出你想要的结果。
# 在分割任务里，“头”通常是一些卷积或上采样层，把特征图转换成与原图同大小的分割结果。



In [None]:
# 训练与验证

# 用你自定义的 DataLoader（基于上面的 Dataset）迭代数据。
# 送入模型，计算损失函数（比如交叉熵、BCE 等）。
# 反向传播、更新模型参数。
# 观察损失是否下降，用验证集做指标（如 Dice、IoU），看效果是否提升。

In [None]:
# 评估 & 可视化

# 在测试集上推理，计算指标。
# 可视化看看预测的分割图是否跟真实标注接近。

In [None]:
3. Crop the main text


In [None]:
4. OCR Recognition


In [None]:
5. Validation