In [62]:
# importação dos dados
import os
import tarfile
from tqdm import tqdm
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

import torch
import torch.optim as optim
from torchsummaryX import summary

from mltu.torch.model import Model
from mltu.torch.losses import CTCLoss
from mltu.torch.dataProvider import DataProvider
from mltu.torch.metrics import CERMetric, WERMetric
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, ReduceLROnPlateau

from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen

from model import Network
from configs import ModelConfigs
import cv2

- Pré processamento dos dados

In [49]:
dataset, vocab, max_len = [], set(), 0
dataset_path = "../data/iam_data/"
# Processando o dataset pelo formato especifico do dataset IAM_Words
words = open(f"{dataset_path}words.txt", "r").readlines()
for line in tqdm(words):
    if line.startswith("#"):
        continue

    line_split = line.split(" ")
    if line_split[1] == "err":
        continue

    folder1 = line_split[0][:3]
    folder2 = "-".join(line_split[0].split("-")[:2])
    file_name = line_split[0] + ".png"
    label = line_split[-1].rstrip('\n')
    rel_path = f"{dataset_path}words/{folder1}/{folder2}/{file_name}"
    if not os.path.exists(rel_path):
        print(f"File not found: {rel_path}")
        continue

    dataset.append([rel_path, label])
    vocab.update(list(label))
    max_len = max(max_len, len(label))

configs = ModelConfigs()

# Save vocab and maximum text length to configs
configs.vocab = "".join(sorted(vocab))
configs.max_text_length = max_len
configs.save()

100%|██████████| 115320/115320 [00:05<00:00, 22824.04it/s]


In [50]:
dataset

[['../data/iam_data/words/a01/a01-000u/a01-000u-00-00.png', 'A'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-01.png', 'MOVE'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-02.png', 'to'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-03.png', 'stop'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-04.png', 'Mr.'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-05.png', 'Gaitskell'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-00-06.png', 'from'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-00.png', 'nominating'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-01.png', 'any'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-02.png', 'more'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-03.png', 'Labour'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-04.png', 'life'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-01-05.png', 'Peers'],
 ['../data/iam_data/words/a01/a01-000u/a01-000u-02-00.png', 'is'],
 ['../data/iam_data/words/a01

In [105]:
class ImageReader:
    def __init__(self, image_class):
        self.image_class = image_class
    
    def __call__(self, data, annotation):
        image_path = data[0]
        image = cv2.imread(image_path)
        return image, annotation

In [106]:
# Criação do modelo e definição dos hiperparâmetros
data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(cv2.imread)],
    transformers=[
        ImageShowCV2(), # uncomment to show images during training
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab))
        ],
    use_cache=True,
    workers=1
)

2023-06-02 19:56:35,411 INFO DataProvider: Skipping Dataset validation...


In [52]:
# nao funciona
# for _ in data_provider:
#     pass

In [107]:
# Separação dos dados em treino e teste (90% e 10%)
train_dataProvider, test_dataProvider = data_provider.split(split = 0.9)

In [108]:
# Augment training data with random brightness, rotation and erode/dilate
train_dataProvider.augmentors = [
    RandomBrightness(), 
    RandomErodeDilate(),
    RandomSharpen(),
    RandomRotate(angle=10), 
    ]

In [109]:
network = Network(len(configs.vocab), activation='leaky_relu', dropout=0.3)

In [110]:
loss = CTCLoss(blank=len(configs.vocab))
optimizer = optim.Adam(network.parameters(), lr=configs.learning_rate)

In [None]:
# uncomment to print network summary, torchsummaryX package is required
summary(network, torch.zeros((1, configs.height, configs.width, 3)))

In [111]:
# put on cuda device if available
if torch.cuda.is_available():
    network = network.cuda()

In [112]:
# create callbacks
earlyStopping = EarlyStopping(monitor="val_CER", patience=20, mode="min", verbose=1)
modelCheckpoint = ModelCheckpoint(configs.model_path + "/model.pt", monitor="val_CER", mode="min", save_best_only=True, verbose=1)
tb_callback = TensorBoard(configs.model_path + "/logs")
reduce_lr = ReduceLROnPlateau(monitor="val_CER", factor=0.9, patience=10, verbose=1, mode="min", min_lr=1e-6)
model2onnx = Model2onnx(
    saved_model_path=configs.model_path + "/model.pt",
    input_shape=(1, configs.height, configs.width, 3), 
    verbose=1,
    metadata={"vocab": configs.vocab}
    )

In [113]:
# create model object that will handle training and testing of the network
model = Model(network, optimizer, loss, metrics=[CERMetric(configs.vocab), WERMetric(configs.vocab)])
model.fit(
    train_dataProvider, 
    test_dataProvider, 
    epochs=1000, 
    callbacks=[earlyStopping, modelCheckpoint, tb_callback, reduce_lr, model2onnx]
    )

  0%|          | 0/1357 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [None]:
# Save training and validation datasets as csv files
train_dataProvider.to_csv(os.path.join(configs.model_path, "train.csv"))
test_dataProvider.to_csv(os.path.join(configs.model_path, "val.csv"))