# MMOCR Training

This notebook contains all source code to train text detection and recognition models. You don't need to change anything except the path to datasets and config file modification.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Dataset Preparation

First, we need to change the format of the Label Studio annotation to MMOCR annotation.

Load the dataset to local directory. Change in to your case.

In [4]:
!cp -r "/content/drive/MyDrive/Data PC/Data/Dibimbing/Day 25/Assignment/handwriting" "./handwriting"

In [5]:
import cv2
import json
import numpy as np
import os
import shutil
from pathlib import Path
from typing import Dict, List, Tuple

Functions for text detection dataset preparation

In [6]:
def xywh2xyxy(xywh: List[float], img_width: int, img_height: int) -> List[int]:
    """
    Change bounding box format xywh normalized to xyxy
    """
    x, y, w, h = xywh
    x = x * img_width / 100
    y = y * img_height / 100
    w = w * img_width / 100
    h = h * img_height / 100
    return [
        int(x),
        int(y),
        int(x + w),
        int(y + h),
    ]

def xyxy2poly(xyxy: List[int]) -> List[int]:
    """
    Change bounding box format from xyxy to polygon
    format xyxyxy...
    """
    x1, y1, x2, y2 = xyxy
    return [
        x1, y1, x1, y2, x2, y2, x2, y1
    ]


def create_instance_mmocr_anno(
    label_ls: Dict,
    text: str,
    img_width: int,
    img_height: int,
) -> Dict:
    """
    Conver annotation of a text instance from label studio format
    to MMOCR format
    """
    bbox = xywh2xyxy(
        [
            label_ls["x"],
            label_ls["y"],
            label_ls["width"],
            label_ls["height"],
        ],
        img_width,
        img_height,
    )
    instance_anno = {}
    instance_anno["bbox"] = bbox
    instance_anno["bbox_label"] = 0
    instance_anno["polygon"] = xyxy2poly(bbox)
    instance_anno["text"] = text
    instance_anno["ignore"] = False
    return instance_anno

def create_image_mmocr_anno(image_name: str, image_ls: Dict) -> Dict:
    """
    Conver annotation of an image from label studio format
    to MMOCR format
    """
    img_width = image_ls["label"][0]["original_width"]
    img_height = image_ls["label"][0]["original_height"]
    image_anno = {}
    image_anno["img_path"] = image_name
    image_anno["height"] = img_height
    image_anno["width"] = img_width
    image_anno["instances"] = [
        create_instance_mmocr_anno(lbl, txt, img_width, img_height)
        for lbl, txt in zip(image_ls["label"], image_ls["transcription"])
    ]
    return image_anno

def create_metainfo_det() -> Dict:
    """
    Metainfo for MMOCR text detection dataset
    """
    return {
        "dataset_type": "TextDetDataset",
        "task_name": "textdet",
        "category": [{"id": 0, "name": "text"}],
    }

def create_output_json(
    annotations: List[Dict],
    metainfo: Dict,
    output_path: Path
) -> None:
    """
    Dump MMOCR annotation JSON
    """
    output = {
        "metainfo": metainfo,
        "data_list": annotations
    }
    with open(output_path, "w") as f:
        json.dump(output, f)

def get_image_name(ls_image_path: str) -> str:
    """
    Label studio will write the image file name in format of
    '{random_id}-{original_image_name}'. So we only want to
    get the original image name, since that is the name that
    we have.
    """
    name = os.path.basename(ls_image_path)
    name = name[(name.find("-") + 1):]
    return name

def create_mmocr_det_anno(
    ls_anno_path: Path,
    train_images_dir: Path,
    test_images_dir: Path,
    output_dir: Path,
):
    """
    Create text detection dataset in MMOCR format
    """
    train_images = [p for p in train_images_dir.glob("*")]
    test_images = [p for p in test_images_dir.glob("*")]
    with open(ls_anno_path, "r") as f:
        ls_anno = json.load(f)
    image_annos = {}
    for ann in ls_anno:
        img_name = get_image_name(ann["ocr"])
        image_annos[img_name] = create_image_mmocr_anno(img_name, ann)

    output_dir.mkdir(parents=True, exist_ok=True)
    for p in [*train_images, *test_images]:
      shutil.copy(p, output_dir / p.name)
    create_output_json(
        annotations=[image_annos[p.name] for p in train_images],
        metainfo=create_metainfo_det(),
        output_path=output_dir / "textdet_train.json"
    )
    create_output_json(
        annotations=[image_annos[p.name] for p in test_images],
        metainfo=create_metainfo_det(),
        output_path=output_dir / "textdet_test.json"
    )

Functions for text recognition dataset preparation

In [7]:
def create_metainfo_rec() -> Dict:
    """
    Metainfo for MMOCR text recognition dataset
    """
    return {
        "dataset_type": "TextRecogDataset",
        "task_name": "textrecog",
    }

