# Обучение модели EasyOCR для решения задачи: "Распознавание автомобильных номеров"

## Импорты

In [1]:
!pip install easyocr
!pip install albumentations

import os
import cv2
import json
import yaml
import csv
from PIL import Image
import pandas as pd
import albumentations as A
from sklearn.model_selection import train_test_split

import torch
import torch.backends.cudnn as cudnn

import easyocr

from google.colab.patches import cv2_imshow



## Config

In [2]:
config = {
    'test_size': 0.2,
    'imgH': 64,
    'imgW': 600,
    'train_data_path': '/content/drive/MyDrive/Распознавание номеров/Dataset/train_data.json',
    'val_data_path': '/content/drive/MyDrive/Распознавание номеров/Dataset/val_data.json',
    'all_data_path': '/content/drive/MyDrive/Распознавание номеров/EasyOCR/trainer/all_data',
}

## Проверка данных

In [3]:
# Папка с изображениями
folder_path = '/content/drive/MyDrive/Распознавание номеров/NumberWithMarking'
images_names = os.listdir(folder_path)
print(f'В папке {len(images_names)} изображений')

В папке 7182 изображений


Как можно заметить изображений достаточно мало, поэтому нужно "размножить" изображения с помощью аугментации. Буду использовать следующие аугментации:

*   Поворот изображения
*   Изменение контрастности
*   Размытие
*   Шум
*   Изменение перспективы





## Аугментация данных

