Model to recognise hand-written text

pip install tensorflow

pip install stow

pip install mltu==0.1.5

In [24]:
# Libaries
import tensorflow as tf
import stow
import tarfile
from tqdm import tqdm
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile
import os

# Data collection

In [35]:
dataset_path = r'C:\Users\ljant\Downloads\IAM_Words\IAM_Words'

# Initialize the dataset list and vocabulary set
dataset = []
vocab = set()
max_len = 0

# Path to the words.txt file
words_file_path = os.path.join(dataset_path, "words.txt")

# Reading lines from words.txt
with open(words_file_path, "r") as file:
    words = file.readlines()

for line in tqdm(words):
    if line.startswith("#"):
        continue

    line_split = line.split(" ")
    if line_split[1] == "err":
        continue

    folder1 = line_split[0][:3]
    folder2 = line_split[0][:8]
    file_name = line_split[0] + ".png"
    label = line_split[-1].rstrip('\n')

    # Constructing the relative path to the image
    rel_path = os.path.join(dataset_path, "words", folder1, folder2, file_name)

    # Check if the image file exists
    if not os.path.exists(rel_path):
        continue

    # Append the relative path and label to the dataset list
    dataset.append([rel_path, label])
    
    # Update the vocabulary set with characters from the label
    vocab.update(list(label))
    
    # Update the maximum label length
    max_len = max(max_len, len(label))

# Now, `dataset` is a list of [image_path, label] and `vocab` contains all unique characters in the labels.


100%|██████████| 115338/115338 [00:06<00:00, 18625.32it/s]


In [36]:
%%writefile configs.py
import os
from datetime import datetime

from mltu.configs import BaseModelConfigs