def crop_images(
    src_annos: Dict,
    image_src_dir: Path,
    image_dst_dir: Path,
) -> List[Dict]:
    """
    Crop text images and extract the text annotations
    """
    image_path = image_src_dir / src_annos["img_path"]
    image = cv2.imread(str(image_path))
    image_name = image_path.stem

    anns = []
    for i, src_txt_anno in enumerate(src_annos["instances"]):
        dst_image_file = f"{image_name}_{i:05}.jpg"
        x1, y1, x2, y2 = src_txt_anno["bbox"]
        crop = image[y1:y2, x1:x2]
        cv2.imwrite(str(image_dst_dir / dst_image_file), crop)

        instance = [{"text": src_txt_anno["text"]}]
        crop_ann = {
            "img_path": dst_image_file,
            "height": crop.shape[0],
            "width": crop.shape[1],
            "instances": instance
        }
        anns.append(crop_ann)
    return anns


def create_split_anno(
    det_anno_path: Path,
    det_images_dir: Path,
    output_dir: Path,
    json_name: str,
):
    """
    Create formatted text recognition dataset for
    a dataset split.
    """
    with open(det_anno_path, "r") as f:
        det_anno = json.load(f)
    new_data_list = []
    for src_anno in det_anno["data_list"]:
        new_data_list += crop_images(
            src_anno,
            det_images_dir,
            output_dir,
        )
    new_anno = {
        "metainfo": create_metainfo_rec(),
        "data_list": new_data_list,
    }
    with open(output_dir / json_name, "w") as f:
      json.dump(new_anno, f)

def create_mmocr_rec_anno(
    det_root_dir: Path,
    output_dir: Path,
):
    """
    Create text recognition dataset in MMOCR format
    """
    output_dir.mkdir(parents=True, exist_ok=True)
    create_split_anno(
        det_root_dir / "textdet_train.json",
        det_root_dir,
        output_dir,
        "textrecog_train.json"
    )
    create_split_anno(
        det_root_dir / "textdet_test.json",
        det_root_dir,
        output_dir,
        "textrecog_test.json"
    )

Do the actual format conversions. **Change the input path to the one you have in your environment.**

In [8]:
# change to path to your label-studio annotation JSON
LABEL_STUDIO_ANN = Path("handwriting/label-studio-anno.json")
# change to path to your training images folder
TRAIN_IMGS = Path("handwriting/training")
# change to path to your test images folder
TEST_IMGS = Path("handwriting/test")
# formatted dataset for text detection will be saved in the directory below
OUTPUT_DET_DIR = Path("dataset-det")
# formatted dataset for text recognition will be saved in the directory below
OUTPUT_REC_DIR = Path("dataset-rec")

create_mmocr_det_anno(
    LABEL_STUDIO_ANN,
    TRAIN_IMGS,
    TEST_IMGS,
    OUTPUT_DET_DIR,
)
create_mmocr_rec_anno(
    OUTPUT_DET_DIR,
    OUTPUT_REC_DIR,
)

## Setup for Training

In [9]:
!pip install torch==1.13.1+cu117 \
  torchvision==0.14.1+cu117 \
  --extra-index-url https://download.pytorch.org/whl/cu117
