In [2]:
# importação dos dados
import os
import tarfile
from tqdm import tqdm
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

import torch
import torch.optim as optim
from torchsummaryX import summary

from mltu.torch.model import Model
from mltu.torch.losses import CTCLoss
from mltu.torch.dataProvider import DataProvider
from mltu.torch.metrics import CERMetric, WERMetric
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, ReduceLROnPlateau

from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage

from model import Network
from configs import ModelConfigs

- Pré processamento dos dados

In [3]:
dataset, vocab, max_len = [], set(), 0
dataset_path = "../data/iam_data/"
# Processando o dataset pelo formato especifico do dataset IAM_Words
words = open(f"{dataset_path}words.txt", "r").readlines()
for line in tqdm(words):
    if line.startswith("#"):
        continue

    line_split = line.split(" ")
    if line_split[1] == "err":
        continue

    folder1 = line_split[0][:3]
    folder2 = "-".join(line_split[0].split("-")[:2])
    file_name = line_split[0] + ".png"
    label = line_split[-1].rstrip('\n')
    rel_path = f"{dataset_path}words/{folder1}/{folder2}/{file_name}"
    if not os.path.exists(rel_path):
        print(f"File not found: {rel_path}")
        continue

    dataset.append([rel_path, label])
    vocab.update(list(label))
    max_len = max(max_len, len(label))

configs = ModelConfigs()

# Save vocab and maximum text length to configs
configs.vocab = "".join(sorted(vocab))
configs.max_text_length = max_len
configs.save()

  0%|          | 0/115320 [00:00<?, ?it/s]

100%|██████████| 115320/115320 [00:05<00:00, 22608.27it/s]


In [4]:
dataset

[['../data/iam_data/words/a01/a01-000u/a01-000u-00-00.png', 'A'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-01.png', 'MOVE'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-02.png', 'to'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-03.png', 'stop'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-04.png', 'Mr.'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-05.png', 'Gaitskell'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-06.png', 'from'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-00.png', 'nominating'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-01.png', 'any'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-02.png', 'more'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-03.png', 'Labour'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-04.png', 'life'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-05.png', 'Peers'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-02-00.png', 'is'],
 ['../data/iam_data/words/a01

In [5]:
# Criação do modelo e definição dos hiperparâmetros
data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(CVImage)],
    transformers=[
        # ImageShowCV2(),  # uncomment to show images when iterating over the data provider
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length,
                     padding_value=len(configs.vocab))
    ],
    use_cache=True,
)

2023-06-09 18:18:04,431 INFO DataProvider: Skipping Dataset validation...


In [6]:
# verificar se o dataset foi carregado corretamente
# for _ in data_provider:
#     pass

In [7]:
# Divisão do dataset em treino e teste (90% e 10%)
train_dataProvider, test_dataProvider = data_provider.split(split = 0.9)

In [8]:
#
train_dataProvider.augmentors = [
    RandomBrightness(), 
    RandomErodeDilate(),
    RandomSharpen(),
    RandomRotate(angle=10), 
    ]

In [9]:
network = Network(len(configs.vocab), activation="leaky_relu", dropout=0.3)
loss = CTCLoss(blank=len(configs.vocab))
optimizer = optim.Adam(network.parameters(), lr=configs.learning_rate)

In [10]:
# uncomment to print network summary, torchsummaryX package is required
# summary(network, torch.zeros((1, configs.height, configs.width, 3)))

In [11]:
# Usar GPU
# if torch.cuda.is_available():
#     network = network.cuda()
    
# Usar CPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
network = network.cpu()    


In [12]:
# Criação do callback para o treinamento
earlyStopping = EarlyStopping(monitor="val_CER", patience=20, mode="min", verbose=1)
modelCheckpoint = ModelCheckpoint(configs.model_path + "/model.pt", monitor="val_CER", mode="min", save_best_only=True, verbose=1)
tb_callback = TensorBoard(configs.model_path + "/logs")
reduce_lr = ReduceLROnPlateau(monitor="val_CER", factor=0.9, patience=10, verbose=1, mode="min", min_lr=1e-6)
model2onnx = Model2onnx(
    saved_model_path=configs.model_path + "/model.pt",
    input_shape=(1, configs.height, configs.width, 3), 
    verbose=1,
    metadata={"vocab": configs.vocab}
    )

In [13]:
# Criação do modelo para o treinamento
model = Model(network, optimizer, loss, metrics=[CERMetric(configs.vocab), WERMetric(configs.vocab)])
model.fit(
    train_dataProvider, 
    test_dataProvider, 
    epochs=7, 
    callbacks=[earlyStopping, modelCheckpoint, tb_callback, reduce_lr, Model2onnx(
        saved_model_path=configs.model_path + "/model.pt",
        input_shape=(1, configs.height, configs.width, 3), 
        verbose=1,
        metadata={"vocab": configs.vocab},
        onnx_opset_version=10
    )]
)

Epoch 1 - loss: 3.4491 - CER: 0.9289 - WER: 0.9303: 100%|██████████| 1357/1357 [15:28<00:00,  1.46it/s]
          val_loss: 2.4098 - val_CER: 0.7661 - val_WER: 0.7958: 100%|██████████| 151/151 [00:42<00:00,  3.59it/s]
2023-06-09 18:34:17,859 INFO ModelCheckpoint: Epoch 1: val_CER improved from inf to 0.76605, saving model to Models/08_handwriting_recognition_torch\202306091818/model.pt
Epoch 2 - loss: 2.2792 - CER: 0.7230 - WER: 0.7780: 100%|██████████| 1357/1357 [13:59<00:00,  1.62it/s]
          val_loss: 2.6696 - val_CER: 0.7615 - val_WER: 0.8263: 100%|██████████| 151/151 [00:26<00:00,  5.66it/s]
2023-06-09 18:48:43,900 INFO ModelCheckpoint: Epoch 2: val_CER improved from 0.76605 to 0.76149, saving model to Models/08_handwriting_recognition_torch\202306091818/model.pt
Epoch 3 - loss: 1.9402 - CER: 0.6504 - WER: 0.7239: 100%|██████████| 1357/1357 [14:01<00:00,  1.61it/s]
          val_loss: 1.6388 - val_CER: 0.5870 - val_WER: 0.6774: 100%|██████████| 151/151 [00:29<00:00,  5.04it/s]


ValueError: Unsupported ONNX opset version: 14

In [14]:
# salvar o modelo
model.save(configs.model_path + "/model.hdf5")

In [15]:
# Save training and validation datasets as csv files
train_dataProvider.to_csv(os.path.join("./", "train.csv"))
test_dataProvider.to_csv(os.path.join("./", "val.csv"))

In [16]:
# testar com uma imagem
import cv2
import numpy as np
from keras.models import load_model

# Carrega o modelo treinado
model = load_model(configs.model_path + "/model.hdf5")

# Carrega a imagem
image = cv2.imread("../data/iam_data/words/a01/a01-000u/a01-000u-00-00.png")

# Pré-processa a imagem
image = cv2.resize(image, (configs.image_size, configs.image_size))
image = image.astype("float") / 255.0
image = np.expand_dims(image, axis=0)

# Faz a previsão na imagem
prediction = model.predict(image)

# Imprime a previsão
print(prediction)

OSError: Unable to open file (file signature not found)