class ModelConfigs(BaseModelConfigs):
    def __init__(self):
        super().__init__()
        self.model_path = os.path.join("C:\Users\ljant\Desktop\Ironhack\Projects\Final-Project-Ironhack-2024", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
        self.vocab = ""
        self.height = 32
        self.width = 128
        self.max_text_length = 0
        self.batch_size = 16
        self.learning_rate = 0.0005
        self.train_epochs = 1000
        self.train_workers = 20

Overwriting configs.py


In [37]:
from configs import ModelConfigs

In [38]:
configs = ModelConfigs()

configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()

# Create a data provider for the dataset

In [34]:
%%writefile dataProvider.py
import os
import copy
import typing
import numpy as np
import pandas as pd
from tqdm import tqdm

from .augmentors import Augmentor
from .transformers import Transformer

import logging
logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s: %(message)s")


class DataProvider:
    def __init__(
            self,
            dataset: typing.Union[str, list, pd.DataFrame],
            data_preprocessors: typing.List[typing.Callable] = None,
            batch_size: int = 4,
            shuffle: bool = True,
            initial_epoch: int = 1,
            augmentors: typing.List[Augmentor] = None,
            transformers: typing.List[Transformer] = None,
            skip_validation: bool = True,
            limit: int = None,
            use_cache: bool = False,
            log_level: int = logging.INFO,
    ) -> None:
        """ Standardised object for providing data to a model while training.

        Attributes:
            dataset (str, list, pd.DataFrame): Path to dataset, list of data or pandas dataframe of data.
            data_preprocessors (list): List of data preprocessors. (e.g. [read image, read audio, etc.])
            batch_size (int): The number of samples to include in each batch. Defaults to 4.
            shuffle (bool): Whether to shuffle the data. Defaults to True.
            initial_epoch (int): The initial epoch. Defaults to 1.
            augmentors (list, optional): List of augmentor functions. Defaults to None.
            transformers (list, optional): List of transformer functions. Defaults to None.
            skip_validation (bool, optional): Whether to skip validation. Defaults to True.
            limit (int, optional): Limit the number of samples in the dataset. Defaults to None.
            use_cache (bool, optional): Whether to cache the dataset. Defaults to False.
            log_level (int, optional): The log level. Defaults to logging.INFO.
        """
        self._dataset = dataset
        self._data_preprocessors = [] if data_preprocessors is None else data_preprocessors
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._epoch = initial_epoch
        self._augmentors = [] if augmentors is None else augmentors
        self._transformers = [] if transformers is None else transformers
        self._skip_validation = skip_validation
        self._limit = limit
        self._use_cache = use_cache
        self._step = 0
        self._cache = {}
        self._on_epoch_end_remove = []

        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.setLevel(log_level)

        # Validate dataset
        if not skip_validation:
            self._dataset = self.validate(dataset)
        else:
            self.logger.info("Skipping Dataset validation...")

        if limit:
            self.logger.info(f"Limiting dataset to {limit} samples.")
            self._dataset = self._dataset[:limit]

    def __len__(self):
        """ Denotes the number of batches per epoch """
        return int(np.ceil(len(self._dataset) / self._batch_size))

    @property
    def augmentors(self) -> typing.List[Augmentor]:
        """ Return augmentors """
        return self._augmentors

    @augmentors.setter
    def augmentors(self, augmentors: typing.List[Augmentor]):
        """ Decorator for adding augmentors to the DataProvider """
        for augmentor in augmentors:
            if isinstance(augmentor, Augmentor):
                if self._augmentors is not None:
                    self._augmentors.append(augmentor)
                else:
                    self._augmentors = [augmentor]

            else:
                self.logger.warning(f"Augmentor {augmentor} is not an instance of Augmentor.")

    @property
    def transformers(self) -> typing.List[Transformer]:
        """ Return transformers """
        return self._transformers

    @transformers.setter
    def transformers(self, transformers: typing.List[Transformer]):
        """ Decorator for adding transformers to the DataProvider """
        for transformer in transformers:
            if isinstance(transformer, Transformer):
                if self._transformers is not None:
                    self._transformers.append(transformer)
                else:
                    self._transformers = [transformer]

            else:
                self.logger.warning(f"Transformer {transformer} is not an instance of Transformer.")

    @property
    def epoch(self) -> int:
        """ Return Current Epoch"""
        return self._epoch

    @property
    def step(self) -> int:
        """ Return Current Step"""
        return self._step

    def on_epoch_end(self):
        """ Shuffle training dataset and increment epoch counter at the end of each epoch. """
        self._epoch += 1
        if self._shuffle:
            np.random.shuffle(self._dataset)

        # Remove any samples that were marked for removal
        for remove in self._on_epoch_end_remove:
            self.logger.warning(f"Removing {remove} from dataset.")
            self._dataset.remove(remove)
        self._on_epoch_end_remove = []

    def validate_list_dataset(self, dataset: list) -> list:
        """ Validate a list dataset """
        validated_data = [data for data in tqdm(dataset, desc="Validating Dataset") if os.path.exists(data[0])]
        if not validated_data:
            raise FileNotFoundError("No valid data found in dataset.")

        return validated_data

    def validate(self, dataset: typing.Union[str, list, pd.DataFrame]) -> typing.Union[list, str]:
        """ Validate the dataset and return the dataset """

        if isinstance(dataset, str):
            if os.path.exists(dataset):
                return dataset
        elif isinstance(dataset, list):
            return self.validate_list_dataset(dataset)
        elif isinstance(dataset, pd.DataFrame):
            return self.validate_list_dataset(dataset.values.tolist())
        else:
            raise TypeError("Dataset must be a path, list or pandas dataframe.")

    def split(self, split: float = 0.9, shuffle: bool = True) -> typing.Tuple[typing.Any, typing.Any]:
        """ Split current data provider into training and validation data providers. 
        
        Args:
            split (float, optional): The split ratio. Defaults to 0.9.
            shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.

        Returns:
            train_data_provider (tf.keras.utils.Sequence): The training data provider.
            val_data_provider (tf.keras.utils.Sequence): The validation data provider.
        """
        if shuffle:
            np.random.shuffle(self._dataset)
            
        train_data_provider, val_data_provider = copy.deepcopy(self), copy.deepcopy(self)
        train_data_provider._dataset = self._dataset[:int(len(self._dataset) * split)]
        val_data_provider._dataset = self._dataset[int(len(self._dataset) * split):]

        return train_data_provider, val_data_provider

    def to_csv(self, path: str, index: bool = False) -> None:
        """ Save the dataset to a csv file 

        Args:
            path (str): The path to save the csv file.
            index (bool, optional): Whether to save the index. Defaults to False.
        """
        df = pd.DataFrame(self._dataset)
        df.to_csv(path, index=index)

    def get_batch_annotations(self, index: int) -> typing.List:
        """ Returns a batch of annotations by batch index in the dataset

        Args:
            index (int): The index of the batch in 

        Returns:
            batch_annotations (list): A list of batch annotations
        """
        self._step = index
        start_index = index * self._batch_size

        # Get batch indexes
        batch_indexes = [i for i in range(start_index, start_index + self._batch_size) if i < len(self._dataset)]

        # Read batch data
        batch_annotations = [self._dataset[index] for index in batch_indexes]

        return batch_annotations
    
    def __iter__(self):
        """ Create a generator that iterate over the Sequence."""
        for item in (self[i] for i in range(len(self))):
            yield item

    def process_data(self, batch_data):
        """ Process data batch of data """
        if self._use_cache and batch_data[0] in self._cache:
            data, annotation = copy.deepcopy(self._cache[batch_data[0]])
        else:
            data, annotation = batch_data
            for preprocessor in self._data_preprocessors:
                data, annotation = preprocessor(data, annotation)
            
            if data is None or annotation is None:
                self.logger.warning("Data or annotation is None, marking for removal on epoch end.")
                self._on_epoch_end_remove.append(batch_data)
                return None, None
            
            if self._use_cache and batch_data[0] not in self._cache:
                self._cache[batch_data[0]] = (copy.deepcopy(data), copy.deepcopy(annotation))

        # Then augment, transform and postprocess the batch data
        for objects in [self._augmentors, self._transformers]:
            for _object in objects:
                data, annotation = _object(data, annotation)

        # Convert to numpy array if not already
        if not isinstance(data, np.ndarray):
            data = data.numpy()

        # Convert to numpy array if not already
        # TODO: This is a hack, need to fix this
        if not isinstance(annotation, (np.ndarray, int, float, str, np.uint8, float)):
            annotation = annotation.numpy()

        return data, annotation

    def __getitem__(self, index: int):
        """ Returns a batch of data by batch index"""
        dataset_batch = self.get_batch_annotations(index)
        
        # First read and preprocess the batch data
        batch_data, batch_annotations = [], []
        for index, batch in enumerate(dataset_batch):

            data, annotation = self.process_data(batch)

            if data is None or annotation is None:
                self.logger.warning("Data or annotation is None, skipping.")
                continue

            batch_data.append(data)
            batch_annotations.append(annotation)

        return np.array(batch_data), np.array(batch_annotations)

Writing dataProvider.py


In [50]:
import sys
import os

# Print the current search path
print(sys.path)

# Add a new directory to the search path
# Replace '/path/to/your/module' with the actual path to the directory containing your modules
module_path = r'C:\Users\ljant\Desktop\Ironhack\Projects\Final-Project-Ironhack-2024\dataProvider.py'
if module_path not in sys.path:
    sys.path.append(module_path)

# Try importing again
from augmentors import Augmentor
from transformers import Transformer

data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(CVImage)],
    transformers=[
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
        ],
)