!pip install -U openmim
!mim install "mmengine>=0.7.1,<1.1.0"
!mim install "mmcv>=2.0.0rc4,<2.1.0"
!mim install "mmdet>=3.0.0rc5,<3.2.0"
!git clone https://github.com/open-mmlab/mmocr.git
!cd mmocr && pip install -v -e .

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117
Collecting torch==1.13.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-linux_x86_64.whl (1801.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m658.8 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.14.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torchvision-0.14.1%2Bcu117-cp310-cp310-linux_x86_64.whl (24.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.3/24.3 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 2.1.0+cu121
    Uninstalling torch-2.1.0+cu121:
      Successfully uninstalled torch-2.1.0+cu121
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.16.0+cu121
    Uninstalling torchvision-0.16.0+cu121:
      Successf

Looking in links: https://download.openmmlab.com/mmcv/dist/cu117/torch1.13.0/index.html
Collecting mmengine<1.1.0,>=0.7.1
  Downloading mmengine-0.10.3-py3-none-any.whl (451 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m451.7/451.7 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting addict (from mmengine<1.1.0,>=0.7.1)
  Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting yapf (from mmengine<1.1.0,>=0.7.1)
  Downloading yapf-0.40.2-py3-none-any.whl (254 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.7/254.7 kB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: addict, yapf, mmengine
Successfully installed addict-2.4.0 mmengine-0.10.3 yapf-0.40.2
Looking in links: https://download.openmmlab.com/mmcv/dist/cu117/torch1.13.0/index.html
Collecting mmcv<2.1.0,>=2.0.0rc4
  Downloading https://download.openmmlab.com/mmcv/dist/cu117/torch1.13.0/mmcv-2.0.1-cp310-cp310-manylinux1_x86_64.whl (73.9 MB)
[2K 

## Training Text Detection

We will be using the config file `/content/mmocr/configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py` as the main model config. Note that each parameters can be defined in another `.py` file, since MMOCR uses distributed configuration files. Check the `_base_` of the main config.

Change in the configuration:

- Root data (Use ICDAR2015 config) to `dataset-det`
- Num of iterations, try at least 50, be careful to not overfit
- Validation cycle, try around 10 iters
- TensorBoard visualizer

  ```
  vis_backends = [dict(type='LocalVisBackend'),
                  dict(type='TensorboardVisBackend')]
  ```

- Only save last checkpoint

  ```
      checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=1)
  ```

In [10]:
_base_ = [
    '/content/mmocr/configs/textdet/icdar2015_dbnet_resnet50dcnv2_fpn.py',
    '/content/mmocr/configs/_base_/schedules/schedule_1200e.py'
]

dataset_type = 'IcdarDataset'
data_root = 'dataset-det/'

total_iters = 50

evaluation = dict(interval=10)

vis_backends = [
    dict(type='LocalVisBackend'),
    dict(type='TensorboardVisBackend')
]

checkpoint = dict(type='CheckpointHook', interval=10, max_keep_ckpts=1)


In [None]:
!python /content/mmocr/tools/visualizations/browse_dataset.py \
  "/content/mmocr/configs/textdet/dbnet/dbnet_resnet50_1200e_icdar2015.py" \
  -o "/content/vis" \
  -m original

In [None]:
%reload_ext tensorboard
%tensorboard --logdir "/content/work_dir"

In [None]:
!python "/content/mmocr/tools/train.py" \
  "/content/mmocr/configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py" \
  --work-dir "/content/work_dir"

We will be using the config file `/content/mmocr/configs/textrecog/svtr/svtr-base_20e_st_mj.py` as the main model config. Note that each parameters can be defined in another `.py` file, since MMOCR uses distributed configuration files. Check the `_base_` of the main config.

Change in the configuration:

- Root data (Use ICDAR2015 config) to `dataset-rec`
- Num of iterations, try the default fist.
- TensorBoard visualizer

  ```
  vis_backends = [dict(type='LocalVisBackend'),
                  dict(type='TensorboardVisBackend')]
  ```

- Only save last checkpoint

  ```
      checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1)
  ```

- Validation evaluator

  ```
  val_evaluator = dict(
      _delete_=True,
      type='Evaluator',
      metrics=[
          dict(
              type='WordMetric',
              mode=['exact', 'ignore_case', 'ignore_case_symbol']),
          dict(type='CharMetric')
      ])
  test_evaluator = val_evaluator
  ```

- Train/test dataset list

  ```
  train_list = [_base_.icdar2015_textrecog_train]
  test_list = [_base_.icdar2015_textrecog_test]
  ```

- Update pre-trained model

  ```
  load_from = "https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth"
  ```

- Change batch size to smaller value if you get CUDA OOM, e.g. 128

In [53]:
vis_backends = [
    dict(type='LocalVisBackend'),
    dict(type='TensorboardVisBackend')
]

# Checkpoint configuration
checkpoint = dict(type='CheckpointHook', interval=1, max_keep_ckpts=1)

# Validation Evaluator configuration
val_evaluator = dict(
    _delete_=True,
    type='Evaluator',
    metrics=[
        dict(
            type='WordMetric',
            mode=['exact', 'ignore_case', 'ignore_case_symbol']),
        dict(type='CharMetric')
    ]
)

# Test Evaluator configuration
test_evaluator = val_evaluator

# Train/test dataset dictionaries
train_list = dict(
    type='_base_.icdar2015_textrecog_train'
)
test_list = dict(
    type='_base_.icdar2015_textrecog_test'
)

# Update pre-trained model
load_from = "https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth"

# Note: Change batch size to smaller value if CUDA OOM, e.g., 12
batch_size = 16  # Replace with your desired batch size or make it configurable

# Define other training parameters as needed

# Placeholder for training loop or function
def train_model():
    # Your training code goes here
    pass

# Placeholder for testing loop or function
def test_model():
    # Your testing code goes here
    pass

# Run training using the defined configuration
if __name__ == "__main__":
    # Replace with actual training and testing calls based on your framework
    train_model()
    test_model()

In [None]:
!python /content/mmocr/tools/visualizations/browse_dataset.py \
  "/content/mmocr/configs/textrecog/svtr/svtr-base_20e_st_mj.py" \
  -o "/content/vis" \
  -m original

In [None]:
%reload_ext tensorboard
%tensorboard --logdir "/content/work_dir"

In [None]:
!python "/content/mmocr/tools/train.py" \
  "/content/mmocr/configs/textrecog/svtr/svtr-base_20e_st_mj.py" \
  --work-dir "/content/work_dir"

In [58]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# Save the model checkpoint
checkpoint_path = "modelOCR.pth"
torch.save(model.state_dict(), checkpoint_path)

print(f"Model saved to {checkpoint_path}")


Model saved to modelOCR.pth