In [4]:
def augment_images(folder_path: str, save_folder_path: str, number_augmentations=15) -> None:
    # Список имен изображений
    images_names = os.listdir(folder_path)

    # Список трансформаций для изображений
    transform = A.Compose([
        A.Rotate(limit=10, p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.GaussianBlur(p=0.2),
        A.GaussNoise(p=0.2),
        A.Perspective(p=0.2)
    ])

    # Обработка изображений
    for image_name in images_names:
        # Путь к изображению
        image_path = os.path.join(folder_path, image_name)
        # Читаем изображение и переводим в RGB
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Обрабатывае отдельно взятое изображение number_augmentations раз
        for i in range(number_augmentations):
            image_augmentation = transform(image=image)
            # Обработанное изображение
            aug_image = image_augmentation['image']

            aug_image_name = f'augmented_image_{i}_{image_name}'
            save_path = os.path.join(save_folder_path, aug_image_name)

            cv2.imwrite(save_path, aug_image)


In [5]:
folder_path = '/content/drive/MyDrive/Распознавание номеров/NumberWithMarking'
save_folder_path = '/content/drive/MyDrive/Распознавание номеров/NumberWithMarking'

#augment_images(folder_path, save_folder_path)
print('Аугментация завершена!')

Аугментация завершена!


In [6]:
print(len(os.listdir(save_folder_path)))

7182


## Создание датасета

In [7]:
def make_dataset(folder_path: str, dataset_path: str) -> None:
    dataset = []
    images_names = os.listdir(folder_path)

    for image_name in images_names:
        image = os.path.join(folder_path, image_name)
        try:
            _, left, right = image.split(' ')
            left = left.split('/')[-1].split('_')
            right = right.split('_')

            text = left[-1] + right[0].replace('.bmp', '')
        except:
            _, part_with_name = image.split(' ')
            if part_with_name.endswith('.ipynb_checkpoints'):
                continue

            part_with_name = part_with_name.split('/')[-1].split('_')[-1]

            text = part_with_name.replace('.bmp', '')
            if text == '':
                continue

        dataset.append({
            'image': image,
            'text': text
        })

    with open(dataset_path, "w", encoding="utf-8") as f:
        json.dump(dataset, f, ensure_ascii=False, indent=4)


In [8]:
# Путь к папке с изображениеями
folder_path = '/content/drive/MyDrive/Распознавание номеров/NumberWithMarking'
# Создание пути к файлу json
dataset_name = 'dataset.json'
dataset_folder_path = '/content/drive/MyDrive/Распознавание номеров/Dataset'
dataset_path = os.path.join(dataset_folder_path, dataset_name)

#make_dataset(folder_path, dataset_path)
print('Датасет создан!')

Датасет создан!


## Разделение датасета на тренировочный и валидационный

In [9]:
# Запись данных датасета в data
with open(dataset_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# Разделение на тренировочный и валидационный датасет
# train_data, val_data = train_test_split(data, test_size=config['test_size'])

# Сохранение тренировочного датасета
# with open(config['train_data_path'], "w", encoding="utf-8") as f:
#         json.dump(train_data, f, ensure_ascii=False, indent=4)

# Сохранение валидационного датасета
# with open(config['val_data_path'], "w", encoding="utf-8") as f:
#         json.dump(val_data, f, ensure_ascii=False, indent=4)

print('Датасеты созданы!')

Датасеты созданы!


## Перенос данных в all_data

In [10]:
def save_and_resize_image(image_path: str, image_folder: str, height: int, width: int):
    try:
        image_name = image_path.split('/')[-1]
        save_path = os.path.join(image_folder, image_name)

        image = Image.open(image_path).convert('L')
        image = image.resize((width, height), Image.Resampling.LANCZOS)

        image.save(save_path)
    except Exception as e:
        print(f"Ошибка с файлом {image_path}: {e}")


In [11]:
def create_folders(data, output_folder: str, folder_name: str):
    folder = os.path.join(output_folder, folder_name)
    if not os.path.exists(folder):
        os.makedirs(folder)

    image_folder = os.path.join(folder, 'images')
    if not os.path.exists(image_folder):
        os.makedirs(image_folder)

    labels = []

    for row in data:
        # Оригинальный путь изображения
        image_path = row['image']
        # Имя файла
        filename = row['image'].split('/')[-1]
        words = row['text']

        save_and_resize_image(image_path, image_folder)

        labels.append([filename, words])

    labels_csv_path = os.path.join(folder, 'labels.csv')
    with open(labels_csv_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(labels)


In [12]:
def transfer_data(train_json_path: str, val_json_path: str, output_folder: str):
    with open(train_json_path, 'r', encoding='utf-8') as f:
        train_data = json.load(f)

    with open(val_json_path, 'r', encoding='utf-8') as f:
        val_data = json.load(f)

    create_folders(train_data, output_folder, 'train_data')
    create_folders(val_data, output_folder, 'val_data')


In [13]:
# Перенос данных в all_data
#transfer_data(config['train_data_path'], config['val_data_path'], config['all_data_path'])

## Конфигурация для обучения

In [14]:
data = {
    "number": '0123456789',
    "symbol": '',
    "lang_char": 'АВЕКМНОРСТУХ',
    "experiment_name": 'ru_filtered',
    "train_data": 'all_data/train_data',
    "valid_data": 'all_data/val_data',
    "manualSeed": 1111,
    "workers": 2,
    "batch_size": 64,
    "num_iter": 100,
    "valInterval": 5,
    "saved_model": 'cyrillic_g2',
    "FT": True,
    "optim": False,
    "lr": 0.0001,
    "beta1": 0.9,
    "rho": 0.95,
    "eps": 1e-8,
    "grad_clip": 5,
    "select_data": 'train',
    "batch_ratio": '1',
    "total_data_usage_ratio": 1.0,
    "batch_max_length": 68,
    "imgH": 64,
    "imgW": 600,
    "rgb": False,
    "contrast_adjust": False,
    "sensitive": True,
    "PAD": True,
    "contrast_adjust": 0.0,
    "data_filtering_off": False,
    "Transformation": 'None',
    "FeatureExtraction": 'VGG',
    "SequenceModeling": 'BiLSTM',
    "Prediction": 'CTC',
    "num_fiducial": 20,
    "input_channel": 1,
    "output_channel": 256,
    "hidden_size": 256,
    "decode": 'greedy',
    "new_prediction": False,
    "freeze_FeatureFxtraction": False,
    "freeze_SequenceModeling": False
}

# with open("config.yaml", "w", encoding="utf8") as yaml_file:
#     yaml.dump(data, yaml_file, allow_unicode=True)

## Обучение модели

In [15]:
#!apt-get install git

In [16]:
# Копирую репозиторий EasyORC на гугл диск
#!git clone "https://github.com/JaidedAI/EasyOCR" "/content/drive/MyDrive/Распознавание номеров/EasyOCR"

In [17]:
# !pip cache purge
# !pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html

In [18]:
# !pip install --upgrade torch torchvision torchaudio

In [19]:
# !pip uninstall -y torch torchvision torchaudio
# !pip install torch torchvision torchaudio

In [20]:
os.getcwd()

'/content'

In [21]:
%cd /content/drive/MyDrive/Распознавание номеров/EasyOCR/trainer

/content/drive/MyDrive/Распознавание номеров/EasyOCR/trainer


In [22]:
!pip install -r ../requirements.txt



In [35]:
import torch
print(torch.__version__)

2.5.1+cu118


In [24]:
import easyocr
print(easyocr.__version__)

1.7.2


In [25]:
print(torch.version.cuda)

11.8


In [38]:
# Uninstall current PyTorch installation to avoid conflicts
!pip uninstall -y torch torchvision torchaudio

# Install a compatible torch version. This will likely resolve the issue.
!pip install torch==1.13.1+cu116 torchvision==0.1.6+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

[0mLooking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu116
Collecting torch==1.13.1+cu116
  Using cached https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp311-cp311-linux_x86_64.whl (1977.9 MB)
[31mERROR: Ignored the following yanked versions: 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.2.0, 0.2.1, 0.2.2, 0.2.2.post2, 0.2.2.post3, 0.15.0[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement torchvision==0.1.6+cu116 (from versions: 0.1.6, 0.2.0, 0.15.1, 0.15.2, 0.16.0, 0.16.1, 0.16.2, 0.17.0, 0.17.1, 0.17.2, 0.18.0, 0.18.1, 0.19.0, 0.19.1, 0.20.0, 0.20.1)[0m[31m
[0m[31mERROR: No matching distribution found for torchvision==0.1.6+cu116[0m[31m
[0m

In [36]:
# test.py переименуйте в validation_script.py (например) + поменяйте импорт в train.py на соответствующий
from train import train
from utils import AttrDict

ImportError: cannot import name '_accumulate' from 'torch._utils' (/usr/local/lib/python3.11/dist-packages/torch/_utils.py)

In [None]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)

    opt = AttrDict(opt)

    if opt.lang_char == 'None':
        characters = ''

        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))

        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char

    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

In [None]:
cudnn.benchmark = True
cudnn.deterministic = False

opt = get_config("config_files/ru_filtered_config.yaml")
train(opt, amp=False)

Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
dataset_root: all_data/train_data
opt.select_data: ['all_data']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root:    all_data/train_data	 dataset: all_data
all_data/train_data/images
sub-directory:	/images	 num samples: 5733
num total samples of all_data: 5733 x 1.0 (total_data_usage_ratio) = 5733
num samples of all_data per batch: 64 x 1.0 (batch_ratio) = 64
--------------------------------------------------------------------------------
Total_batch_size: 64 = 64
--------------------------------------------------------------------------------
dataset_root:    all_data/val_data	 dataset: /
all_data/val_data/images
sub-directory:	/images	 num samples: 1434


ValueError: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.