['C:\\Users\\ljant\\Desktop\\Ironhack\\Projects\\Final-Project-Ironhack-2024', 'C:\\Users\\ljant\\anaconda3\\python39.zip', 'C:\\Users\\ljant\\anaconda3\\DLLs', 'C:\\Users\\ljant\\anaconda3\\lib', 'C:\\Users\\ljant\\anaconda3', '', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages\\locket-0.2.1-py3.9.egg', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages\\win32', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages\\win32\\lib', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages\\Pythonwin', 'C:\\Users\\ljant\\anaconda3\\lib\\site-packages\\IPython\\extensions', 'C:\\Users\\ljant\\.ipython', 'C:\\Users\\ljant\\Desktop\\Ironhack\\Projects']


ModuleNotFoundError: No module named 'augmentors'

In [29]:
%%writefile model.py
from keras import layers
from keras.models import Model

from mltu.tensorflow.model_utils import residual_block


def train_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
    
    inputs = layers.Input(shape=input_dim, name="input")

    # normalize images here instead in preprocessing step
    input = layers.Lambda(lambda x: x / 255)(inputs)

    x1 = residual_block(input, 16, activation=activation, skip_conv=True, strides=1, dropout=dropout)

    x2 = residual_block(x1, 16, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x3 = residual_block(x2, 16, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    x4 = residual_block(x3, 32, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x5 = residual_block(x4, 32, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    x6 = residual_block(x5, 64, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x7 = residual_block(x6, 64, activation=activation, skip_conv=True, strides=1, dropout=dropout)

    x8 = residual_block(x7, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)
    x9 = residual_block(x8, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)

    blstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(squeezed)
    blstm = layers.Dropout(dropout)(blstm)

    output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)

    model = Model(inputs=inputs, outputs=output)
    return model

Writing model.py


In [30]:
from model import train_model