In [None]:
%%writefile requirements.txt
editdistance
lmdb
matplotlib
numpy
tensorflow
Pillow
streamlit
pyngrok
reportlab

In [None]:
!pip install -r requirements.txt

In [None]:
!mkdir src

In [None]:
!mkdir data

In [None]:
!mkdir -p data/characters

In [None]:
%%writefile src/character_recognition.py
import argparse
import json
import os
from typing import Tuple, List

import cv2
import editdistance
from pathlib import Path

from dataloader_iam import DataLoaderIAM, Batch
from model import Model, DecoderType
from preprocessor import Preprocessor


class FilePaths:
    """Filenames and paths to data."""
    fn_char_list = '../model/charList.txt'
    fn_summary = '../model/summary.json'
    fn_corpus = '../data/corpus.txt'


def get_img_height() -> int:
    """Fixed height for NN."""
    return 32


def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
    """Height is fixed for NN, width is set according to training mode (single words or text lines)."""
    if line_mode:
        return 256, get_img_height()
    return 128, get_img_height()


def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None:
    """Writes training summary file for NN."""
    with open(FilePaths.fn_summary, 'w') as f:
        json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)


def train(model: Model,
          loader: DataLoaderIAM,
          line_mode: bool,
          early_stopping: int = 25) -> None:
    """Trains NN."""
    epoch = 0  # number of training epochs since start
    summary_char_error_rates = []
    summary_word_accuracies = []
    preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode)
    best_char_error_rate = float('inf')  # best validation character error rate
    no_improvement_since = 0  # number of epochs no improvement of character error rate occurred
    # stop training after this number of epochs without improvement
    while True:
        epoch += 1
        print('Epoch:', epoch)

        # train
        print('Train NN')
        loader.train_set()
        while loader.has_next():
            iter_info = loader.get_iterator_info()
            batch = loader.get_next()
            batch = preprocessor.process_batch(batch)
            loss = model.train_batch(batch)
            print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')

        # validate
        char_error_rate, word_accuracy = validate(model, loader, line_mode)

        # write summary
        summary_char_error_rates.append(char_error_rate)
        summary_word_accuracies.append(word_accuracy)
        write_summary(summary_char_error_rates, summary_word_accuracies)

        # if best validation accuracy so far, save model parameters
        if char_error_rate < best_char_error_rate:
            print('Character error rate improved, save model')
            best_char_error_rate = char_error_rate
            no_improvement_since = 0
            model.save()
        else:
            print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%')
            no_improvement_since += 1

        # stop training if no more improvement in the last x epochs
        if no_improvement_since >= early_stopping:
            print(f'No more improvement since {early_stopping} epochs. Training stopped.')
            break


def validate(model: Model, loader: DataLoaderIAM, line_mode: bool) -> Tuple[float, float]:
    """Validates NN."""
    print('Validate NN')
    loader.validation_set()
    preprocessor = Preprocessor(get_img_size(line_mode), line_mode=line_mode)
    num_char_err = 0
    num_char_total = 0
    num_word_ok = 0
    num_word_total = 0
    while loader.has_next():
        iter_info = loader.get_iterator_info()
        print(f'Batch: {iter_info[0]} / {iter_info[1]}')
        batch = loader.get_next()
        batch = preprocessor.process_batch(batch)
        recognized, _ = model.infer_batch(batch)

        print('Ground truth -> Recognized')
        for i in range(len(recognized)):
            num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0
            num_word_total += 1
            dist = editdistance.eval(recognized[i], batch.gt_texts[i])
            num_char_err += dist
            num_char_total += len(batch.gt_texts[i])
            print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->',
                  '"' + recognized[i] + '"')

    # print validation result
    char_error_rate = num_char_err / num_char_total
    word_accuracy = num_word_ok / num_word_total
    print(f'Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.')
    return char_error_rate, word_accuracy


def infer(model: Model, fn_img: Path) -> None:
    """Recognizes text in image provided by file path."""
    img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
    assert img is not None

    preprocessor = Preprocessor(get_img_size(), dynamic_width=True, padding=16)
    img = preprocessor.process_img(img)

    batch = Batch([img], None, 1)
    recognized, probability = model.infer_batch(batch, True)
    print(f'Recognized: "{recognized[0]}"')
    print(f'Probability: {probability[0]}')
    return recognized[0]


# Slightly edited to allow for custom file names to be analysed
def main():
    """Main function."""
    CharacterFile = ""
    Files = os.listdir("/content/data/characters")
    for File in Files:
        if "Cropped" in File:
            CharacterFile = File
    print(CharacterFile)
    if CharacterFile != "":
        parser = argparse.ArgumentParser()

        parser.add_argument('--mode', choices=['train', 'validate', 'infer'], default='infer')
        parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath')
        parser.add_argument('--batch_size', help='Batch size.', type=int, default=100)
        parser.add_argument('--data_dir', help='Directory containing IAM dataset.', type=Path, required=False)
        parser.add_argument('--fast', help='Load samples from LMDB.', action='store_true')
        parser.add_argument('--line_mode', help='Train to read text lines instead of single words.',
                            action='store_true')
        parser.add_argument('--img_file', help='Image used for inference.', type=Path,
                            default='../data/characters/{}'.format(CharacterFile))
        parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25)
        parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true')
        args = parser.parse_args()

        # set chosen CTC decoder
        decoder_mapping = {'bestpath': DecoderType.BestPath,
                           'beamsearch': DecoderType.BeamSearch,
                           'wordbeamsearch': DecoderType.WordBeamSearch}
        decoder_type = decoder_mapping[args.decoder]

        # train or validate on IAM dataset
        if args.mode in ['train', 'validate']:
            # load training data, create TF model
            loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast)
            char_list = loader.char_list

            # when in line mode, take care to have a whitespace in the char list
            if args.line_mode and ' ' not in char_list:
                char_list = [' '] + char_list

            # save characters of model for inference mode
            open(FilePaths.fn_char_list, 'w').write(''.join(char_list))

            # save words contained in dataset into file
            open(FilePaths.fn_corpus, 'w').write(' '.join(loader.train_words + loader.validation_words))

            # execute training or validation
            if args.mode == 'train':
                model = Model(char_list, decoder_type)
                train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping)
            elif args.mode == 'validate':
                model = Model(char_list, decoder_type, must_restore=True)
                validate(model, loader, args.line_mode)

        # infer text on image
        elif args.mode == 'infer':
            model = Model(list(open(FilePaths.fn_char_list).read()), decoder_type, must_restore=True, dump=args.dump)
            # Saving the new file as the character name
            Character = infer(model, args.img_file)
            if Character == '"':
                Character = "SM"
            if Character == "?":
                Character = "QM"
            if Character == "/":
                Character = "FS"
            if Character == '\\':
                Character = "BS"
            if Character == "*":
                Character = "AS"
            if Character == "<":
                Character = "LT"
            if Character == ">":
                Character = "GT"
            if Character == "|":
                Character = "VL"

            try:
                os.rename("../data/characters/" + CharacterFile, "../data/characters/" + Character + ".png")
            except IOError:
                NewFile = False
                Counter = 1
                while not NewFile:
                    try:
                        os.rename("../data/characters/" + CharacterFile,
                                  "../data/characters/{}{}.png".format(Character, "." * Counter))
                        NewFile = True
                    except IOError:
                        Counter = Counter + 1


if __name__ == '__main__':
    main()


In [None]:
%%writefile src/create_lmdb.py
import argparse
import pickle

import cv2
import lmdb
from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=Path, required=True)
args = parser.parse_args()

# 2GB is enough for IAM dataset
assert not (args.data_dir / 'lmdb').exists()
env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2)

# go over all png files
fn_imgs = list((args.data_dir / 'img').walkfiles('*.png'))

# and put the imgs into lmdb as pickled grayscale imgs
with env.begin(write=True) as txn:
    for i, fn_img in enumerate(fn_imgs):
        print(i, len(fn_imgs))
        img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
        basename = fn_img.basename()
        txn.put(basename.encode("ascii"), pickle.dumps(img))

env.close()


In [None]:
%%writefile src/dataloader_iam.py
import pickle
import random
from collections import namedtuple
from typing import Tuple

import cv2
import lmdb
import numpy as np
from pathlib import Path

Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')


class DataLoaderIAM:
    """
    Loads data which corresponds to IAM format,
    see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
    """

    def __init__(self,
                 data_dir: Path,
                 batch_size: int,
                 data_split: float = 0.95,
                 fast: bool = True) -> None:
        """Loader for dataset."""

        assert data_dir.exists()

        self.fast = fast
        if fast:
            self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)

        self.data_augmentation = False
        self.curr_idx = 0
        self.batch_size = batch_size
        self.samples = []

        f = open(data_dir / 'gt/words.txt')
        chars = set()
        bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05']  # known broken images in IAM dataset
        for line in f:
            # ignore comment line
            if not line or line[0] == '#':
                continue

            line_split = line.strip().split(' ')
            assert len(line_split) >= 9

            # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
            file_name_split = line_split[0].split('-')
            file_name_subdir1 = file_name_split[0]
            file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
            file_base_name = line_split[0] + '.png'
            file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name

            if line_split[0] in bad_samples_reference:
                print('Ignoring known broken image:', file_name)
                continue

            # GT text are columns starting at 9
            gt_text = ' '.join(line_split[8:])
            chars = chars.union(set(list(gt_text)))

            # put sample into list
            self.samples.append(Sample(gt_text, file_name))

        # split into training and validation set: 95% - 5%
        split_idx = int(data_split * len(self.samples))
        self.train_samples = self.samples[:split_idx]
        self.validation_samples = self.samples[split_idx:]

        # put words into lists
        self.train_words = [x.gt_text for x in self.train_samples]
        self.validation_words = [x.gt_text for x in self.validation_samples]

        # start with train set
        self.train_set()

        # list of all chars in dataset
        self.char_list = sorted(list(chars))

    def train_set(self) -> None:
        """Switch to randomly chosen subset of training set."""
        self.data_augmentation = True
        self.curr_idx = 0
        random.shuffle(self.train_samples)
        self.samples = self.train_samples
        self.curr_set = 'train'

    def validation_set(self) -> None:
        """Switch to validation set."""
        self.data_augmentation = False
        self.curr_idx = 0
        self.samples = self.validation_samples
        self.curr_set = 'val'

    def get_iterator_info(self) -> Tuple[int, int]:
        """Current batch index and overall number of batches."""
        if self.curr_set == 'train':
            num_batches = int(np.floor(len(self.samples) / self.batch_size))  # train set: only full-sized batches
        else:
            num_batches = int(np.ceil(len(self.samples) / self.batch_size))  # val set: allow last batch to be smaller
        curr_batch = self.curr_idx // self.batch_size + 1
        return curr_batch, num_batches

    def has_next(self) -> bool:
        """Is there a next element?"""
        if self.curr_set == 'train':
            return self.curr_idx + self.batch_size <= len(self.samples)  # train set: only full-sized batches
        else:
            return self.curr_idx < len(self.samples)  # val set: allow last batch to be smaller

    def _get_img(self, i: int) -> np.ndarray:
        if self.fast:
            with self.env.begin() as txn:
                basename = Path(self.samples[i].file_path).basename()
                data = txn.get(basename.encode("ascii"))
                img = pickle.loads(data)
        else:
            img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)

        return img

    def get_next(self) -> Batch:
        """Get next element."""
        batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))

        imgs = [self._get_img(i) for i in batch_range]
        gt_texts = [self.samples[i].gt_text for i in batch_range]

        self.curr_idx += self.batch_size
        return Batch(imgs, gt_texts, len(imgs))


In [None]:
%%writefile src/model.py
import os
import sys
from typing import List, Tuple

import numpy as np
import tensorflow as tf

from dataloader_iam import Batch

# Disable eager mode
tf.compat.v1.disable_eager_execution()


class DecoderType:
    """CTC decoder types."""
    BestPath = 0
    BeamSearch = 1
    WordBeamSearch = 2


class Model:
    """Minimalistic TF model for HTR."""

    def __init__(self,
                 char_list: List[str],
                 decoder_type: str = DecoderType.BestPath,
                 must_restore: bool = False,
                 dump: bool = False) -> None:
        """Init model: add CNN, RNN and CTC and initialize TF."""
        self.dump = dump
        self.char_list = char_list
        self.decoder_type = decoder_type
        self.must_restore = must_restore
        self.snap_ID = 0

        # Whether to use normalization over a batch or a population
        self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train')

        # input image batch
        self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None))

        # setup CNN, RNN and CTC
        self.setup_cnn()
        self.setup_rnn()
        self.setup_ctc()

        # setup optimizer to train NN
        self.batches_trained = 0
        self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(self.update_ops):
            self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss)

        # initialize TF
        self.sess, self.saver = self.setup_tf()

    def setup_cnn(self) -> None:
        """Create CNN layers."""
        cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3)

        # list of parameters for the layers
        kernel_vals = [5, 5, 3, 3, 3]
        feature_vals = [1, 32, 64, 128, 128, 256]
        stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)]
        num_layers = len(stride_vals)

        # create layers
        pool = cnn_in4d  # input to first CNN layer
        for i in range(num_layers):
            kernel = tf.Variable(
                tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]],
                                           stddev=0.1))
            conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1))
            conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train)
            relu = tf.nn.relu(conv_norm)
            pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1),
                                    strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID')

        self.cnn_out_4d = pool

    def setup_rnn(self) -> None:
        """Create RNN layers."""
        rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2])

        # basic cells which is used to build RNN
        num_hidden = 256
        cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in
                 range(2)]  # 2 layers

        # stack basic cells
        stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)

        # bidirectional RNN
        # BxTxF -> BxTx2H
        (fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d,
                                                                dtype=rnn_in3d.dtype)

        # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
        concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)

        # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
        kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1))
        self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'),
                                     axis=[2])

    def setup_ctc(self) -> None:
        """Create CTC loss and decoder."""
        # BxTxC -> TxBxC
        self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2])
        # ground truth text as sparse tensor
        self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]),
                                        tf.compat.v1.placeholder(tf.int32, [None]),
                                        tf.compat.v1.placeholder(tf.int64, [2]))

        # calc loss for batch
        self.seq_len = tf.compat.v1.placeholder(tf.int32, [None])
        self.loss = tf.reduce_mean(
            input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc,
                                                  sequence_length=self.seq_len,
                                                  ctc_merge_repeated=True))

        # calc loss for each element to compute label probability
        self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32,
                                                        shape=[None, None, len(self.char_list) + 1])
        self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input,
                                                         sequence_length=self.seq_len, ctc_merge_repeated=True)

        # best path decoding or beam search decoding
        if self.decoder_type == DecoderType.BestPath:
            self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len)
        elif self.decoder_type == DecoderType.BeamSearch:
            self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len,
                                                         beam_width=50)
        # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch)
        elif self.decoder_type == DecoderType.WordBeamSearch:
            # prepare information about language (dictionary, characters in dataset, characters forming words)
            chars = ''.join(self.char_list)
            word_chars = open('../model/wordCharList.txt').read().splitlines()[0]
            corpus = open('../data/corpus.txt').read()

            # decode using the "Words" mode of word beam search
            from word_beam_search import WordBeamSearch
            self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'),
                                          word_chars.encode('utf8'))

            # the input to the decoder must have softmax already applied
            self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2)

    def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]:
        """Initialize TF."""
        print('Python: ' + sys.version)
        print('Tensorflow: ' + tf.__version__)

        sess = tf.compat.v1.Session()  # TF session

        saver = tf.compat.v1.train.Saver(max_to_keep=1)  # saver saves model to file
        model_dir = '../model/'
        latest_snapshot = tf.train.latest_checkpoint(model_dir)  # is there a saved model?

        # if model must be restored (for inference), there must be a snapshot
        if self.must_restore and not latest_snapshot:
            raise Exception('No saved model found in: ' + model_dir)

        # load saved model if available
        if latest_snapshot:
            print('Init with stored values from ' + latest_snapshot)
            saver.restore(sess, latest_snapshot)
        else:
            print('Init with new values')
            sess.run(tf.compat.v1.global_variables_initializer())

        return sess, saver

    def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]:
        """Put ground truth texts into sparse tensor for ctc_loss."""
        indices = []
        values = []
        shape = [len(texts), 0]  # last entry must be max(labelList[i])

        # go over all texts
        for batchElement, text in enumerate(texts):
            # convert to string of label (i.e. class-ids)
            label_str = [self.char_list.index(c) for c in text]
            # sparse tensor must have size of max. label-string
            if len(label_str) > shape[1]:
                shape[1] = len(label_str)
            # put each label into sparse tensor
            for i, label in enumerate(label_str):
                indices.append([batchElement, i])
                values.append(label)

        return indices, values, shape

    def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]:
        """Extract texts from output of CTC decoder."""

        # word beam search: already contains label strings
        if self.decoder_type == DecoderType.WordBeamSearch:
            label_strs = ctc_output

        # TF decoders: label strings are contained in sparse tensor
        else:
            # ctc returns tuple, first element is SparseTensor
            decoded = ctc_output[0][0]

            # contains string of labels for each batch element
            label_strs = [[] for _ in range(batch_size)]

            # go over all indices and save mapping: batch -> values
            for (idx, idx2d) in enumerate(decoded.indices):
                label = decoded.values[idx]
                batch_element = idx2d[0]  # index according to [b,t]
                label_strs[batch_element].append(label)

        # map labels to chars for all batch elements
        return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs]

    def train_batch(self, batch: Batch) -> float:
        """Feed a batch into the NN to train it."""
        num_batch_elements = len(batch.imgs)
        max_text_len = batch.imgs[0].shape[0] // 4
        sparse = self.to_sparse(batch.gt_texts)
        eval_list = [self.optimizer, self.loss]
        feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse,
                     self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True}
        _, loss_val = self.sess.run(eval_list, feed_dict)
        self.batches_trained += 1
        return loss_val

    @staticmethod
    def dump_nn_output(rnn_output: np.ndarray) -> None:
        """Dump the output of the NN to CSV file(s)."""
        dump_dir = '../dump/'
        if not os.path.isdir(dump_dir):
            os.mkdir(dump_dir)

        # iterate over all batch elements and create a CSV file for each one
        max_t, max_b, max_c = rnn_output.shape
        for b in range(max_b):
            csv = ''
            for t in range(max_t):
                for c in range(max_c):
                    csv += str(rnn_output[t, b, c]) + ';'
                csv += '\n'
            fn = dump_dir + 'rnnOutput_' + str(b) + '.csv'
            print('Write dump of NN to file: ' + fn)
            with open(fn, 'w') as f:
                f.write(csv)

    def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False):
        """Feed a batch into the NN to recognize the texts."""

        # decode, optionally save RNN output
        num_batch_elements = len(batch.imgs)

        # put tensors to be evaluated into list
        eval_list = []

        if self.decoder_type == DecoderType.WordBeamSearch:
            eval_list.append(self.wbs_input)
        else:
            eval_list.append(self.decoder)

        if self.dump or calc_probability:
            eval_list.append(self.ctc_in_3d_tbc)

        # sequence length depends on input image size (model downsizes width by 4)
        max_text_len = batch.imgs[0].shape[0] // 4

        # dict containing all tensor fed into the model
        feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements,
                     self.is_train: False}

        # evaluate model
        eval_res = self.sess.run(eval_list, feed_dict)

        # TF decoders: decoding already done in TF graph
        if self.decoder_type != DecoderType.WordBeamSearch:
            decoded = eval_res[0]
        # word beam search decoder: decoding is done in C++ function compute()
        else:
            decoded = self.decoder.compute(eval_res[0])

        # map labels (numbers) to character string
        texts = self.decoder_output_to_text(decoded, num_batch_elements)

        # feed RNN output and recognized text into CTC loss to compute labeling probability
        probs = None
        if calc_probability:
            sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts)
            ctc_input = eval_res[1]
            eval_list = self.loss_per_element
            feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse,
                         self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False}
            loss_vals = self.sess.run(eval_list, feed_dict)
            probs = np.exp(-loss_vals)

        # dump the output of the NN to CSV file(s)
        if self.dump:
            self.dump_nn_output(eval_res[1])

        return texts, probs

    def save(self) -> None:
        """Save model to file."""
        self.snap_ID += 1
        self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID)


In [None]:
%%writefile src/preprocessor.py
import random
from typing import Tuple

import cv2
import numpy as np

from dataloader_iam import Batch


class Preprocessor:
    def __init__(self,
                 img_size: Tuple[int, int],
                 padding: int = 0,
                 dynamic_width: bool = False,
                 data_augmentation: bool = False,
                 line_mode: bool = False) -> None:
        # dynamic width only supported when no data augmentation happens
        assert not (dynamic_width and data_augmentation)
        # when padding is on, we need dynamic width enabled
        assert not (padding > 0 and not dynamic_width)

        self.img_size = img_size
        self.padding = padding
        self.dynamic_width = dynamic_width
        self.data_augmentation = data_augmentation
        self.line_mode = line_mode

    @staticmethod
    def _truncate_label(text: str, max_text_len: int) -> str:
        """
        Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
        labels. Repeat letters cost double because of the blank symbol needing to be inserted.
        If a too-long label is provided, ctc_loss returns an infinite gradient.
        """
        cost = 0
        for i in range(len(text)):
            if i != 0 and text[i] == text[i - 1]:
                cost += 2
            else:
                cost += 1
            if cost > max_text_len:
                return text[:i]
        return text

    def _simulate_text_line(self, batch: Batch) -> Batch:
        """Create image of a text line by pasting multiple word images into an image."""

        default_word_sep = 30
        default_num_words = 5

        # go over all batch elements
        res_imgs = []
        res_gt_texts = []
        for i in range(batch.batch_size):
            # number of words to put into current line
            num_words = random.randint(1, 8) if self.data_augmentation else default_num_words

            # concat ground truth texts
            curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
            res_gt_texts.append(curr_gt)

            # put selected word images into list, compute target image size
            sel_imgs = []
            word_seps = [0]
            h = 0
            w = 0
            for j in range(num_words):
                curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
                curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
                h = max(h, curr_sel_img.shape[0])
                w += curr_sel_img.shape[1]
                sel_imgs.append(curr_sel_img)
                if j + 1 < num_words:
                    w += curr_word_sep
                    word_seps.append(curr_word_sep)

            # put all selected word images into target image
            target = np.ones([h, w], np.uint8) * 255
            x = 0
            for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
                x += curr_word_sep
                y = (h - curr_sel_img.shape[0]) // 2
                target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
                x += curr_sel_img.shape[1]

            # put image of line into result
            res_imgs.append(target)

        return Batch(res_imgs, res_gt_texts, batch.batch_size)

    def process_img(self, img: np.ndarray) -> np.ndarray:
        """Resize to target size, apply data augmentation."""

        # there are damaged files in IAM dataset - just use black image instead
        if img is None:
            img = np.zeros(self.img_size[::-1])

        # data augmentation
        img = img.astype(np.float)
        if self.data_augmentation:
            # photometric data augmentation
            if random.random() < 0.25:
                def rand_odd():
                    return random.randint(1, 3) * 2 + 1
                img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
            if random.random() < 0.25:
                img = cv2.dilate(img, np.ones((3, 3)))
            if random.random() < 0.25:
                img = cv2.erode(img, np.ones((3, 3)))

            # geometric data augmentation
            wt, ht = self.img_size
            h, w = img.shape
            f = min(wt / w, ht / h)
            fx = f * np.random.uniform(0.75, 1.05)
            fy = f * np.random.uniform(0.75, 1.05)

            # random position around center
            txc = (wt - w * fx) / 2
            tyc = (ht - h * fy) / 2
            freedom_x = max((wt - fx * w) / 2, 0)
            freedom_y = max((ht - fy * h) / 2, 0)
            tx = txc + np.random.uniform(-freedom_x, freedom_x)
            ty = tyc + np.random.uniform(-freedom_y, freedom_y)

            # map image into target image
            M = np.float32([[fx, 0, tx], [0, fy, ty]])
            target = np.ones(self.img_size[::-1]) * 255
            img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)

            # photometric data augmentation
            if random.random() < 0.5:
                img = img * (0.25 + random.random() * 0.75)
            if random.random() < 0.25:
                img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
            if random.random() < 0.1:
                img = 255 - img

        # no data augmentation
        else:
            if self.dynamic_width:
                ht = self.img_size[1]
                h, w = img.shape
                f = ht / h
                wt = int(f * w + self.padding)
                wt = wt + (4 - wt) % 4
                tx = (wt - w * f) / 2
                ty = 0
            else:
                wt, ht = self.img_size
                h, w = img.shape
                f = min(wt / w, ht / h)
                tx = (wt - w * f) / 2
                ty = (ht - h * f) / 2

            # map image into target image
            M = np.float32([[f, 0, tx], [0, f, ty]])
            target = np.ones([ht, wt]) * 255
            img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)

        # transpose for TF
        img = cv2.transpose(img)

        # convert to range [-1, 1]
        img = img / 255 - 0.5
        return img

    def process_batch(self, batch: Batch) -> Batch:
        if self.line_mode:
            batch = self._simulate_text_line(batch)

        res_imgs = [self.process_img(img) for img in batch.imgs]
        max_text_len = res_imgs[0].shape[0] // 4
        res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
        return Batch(res_imgs, res_gt_texts, batch.batch_size)


def main():
    import matplotlib.pyplot as plt

    img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
    img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
    plt.subplot(121)
    plt.imshow(img, cmap='gray')
    plt.subplot(122)
    plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
    plt.show()


if __name__ == '__main__':
    main()


In [None]:
%%writefile content.txt
AAA BIG HAIRY DOG SAT ON MY HEAD
DOG           IT WAS VERY HEAVY
SAT
ON                   I EXCLAIMED LOUDLY
MY



      HELP THERE IS  A DOG ON MY HEAD

In [None]:
!python src/character_recognition.py

In [None]:
%%writefile app.py
import streamlit as st
import os
import subprocess
import threading
import tempfile
import shutil
from PIL import Image, ImageFilter, ImageTk, ImageDraw
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from io import BytesIO
import zipfile
import time
import json
from pathlib import Path
import numpy as np
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter, A4
from reportlab.lib.utils import ImageReader
import soundfile as sf

# OCR integration
try:
    from transformers import pipeline
    OCR_AVAILABLE = True
    @st.cache_resource
    def load_ocr_model():
        """Load OCR model with caching"""
        return pipeline('image-to-text', model="microsoft/trocr-base-handwritten")
except ImportError:
    OCR_AVAILABLE = False
    st.warning("OCR functionality not available. Install transformers: pip install transformers torch")

# Speech-to-Text integration
try:
    import torch
    from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
    STT_AVAILABLE = True

    @st.cache_resource
    def load_speech_model():
        """Load Speech-to-Text model with caching"""
        model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
        processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
        return model, processor
except ImportError:
    STT_AVAILABLE = False
    st.warning("Speech-to-Text not available. Install: pip install torch transformers datasets soundfile")

# Configure Streamlit page
st.set_page_config(
    page_title="HandCraft AI",
    page_icon="✍️",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for better styling
st.markdown("""
<style>
    .main-header {
        font-size: 3rem;
        font-weight: bold;
        text-align: center;
        color: #ffffff;
        margin-bottom: 2rem;
        text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
    }
    .step-header {
        font-size: 1.8rem;
        font-weight: bold;
        color: #64b5f6;
        margin-top: 2rem;
        margin-bottom: 1rem;
        border-bottom: 2px solid #64b5f6;
        padding-bottom: 0.5rem;
    }
    .info-box {
        background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%);
        color: #ffffff;
        padding: 1.5rem;
        border-radius: 12px;
        border-left: 4px solid #64b5f6;
        margin: 1rem 0;
        box-shadow: 0 4px 12px rgba(0,0,0,0.3);
    }
    .success-box {
        background: linear-gradient(135deg, #2e7d32 0%, #4caf50 100%);
        color: #ffffff;
        padding: 1.5rem;
        border-radius: 12px;
        border-left: 4px solid #81c784;
        margin: 1rem 0;
        box-shadow: 0 4px 12px rgba(0,0,0,0.3);
    }
    .warning-box {
        background: linear-gradient(135deg, #f57c00 0%, #ff9800 100%);
        color: #ffffff;
        padding: 1.5rem;
        border-radius: 12px;
        border-left: 4px solid #ffb74d;
        margin: 1rem 0;
        box-shadow: 0 4px 12px rgba(0,0,0,0.3);
    }
    .stTabs [data-baseweb="tab-list"] {
        gap: 8px;
    }
    .stTabs [data-baseweb="tab"] {
        background: linear-gradient(135deg, #424242 0%, #616161 100%);
        color: #ffffff;
        border-radius: 8px 8px 0 0;
        padding: 12px 24px;
        border: none;
    }
    .stTabs [aria-selected="true"] {
        background: linear-gradient(135deg, #1976d2 0%, #42a5f5 100%);
        color: #ffffff;
    }
    .metric-card {
        background: linear-gradient(135deg, #263238 0%, #37474f 100%);
        padding: 1rem;
        border-radius: 12px;
        border: 1px solid #546e7a;
        margin: 0.5rem 0;
    }
    .stSelectbox > div > div {
        background-color: #424242;
        color: #ffffff;
    }
    .stTextInput > div > div > input {
        background-color: #424242;
        color: #ffffff;
        border: 1px solid #616161;
    }
    .stButton > button {
        background: linear-gradient(135deg, #1976d2 0%, #42a5f5 100%);
        color: #ffffff;
        border: none;
        border-radius: 8px;
        padding: 0.5rem 1rem;
        font-weight: bold;
        transition: all 0.3s ease;
    }
    .stButton > button:hover {
        background: linear-gradient(135deg, #1565c0 0%, #1976d2 100%);
        box-shadow: 0 4px 12px rgba(25, 118, 210, 0.4);
        transform: translateY(-2px);
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'processed_image' not in st.session_state:
    st.session_state.processed_image = None
if 'characters_extracted' not in st.session_state:
    st.session_state.characters_extracted = []
if 'character_mappings' not in st.session_state:
    st.session_state.character_mappings = {}
if 'processing_complete' not in st.session_state:
    st.session_state.processing_complete = False
if 'temp_dir' not in st.session_state:
    st.session_state.temp_dir = tempfile.mkdtemp()
if 'use_ocr' not in st.session_state:
    st.session_state.use_ocr = OCR_AVAILABLE
if 'processed_image_hash' not in st.session_state:
    st.session_state.processed_image_hash = None
if 'dataset_characters' not in st.session_state:
    st.session_state.dataset_characters = {}
if 'use_dataset' not in st.session_state:
    st.session_state.use_dataset = False
if 'paper_background' not in st.session_state:
    st.session_state.paper_background = None
if 'generated_pages' not in st.session_state:
    st.session_state.generated_pages = []

# Color definitions (same as original)
colours = ((255, 255, 255, "white"),
           (200, 100, 50, "orange"),
           (128, 0, 0, "red"),
           (0, 255, 0, "green"),
           (0, 0, 0, "black"),
           (64, 64, 64, "grey"))

def get_image_hash(image_file):
    """Generate hash for uploaded image to track changes"""
    import hashlib
    image_file.seek(0)
    content = image_file.read()
    image_file.seek(0)
    return hashlib.md5(content).hexdigest()

def add_white_background_to_processed_image(image):
    """Add white background to processed image for better display"""
    if image.mode == "RGBA":
        # Create white background
        white_bg = Image.new("RGB", image.size, "white")
        # Paste the RGBA image onto white background
        white_bg.paste(image, mask=image.split()[-1])  # Use alpha channel as mask
        return white_bg
    elif image.mode == "LA":
        # Handle grayscale with alpha
        white_bg = Image.new("RGB", image.size, "white")
        # Convert LA to RGBA first
        rgba_image = image.convert("RGBA")
        white_bg.paste(rgba_image, mask=rgba_image.split()[-1])
        return white_bg
    else:
        # For RGB or other modes, just convert to RGB with white background
        if image.mode != "RGB":
            image = image.convert("RGB")
        return image

def generate_handwritten_text_with_background(text, char_mappings, dataset_chars, line_spacing, char_spacing, paper_bg, use_dataset=False):
    """Generate handwritten text with paper background"""
    try:
        # Combine custom and dataset characters
        all_characters = {}

        # Add dataset characters first (if using dataset)
        if use_dataset and dataset_chars:
            for char, img_path in dataset_chars.items():
                try:
                    if isinstance(img_path, str):  # If it's a path
                        char_img = Image.open(img_path).convert("RGBA")
                    else:  # If it's already an image
                        char_img = img_path.convert("RGBA")

                    # Resize lowercase characters to be smaller
                    if char.strip().islower():
                        target_height = int(50 * 0.7)  # 70% size for lowercase
                        scale_factor = target_height / char_img.height
                        new_width = int(char_img.width * scale_factor)
                        char_img = char_img.resize((new_width, target_height), Image.Resampling.LANCZOS)

                    all_characters[char.strip()] = char_img
                except Exception as e:
                    print(f"Error loading dataset character {char}: {e}")
                    continue

        # Add custom characters (override dataset if exists)
        for img_path, char in char_mappings.items():
            if char.strip():
                try:
                    char_img = Image.open(img_path).convert("RGBA")

                    # Resize lowercase characters to be smaller
                    if char.strip().islower():
                        target_height = int(50 * 0.7)  # 70% size for lowercase
                        scale_factor = target_height / char_img.height
                        new_width = int(char_img.width * scale_factor)
                        char_img = char_img.resize((new_width, target_height), Image.Resampling.LANCZOS)

                    all_characters[char.strip()] = char_img
                except Exception as e:
                    print(f"Error loading custom character {char}: {e}")
                    continue

        if not all_characters:
            st.error("No character mappings available.")
            return None

        # Split text into lines
        lines = text.split('\n')
        line_images = []

        for line in lines:
            if not line.strip():
                line_images.append(None)
                continue

            char_images = []
            for char in line:
                if char == ' ':
                    # Space character - fully transparent
                    blank = Image.new('RGBA', (char_spacing * 2, 50), (0, 0, 0, 0))
                    char_images.append(blank)
                elif char in all_characters:  # Check exact case first
                    char_img = all_characters[char].copy()  # Make a copy
                    char_images.append(char_img)
                elif char.lower() in all_characters:  # Then check lowercase
                    char_img = all_characters[char.lower()].copy()  # Make a copy
                    char_images.append(char_img)
                elif char.upper() in all_characters:  # Then check uppercase
                    char_img = all_characters[char.upper()].copy()  # Make a copy
                    char_images.append(char_img)
                else:
                    # Character not found - create placeholder
                    placeholder = Image.new('RGBA', (20, 30), (255, 0, 0, 128))  # Red semi-transparent placeholder
                    char_images.append(placeholder)

            if char_images:
                # Combine characters in line
                total_width = sum(img.width for img in char_images) + char_spacing * (len(char_images) - 1)
                max_height = max(img.height for img in char_images) if char_images else 50

                line_img = Image.new('RGBA', (total_width, max_height), (255, 255, 255, 0))
                x_offset = 0

                for char_img in char_images:
                    line_img.paste(char_img, (x_offset, 0), char_img)
                    x_offset += char_img.width + char_spacing

                line_images.append(line_img)
            else:
                line_images.append(None)

        # Calculate total content size
        valid_lines = [img for img in line_images if img is not None]
        if not valid_lines:
            st.error("No valid lines generated")
            return None

        content_height = sum(img.height for img in valid_lines) + line_spacing * (len(valid_lines) - 1)
        content_width = max(img.width for img in valid_lines)

        # Use paper background or create white background
        if paper_bg:
            # Scale paper background to accommodate content with margins
            margin = 100
            required_width = content_width + 2 * margin
            required_height = content_height + 2 * margin

            # Scale background if needed
            bg_scale_w = required_width / paper_bg.width if required_width > paper_bg.width else 1
            bg_scale_h = required_height / paper_bg.height if required_height > paper_bg.height else 1
            bg_scale = max(bg_scale_w, bg_scale_h)

            if bg_scale > 1:
                new_bg_width = int(paper_bg.width * bg_scale)
                new_bg_height = int(paper_bg.height * bg_scale)
                final_image = paper_bg.resize((new_bg_width, new_bg_height), Image.Resampling.LANCZOS)
            else:
                final_image = paper_bg.copy()

            # Keep as RGB - don't convert to RGBA
            if final_image.mode != "RGB":
                final_image = final_image.convert("RGB")

            start_x = margin
            start_y = margin
        else:
            # Create white background
            final_image = Image.new('RGB', (content_width + 200, content_height + 200), (255, 255, 255))
            start_x = 100
            start_y = 100

        # Paste text lines onto background
        y_offset = start_y
        for line_img in line_images:
            if line_img is not None:
                # Create a composite for transparency on RGB background
                if line_img.mode == "RGBA":
                    # Create temporary RGB version of the line
                    temp_line = Image.new("RGB", line_img.size, (255, 255, 255))
                    temp_line.paste(line_img, mask=line_img.split()[-1])  # Use alpha as mask

                    # Create mask from alpha channel
                    mask = line_img.split()[-1]  # Get alpha channel

                    # Paste onto final image using mask
                    final_image.paste(temp_line, (start_x, y_offset), mask)
                else:
                    final_image.paste(line_img, (start_x, y_offset))

                y_offset += line_img.height + line_spacing
            else:
                y_offset += line_spacing

        return final_image

    except Exception as e:
        st.error(f"Error in text generation: {e}")
        import traceback
        st.error(f"Traceback: {traceback.format_exc()}")
        return None

def nearest_colour(subjects, query):
    """Calculate the nearest colour based on RGB values"""
    return min(subjects, key=lambda subject: sum((s - q) ** 2 for s, q in zip(subject, query)))[3]

def save_character_mappings():
    """Save character mappings to prevent reprocessing"""
    if hasattr(st.session_state, 'character_mappings'):
        # This will persist in session state automatically
        pass

def load_character_mappings():
    """Load existing character mappings"""
    # Mappings are already in session state
    return st.session_state.get('character_mappings', {})

def load_dataset_characters(dataset_folder="Dataset"):
    """Load character images from dataset folder"""
    if not os.path.exists(dataset_folder):
        return {}

    characters = {}

    # Load lowercase letters (.a.png, .b.png, ...)
    for char in 'abcdefghijklmnopqrstuvwxyz':
        char_file = os.path.join(dataset_folder, f".{char}.png")
        if os.path.exists(char_file):
            characters[char] = char_file  # Store as lowercase key

    # Load uppercase letters (A.png, B.png, ...)
    for char in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ':
        char_file = os.path.join(dataset_folder, f"{char}.png")
        if os.path.exists(char_file):
            characters[char] = char_file  # Store as uppercase key

    # Load numbers (0.png, 1.png, ...)
    for char in '0123456789':
        char_file = os.path.join(dataset_folder, f"{char}.png")
        if os.path.exists(char_file):
            characters[char] = char_file

    # Load common punctuation
    punctuation_map = {
        'period': '.', 'comma': ',', 'question': '?', 'exclamation': '!',
        'colon': ':', 'semicolon': ';', 'apostrophe': "'", 'quote': '"',
        'hyphen': '-', 'underscore': '_', 'space': ' '
    }

    for file_name, char in punctuation_map.items():
        char_file = os.path.join(dataset_folder, f"{file_name}.png")
        if os.path.exists(char_file):
            characters[char] = char_file

    return characters

def load_paper_background(dataset_folder="Dataset"):
    """Load paper background image"""
    paper_bg_path = os.path.join(dataset_folder, "paperbg.png")
    if os.path.exists(paper_bg_path):
        bg_img = Image.open(paper_bg_path)
        # Convert to RGB to ensure proper background handling
        if bg_img.mode != "RGB":
            bg_img = bg_img.convert("RGB")
        return bg_img
    return None

def align_character_sizes(characters_dict, target_height=50):
    """Align character sizes maintaining aspect ratio with lowercase adjustment"""
    aligned_chars = {}

    for char, img_path in characters_dict.items():
        try:
            if isinstance(img_path, str):  # If it's a path
                img = Image.open(img_path).convert("RGBA")
            else:  # If it's already an image
                img = img_path.convert("RGBA")

            # Different target heights for lowercase vs uppercase
            if char.islower():
                char_target_height = int(target_height * 0.7)  # 70% size for lowercase
            else:
                char_target_height = target_height

            # Calculate scale factor
            scale_factor = char_target_height / img.height
            new_width = int(img.width * scale_factor)

            # Resize image
            resized_img = img.resize((new_width, char_target_height), Image.Resampling.LANCZOS)
            aligned_chars[char] = resized_img

        except Exception as e:
            st.error(f"Error processing character {char}: {e}")
            continue

    return aligned_chars

def transcribe_audio(audio_file):
    """Transcribe audio file to text using Speech2Text"""
    if not STT_AVAILABLE:
        return "Speech-to-Text not available"

    try:
        model, processor = load_speech_model()

        # Read audio file
        audio_data, sample_rate = sf.read(audio_file)

        # Process audio
        inputs = processor(
            audio_data,
            sampling_rate=sample_rate,
            return_tensors="pt"
        )

        # Generate transcription
        generated_ids = model.generate(
            inputs["input_features"],
            attention_mask=inputs["attention_mask"]
        )

        # Decode transcription
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
        return transcription[0] if transcription else "No transcription available"

    except Exception as e:
        return f"Error transcribing audio: {str(e)}"

def create_directories():
    """Create necessary directories"""
    data_dir = os.path.join(st.session_state.temp_dir, "data")
    characters_dir = os.path.join(data_dir, "characters")
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(characters_dir, exist_ok=True)
    return data_dir, characters_dir

def process_image_advanced(image_file):
    """Enhanced version of the original main function with progress tracking"""
    data_dir, characters_dir = create_directories()

    # Save uploaded file temporarily
    temp_image_path = os.path.join(st.session_state.temp_dir, "input_image.png")
    with open(temp_image_path, "wb") as f:
        f.write(image_file.getvalue())

    progress_bar = st.progress(0)
    status_text = st.empty()

    try:
        # Step 1: Load and enhance image
        status_text.text("Loading and enhancing image...")
        progress_bar.progress(10)

        original_picture = Image.open(temp_image_path)
        width, height = original_picture.size

        # Enhanced upscaling with better resampling
        picture = original_picture.resize((width * 2, height * 2), Image.Resampling.LANCZOS)
        original_picture.close()
        width, height = picture.size

        # Step 2: Background removal
        status_text.text("Removing background and isolating text...")
        progress_bar.progress(25)

        rgba_image = picture.convert("RGBA")
        picture.close()

        for y in range(height):
            for x in range(width):
                r, g, b, a = rgba_image.getpixel((x, y))
                if not nearest_colour(colours, (r, g, b)) in ["black", "grey"]:
                    rgba_image.putpixel((x, y), (255, 255, 255, 0))
                else:
                    rgba_image.putpixel((x, y), (0, 0, 0, 255))

        # Step 3: Image enhancement
        status_text.text("Enhancing image quality...")
        progress_bar.progress(40)

        detail = rgba_image.filter(ImageFilter.EDGE_ENHANCE)
        rgba_image.close()
        smooth = detail.filter(ImageFilter.SMOOTH)
        sharpen = smooth.filter(ImageFilter.SHARPEN)
        smooth.close()

        alpha_path = os.path.join(data_dir, "AlphabetAlpha.png")
        sharpen.save(alpha_path)
        # Add white background for display
        st.session_state.processed_image = add_white_background_to_processed_image(sharpen.copy())
        sharpen.close()

        # Step 4: Vertical cropping
        status_text.text("Segmenting text rows...")
        progress_bar.progress(55)

        rgba_alphabet = Image.open(alpha_path)
        ycroplist = []

        for y in range(height):
            no_letters = True
            for x in range(width):
                r, g, b, a = rgba_alphabet.getpixel((x, y))
                if nearest_colour(colours, (r, g, b)) in ["black", "grey"]:
                    no_letters = False
                    break
            if no_letters:
                ycroplist.append(y)

        # Create vertical crops
        vertical_counter = 1
        for y in range(len(ycroplist) - 1):
            if ycroplist[y] + 1 != ycroplist[y + 1]:
                cropped = rgba_alphabet.crop((0, ycroplist[y] + 1, width, ycroplist[y + 1]))
                cropped.save(os.path.join(data_dir, f"VerticalCrop{vertical_counter}.png"))
                cropped.close()
                vertical_counter += 1

        rgba_alphabet.close()

        # Step 5: Horizontal cropping
        status_text.text("Segmenting individual characters...")
        progress_bar.progress(70)

        files = os.listdir(data_dir)
        horizontal_counter = 1

        for item in files:
            if "VerticalCrop" in item:
                rgba_alphabet = Image.open(os.path.join(data_dir, item))
                width_crop, height_crop = rgba_alphabet.size
                xcroplist = []

                for x in range(width_crop):
                    no_letters = True
                    for y in range(height_crop):
                        r, g, b, a = rgba_alphabet.getpixel((x, y))
                        if nearest_colour(colours, (r, g, b)) in ["black", "grey"]:
                            no_letters = False
                            break
                    if no_letters:
                        xcroplist.append(x)

                for x in range(len(xcroplist) - 1):
                    if xcroplist[x] + 1 != xcroplist[x + 1]:
                        cropped = rgba_alphabet.crop((xcroplist[x] + 1, 0, xcroplist[x + 1], height_crop))
                        cropped.save(os.path.join(data_dir, f"HorizontalCrop{horizontal_counter}.png"))
                        cropped.close()
                        horizontal_counter += 1

                rgba_alphabet.close()

        # Step 6: Final character cropping
        status_text.text("Finalizing character extraction...")
        progress_bar.progress(85)

        files = os.listdir(data_dir)
        counter = 0
        extracted_characters = []

        for item in files:
            if "HorizontalCrop" in item:
                counter += 1
                rgba_character = Image.open(os.path.join(data_dir, item))
                width_char, height_char = rgba_character.size
                croplist = []

                for y in range(height_char):
                    for x in range(width_char):
                        r, g, b, a = rgba_character.getpixel((x, y))
                        if nearest_colour(colours, (r, g, b)) in ["black", "grey"]:
                            croplist.append(y)

                if croplist:
                    cropped = rgba_character.crop((0, croplist[0], width_char, croplist[-1]))
                    char_path = os.path.join(characters_dir, f"Cropped{counter}.png")
                    cropped.save(char_path)
                    extracted_characters.append(char_path)
                    cropped.close()

                rgba_character.close()
                os.remove(os.path.join(data_dir, item))

        # Cleanup
        for item in files:
            if "VerticalCrop" in item:
                try:
                    os.remove(os.path.join(data_dir, item))
                except FileNotFoundError:
                    pass

        try:
            os.remove(alpha_path)
        except FileNotFoundError:
            pass

        status_text.text("Processing complete!")
        progress_bar.progress(100)

        st.session_state.characters_extracted = extracted_characters
        st.session_state.processing_complete = True

        return True, f"Successfully extracted {len(extracted_characters)} characters"

    except Exception as e:
        return False, f"Error during processing: {str(e)}"

def prepare_character_for_display(char_path, target_size=(120, 120), padding=10):
    """Prepare character image for uniform display with white background"""
    try:
        # Load the character image
        char_img = Image.open(char_path).convert("RGBA")

        # Calculate scaling to fit within target size while maintaining aspect ratio
        original_width, original_height = char_img.size
        scale_factor = min(
            (target_size[0] - 2 * padding) / original_width,
            (target_size[1] - 2 * padding) / original_height
        )

        new_width = int(original_width * scale_factor)
        new_height = int(original_height * scale_factor)

        # Resize the character
        resized_char = char_img.resize((new_width, new_height), Image.Resampling.LANCZOS)

        # Create white background
        display_img = Image.new("RGB", target_size, "white")

        # Calculate position to center the character
        x_offset = (target_size[0] - new_width) // 2
        y_offset = (target_size[1] - new_height) // 2

        # Paste the character onto white background
        if resized_char.mode == "RGBA":
            # Handle transparency
            display_img.paste(resized_char, (x_offset, y_offset), resized_char)
        else:
            display_img.paste(resized_char, (x_offset, y_offset))

        # Add border
        draw = ImageDraw.Draw(display_img)
        draw.rectangle([0, 0, target_size[0]-1, target_size[1]-1], outline="lightgray", width=1)

        return display_img
    except Exception as e:
        # Return error placeholder
        error_img = Image.new("RGB", target_size, "white")
        draw = ImageDraw.Draw(error_img)
        draw.text((10, target_size[1]//2), "Error", fill="red")
        return error_img

def run_ocr_on_character(char_path):
    """Run OCR on a single character image"""
    if not OCR_AVAILABLE:
        return "N/A"

    try:
        ocr = load_ocr_model()
        # Prepare image for OCR
        img = Image.open(char_path).convert("RGB")

        # Enhance image for better OCR
        img = img.resize((img.width * 2, img.height * 2), Image.Resampling.LANCZOS)

        # Run OCR
        result = ocr(img)
        if result and len(result) > 0:
            predicted_text = result[0]['generated_text'].strip()
            # Clean up the result (take first character if multiple)
            if predicted_text:
                return predicted_text[0].upper() if predicted_text else "?"
        return "?"
    except Exception as e:
        return "Error"

def display_character_grid(characters, cols_per_row=6, show_ocr=False):
    """Display characters in an equal grid layout with uniform sizing"""
    if not characters:
        st.info("No characters to display")
        return

    st.write(f"**Total Characters Extracted:** {len(characters)}")

    # OCR toggle
    if OCR_AVAILABLE and show_ocr:
        with st.spinner("Running OCR on characters..."):
            ocr_results = {}
            progress_bar = st.progress(0)
            for idx, char_path in enumerate(characters):
                ocr_results[char_path] = run_ocr_on_character(char_path)
                progress_bar.progress((idx + 1) / len(characters))
            progress_bar.empty()

    # Create grid layout
    rows = [characters[i:i + cols_per_row] for i in range(0, len(characters), cols_per_row)]

    for row_idx, row in enumerate(rows):
        # Create columns with equal width
        cols = st.columns(cols_per_row)

        for col_idx, char_path in enumerate(row):
            with cols[col_idx]:
                try:
                    # Prepare uniform character image
                    display_img = prepare_character_for_display(char_path)

                    # Create caption
                    char_num = row_idx * cols_per_row + col_idx + 1
                    caption = f"#{char_num}"

                    if OCR_AVAILABLE and show_ocr and char_path in ocr_results:
                        caption += f" | OCR: {ocr_results[char_path]}"

                    st.image(display_img, caption=caption, use_container_width=True)

                except Exception as e:
                    st.error(f"Error displaying character {col_idx + 1}: {str(e)}")

        # Fill empty columns in the last row
        if len(row) < cols_per_row:
            for empty_col in range(len(row), cols_per_row):
                with cols[empty_col]:
                    st.empty()

def create_character_mapping_interface():
    """Create an interface for mapping characters with OCR assistance"""
    st.markdown('<div class="step-header">Character Verification & Mapping</div>', unsafe_allow_html=True)

    if not st.session_state.characters_extracted:
        st.warning("No characters extracted yet. Please process an image first.")
        return

    # OCR options
    col1, col2, col3 = st.columns([2, 1, 1])
    with col1:
        st.markdown("**Options:**")
    with col2:
        auto_ocr = st.checkbox("Enable OCR Assistance", value=OCR_AVAILABLE, disabled=not OCR_AVAILABLE)
    with col3:
        if auto_ocr and st.button("Run OCR on All", type="secondary"):
            with st.spinner("Running OCR..."):
                for char_path in st.session_state.characters_extracted:
                    if char_path not in st.session_state.character_mappings or not st.session_state.character_mappings[char_path]:
                        ocr_result = run_ocr_on_character(char_path)
                        st.session_state.character_mappings[char_path] = ocr_result
                st.success("OCR completed!")
                st.rerun()

    # Character mapping form
    with st.form("character_mapping"):
        st.write("Map each extracted character to its actual character:")

        # Grid layout for mapping
        cols_per_row = 4
        rows = [st.session_state.characters_extracted[i:i + cols_per_row]
                for i in range(0, len(st.session_state.characters_extracted), cols_per_row)]

        mappings = {}

        for row_idx, row in enumerate(rows):
            cols = st.columns(cols_per_row)

            for col_idx, char_path in enumerate(row):
                with cols[col_idx]:
                    try:
                        # Display character with uniform sizing
                        display_img = prepare_character_for_display(char_path, target_size=(100, 100))
                        char_num = row_idx * cols_per_row + col_idx + 1
                        st.image(display_img, caption=f"Character #{char_num}", use_container_width=True)

                        # Get existing mapping or OCR suggestion
                        existing_mapping = st.session_state.character_mappings.get(char_path, "")

                        # OCR suggestion
                        ocr_suggestion = ""
                        if auto_ocr and OCR_AVAILABLE:
                            ocr_suggestion = run_ocr_on_character(char_path)
                            if not existing_mapping and ocr_suggestion != "Error":
                                existing_mapping = ocr_suggestion

                        char_input = st.text_input(
                            f"Char #{char_num}:",
                            value=existing_mapping,
                            key=f"char_{char_num}",
                            max_chars=5,
                            placeholder=f"OCR: {ocr_suggestion}" if ocr_suggestion else "Enter character"
                        )
                        mappings[char_path] = char_input

                    except Exception as e:
                        st.error(f"Error displaying character {char_num}: {e}")

            # Fill empty columns in the last row
            if len(row) < cols_per_row:
                for empty_col in range(len(row), cols_per_row):
                    with cols[empty_col]:
                        st.empty()

        col1, col2 = st.columns(2)
        with col1:
            if st.form_submit_button("Save Character Mappings", type="primary"):
                st.session_state.character_mappings = mappings
                st.success("Character mappings saved successfully!")

                # Display mapping summary
                st.write("### Mapping Summary:")
                mapping_data = []
                for idx, (char_path, mapping) in enumerate(mappings.items(), 1):
                    if mapping.strip():
                        mapping_data.append({
                            "Character #": idx,
                            "File": os.path.basename(char_path),
                            "Mapped To": mapping.strip()
                        })

                if mapping_data:
                    df = pd.DataFrame(mapping_data)
                    st.dataframe(df, use_container_width=True)

        with col2:
            if st.form_submit_button("Clear All Mappings", type="secondary"):
                st.session_state.character_mappings = {}
                st.warning("All mappings cleared!")
                st.rerun()

def create_text_generation_interface():
    """Interface for generating handwritten text with dataset support"""
    st.markdown('<div class="step-header">Text Generation</div>', unsafe_allow_html=True)

    # Load dataset characters if not already loaded
    if not st.session_state.dataset_characters:
        st.session_state.dataset_characters = load_dataset_characters()
        st.session_state.paper_background = load_paper_background()

    # Dataset status
    col1, col2, col3 = st.columns(3)
    with col1:
        if st.session_state.dataset_characters:
            st.success(f"✅ Dataset loaded ({len(st.session_state.dataset_characters)} characters)")
        else:
            st.warning("⚠️ No dataset found in 'Dataset' folder")

    with col2:
        if st.session_state.paper_background:
            st.success("✅ Paper background loaded")
        else:
            st.info("📄 No paperbg.png found")

    with col3:
        st.session_state.use_dataset = st.checkbox(
            "Use dataset characters",
            value=st.session_state.use_dataset,
            help="Use dataset characters for missing letters"
        )

    # Input options
    st.markdown("### Text Input Options")
    input_method = st.radio(
        "Choose input method:",
        ["Type Text", "Upload Text File", "Speech-to-Text"],
        horizontal=True
    )

    input_text = ""

    if input_method == "Type Text":
        input_text = st.text_area(
            "Enter text to convert to handwriting:",
            height=150,
            placeholder="Type your text here..."
        )

    elif input_method == "Upload Text File":
        uploaded_text_file = st.file_uploader(
            "Upload a text file",
            type=['txt'],
            help="Upload a .txt file with your content"
        )

        if uploaded_text_file is not None:
            try:
                input_text = uploaded_text_file.read().decode('utf-8')
                st.text_area(
                    "File content preview:",
                    value=input_text,
                    height=150,
                    disabled=True
                )
            except Exception as e:
                st.error(f"Error reading file: {e}")
                return

    elif input_method == "Speech-to-Text":
        if STT_AVAILABLE:
            uploaded_audio = st.file_uploader(
                "Upload audio file",
                type=['wav', 'flac', 'mp3'],
                help="Upload an audio file to transcribe"
            )

            if uploaded_audio is not None:
                with st.spinner("Transcribing audio..."):
                    transcription = transcribe_audio(uploaded_audio)

                input_text = st.text_area(
                    "Transcribed text:",
                    value=transcription,
                    height=150,
                    help="Edit the transcription if needed"
                )
        else:
            st.error("Speech-to-Text not available. Install required packages.")
            return

    if not input_text.strip():
        st.info("Please enter text, upload a file, or provide audio to continue.")
        return

    # Generation options
    st.markdown("### Generation Settings")
    col1, col2 = st.columns(2)

    with col1:
        line_spacing = st.slider("Line Spacing", 1, 50, 20, key="line_spacing_gen")
        char_spacing = st.slider("Character Spacing", 1, 20, 5, key="char_spacing_gen")
        use_paper_bg = st.checkbox("Use paper background", value=bool(st.session_state.paper_background), key="use_paper_bg_gen")

    with col2:
        output_width = st.slider("Output Width", 500, 2000, 1000, key="output_width_gen")
        output_format = st.selectbox("Output Format", ["PNG", "JPEG"], key="output_format_gen")
        page_title = st.text_input("Page Title (optional)", placeholder="Page 1", key="page_title_gen")

    # Generate button
    if st.button("Generate Page", type="primary"):
        if not st.session_state.character_mappings and not st.session_state.use_dataset:
            st.warning("Please complete character mapping first or enable dataset usage.")
            return

        try:
            # Align dataset characters if using dataset
            aligned_dataset = {}
            if st.session_state.use_dataset and st.session_state.dataset_characters:
                aligned_dataset = align_character_sizes(st.session_state.dataset_characters)

            paper_bg = st.session_state.paper_background if use_paper_bg else None

            generated_image = generate_handwritten_text_with_background(
                input_text,
                st.session_state.character_mappings,
                aligned_dataset,  # Pass aligned dataset, not raw dataset
                line_spacing,
                char_spacing,
                paper_bg,
                st.session_state.use_dataset
            )

            if generated_image:
                st.image(generated_image, caption=f"Generated: {page_title or 'Handwritten Text'}")

                # Add to pages
                page_data = {
                    'image': generated_image,
                    'title': page_title or f"Page {len(st.session_state.generated_pages) + 1}",
                    'text_preview': input_text[:100] + "..." if len(input_text) > 100 else input_text
                }
                st.session_state.generated_pages.append(page_data)

                # Download options
                col1, col2, col3 = st.columns(3)

                with col1:
                    # Download single image
                    img_buffer = BytesIO()
                    generated_image.save(img_buffer, format=output_format)
                    img_buffer.seek(0)

                    st.download_button(
                        label=f"📥 Download {output_format}",
                        data=img_buffer,
                        file_name=f"handwritten_{page_title or 'text'}.{output_format.lower()}",
                        mime=f"image/{output_format.lower()}"
                    )

                with col2:
                    # Download single page PDF
                    if st.session_state.generated_pages:
                        temp_pdf_path = os.path.join(st.session_state.temp_dir, "temp_single.pdf")

                        if create_pdf_from_images([generated_image], temp_pdf_path):
                            with open(temp_pdf_path, "rb") as f:
                                pdf_data = f.read()

                            st.download_button(
                                label="📄 Download PDF",
                                data=pdf_data,
                                file_name=f"handwritten_{page_title or 'text'}.pdf",
                                mime="application/pdf"
                            )

                with col3:
                    st.success(f"✅ Page added ({len(st.session_state.generated_pages)} total)")

            else:
                st.error("Failed to generate handwritten text.")

        except Exception as e:
            st.error(f"Error generating text: {e}")

    # Multi-page management
    if st.session_state.generated_pages:
        st.markdown("---")
        st.markdown("### Generated Pages")

        # Show all pages
        for idx, page_data in enumerate(st.session_state.generated_pages):
            with st.expander(f"📄 {page_data['title']}", expanded=False):
                col1, col2 = st.columns([3, 1])
                with col1:
                    st.image(page_data['image'], use_container_width=True)
                with col2:
                    st.write("**Preview:**")
                    st.write(page_data['text_preview'])
                    if st.button(f"🗑️ Remove", key=f"remove_{idx}"):
                        st.session_state.generated_pages.pop(idx)
                        st.rerun()

        # Multi-page PDF download
        st.markdown("### Combined PDF")
        col1, col2, col3 = st.columns([1, 1, 2])

        with col1:
            if st.button("📚 Generate Combined PDF", type="primary"):
                try:
                    temp_pdf_path = os.path.join(st.session_state.temp_dir, "combined_pages.pdf")
                    images = [page['image'] for page in st.session_state.generated_pages]

                    if create_pdf_from_images(images, temp_pdf_path):
                        with open(temp_pdf_path, "rb") as f:
                            pdf_data = f.read()

                        st.session_state.combined_pdf = pdf_data
                        st.success("Combined PDF ready!")
                    else:
                        st.error("Failed to create combined PDF")
                except Exception as e:
                    st.error(f"Error creating combined PDF: {e}")

        with col2:
            if hasattr(st.session_state, 'combined_pdf'):
                st.download_button(
                    label="📥 Download Combined PDF",
                    data=st.session_state.combined_pdf,
                    file_name="handwritten_document.pdf",
                    mime="application/pdf"
                )

        with col3:
            if st.button("🗑️ Clear All Pages"):
                st.session_state.generated_pages = []
                if hasattr(st.session_state, 'combined_pdf'):
                    del st.session_state.combined_pdf
                st.rerun()

    if not input_text.strip():
        st.info("Please enter text or upload a file to continue.")
        return

def generate_handwritten_text(text, char_mappings, line_spacing, char_spacing, output_width):
    """Generate handwritten text from character mappings"""
    try:
        # Add custom characters (override dataset if exists)
        reverse_mapping = {}
        for img_path, char in char_mappings.items():
            if char.strip():
                try:
                    char_img = Image.open(img_path).convert("RGBA")
                    all_characters[char.strip()] = char_img  # Keep original case
                    reverse_mapping[char.strip()] = img_path
                except Exception:
                    continue

        # Split text into lines
        lines = text.split('\n')
        line_images = []

        for line in lines:
            if not line.strip():
                # Empty line
                line_images.append(None)
                continue

            char_images = []
            for char in line.lower():
                if char == ' ':
                    # Space character - create blank space
                    blank = Image.new('RGBA', (char_spacing * 2, 50), (255, 255, 255, 0))
                    char_images.append(blank)
                elif char in reverse_mapping:
                    try:
                        char_img = Image.open(reverse_mapping[char])
                        char_images.append(char_img)
                    except Exception:
                        # If character image can't be loaded, skip
                        continue
                else:
                    # Character not found - create placeholder
                    placeholder = Image.new('RGBA', (20, 30), (255, 255, 255, 0))
                    char_images.append(placeholder)

            if char_images:
                # Combine characters in line
                total_width = sum(img.width for img in char_images) + char_spacing * (len(char_images) - 1)
                max_height = max(img.height for img in char_images) if char_images else 50

                line_img = Image.new('RGBA', (min(total_width, output_width), max_height), (255, 255, 255, 0))
                x_offset = 0

                for char_img in char_images:
                    if x_offset + char_img.width <= output_width:
                        line_img.paste(char_img, (x_offset, 0), char_img)
                        x_offset += char_img.width + char_spacing
                    else:
                        break

                line_images.append(line_img)
            else:
                line_images.append(None)

        # Combine all lines
        valid_lines = [img for img in line_images if img is not None]
        if not valid_lines:
            return None

        total_height = sum(img.height for img in valid_lines) + line_spacing * (len(valid_lines) - 1)
        final_width = max(img.width for img in valid_lines)

        final_image = Image.new('RGBA', (final_width, total_height), (255, 255, 255, 255))
        y_offset = 0

        for line_img in line_images:
            if line_img is not None:
                final_image.paste(line_img, (0, y_offset), line_img)
                y_offset += line_img.height + line_spacing
            else:
                y_offset += line_spacing

        return final_image

    except Exception as e:
        st.error(f"Error in text generation: {e}")
        return None

def create_analytics_dashboard():
    """Create analytics dashboard"""
    st.markdown('<div class="step-header">Analytics Dashboard</div>', unsafe_allow_html=True)

    if not st.session_state.characters_extracted:
        st.info("No data available. Process an image first.")
        return

    # Character statistics
    col1, col2, col3 = st.columns(3)

    with col1:
        st.metric("Total Characters Extracted", len(st.session_state.characters_extracted))

    with col2:
        mapped_chars = len([v for v in st.session_state.character_mappings.values() if v.strip()])
        st.metric("Characters Mapped", mapped_chars)

    with col3:
        completion_rate = (mapped_chars / len(st.session_state.characters_extracted) * 100) if st.session_state.characters_extracted else 0
        st.metric("Completion Rate", f"{completion_rate:.1f}%")

    # Character size analysis
    if st.session_state.characters_extracted:
        st.write("### Character Size Analysis")

        sizes = []
        for char_path in st.session_state.characters_extracted:
            try:
                img = Image.open(char_path)
                sizes.append({
                    'Character': os.path.basename(char_path),
                    'Width': img.width,
                    'Height': img.height,
                    'Area': img.width * img.height
                })
            except Exception:
                continue

        if sizes:
            df_sizes = pd.DataFrame(sizes)

            fig_scatter = px.scatter(
                df_sizes,
                x='Width',
                y='Height',
                hover_data=['Character', 'Area'],
                title="Character Dimensions",
                color='Area'
            )
            st.plotly_chart(fig_scatter, use_container_width=True)

            # Size distribution
            fig_hist = px.histogram(
                df_sizes,
                x='Area',
                title="Character Size Distribution",
                nbins=20
            )
            st.plotly_chart(fig_hist, use_container_width=True)

def export_project():
    """Export project data with fixed download buttons"""
    st.markdown('<div class="step-header">Export Project</div>', unsafe_allow_html=True)

    if not st.session_state.characters_extracted and not st.session_state.generated_pages:
        st.warning("No project data to export.")
        return

    export_format = st.selectbox(
        "Choose export format:",
        ["ZIP Archive", "JSON Mappings", "Character Images Only", "Generated Pages Only"]
    )

    if st.button("Prepare Export", type="primary"):
        try:
            if export_format == "ZIP Archive":
                # Create ZIP with all project files
                zip_buffer = BytesIO()
                with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
                    # Add character images
                    for char_path in st.session_state.characters_extracted:
                        if os.path.exists(char_path):
                            zip_file.write(char_path, f"characters/{os.path.basename(char_path)}")

                    # Add mappings as JSON
                    mappings_json = json.dumps(st.session_state.character_mappings, indent=2)
                    zip_file.writestr("character_mappings.json", mappings_json)

                    # Add processed image if available
                    if st.session_state.processed_image:
                        img_buffer = BytesIO()
                        st.session_state.processed_image.save(img_buffer, format='PNG')
                        zip_file.writestr("processed_image.png", img_buffer.getvalue())

                    # Add generated pages
                    for idx, page_data in enumerate(st.session_state.generated_pages):
                        page_buffer = BytesIO()
                        page_data['image'].save(page_buffer, format='PNG')
                        zip_file.writestr(f"generated_pages/page_{idx+1}_{page_data['title']}.png", page_buffer.getvalue())

                zip_buffer.seek(0)
                st.session_state.export_data = zip_buffer.getvalue()
                st.session_state.export_filename = "handwriting_project.zip"
                st.session_state.export_mime = "application/zip"

            elif export_format == "JSON Mappings":
                mappings_json = json.dumps(st.session_state.character_mappings, indent=2)
                st.session_state.export_data = mappings_json.encode()
                st.session_state.export_filename = "character_mappings.json"
                st.session_state.export_mime = "application/json"

            elif export_format == "Character Images Only":
                zip_buffer = BytesIO()
                with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
                    for char_path in st.session_state.characters_extracted:
                        if os.path.exists(char_path):
                            zip_file.write(char_path, f"characters/{os.path.basename(char_path)}")

                zip_buffer.seek(0)
                st.session_state.export_data = zip_buffer.getvalue()
                st.session_state.export_filename = "character_images.zip"
                st.session_state.export_mime = "application/zip"

            elif export_format == "Generated Pages Only":
                if st.session_state.generated_pages:
                    zip_buffer = BytesIO()
                    with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
                        for idx, page_data in enumerate(st.session_state.generated_pages):
                            page_buffer = BytesIO()
                            page_data['image'].save(page_buffer, format='PNG')
                            zip_file.writestr(f"page_{idx+1}_{page_data['title']}.png", page_buffer.getvalue())

                    zip_buffer.seek(0)
                    st.session_state.export_data = zip_buffer.getvalue()
                    st.session_state.export_filename = "generated_pages.zip"
                    st.session_state.export_mime = "application/zip"
                else:
                    st.warning("No generated pages to export.")
                    return

            st.success("Export prepared successfully!")

        except Exception as e:
            st.error(f"Export preparation failed: {e}")
            return

    # Show download button if export data is ready
    if hasattr(st.session_state, 'export_data'):
        st.download_button(
            label=f"📥 Download {export_format}",
            data=st.session_state.export_data,
            file_name=st.session_state.export_filename,
            mime=st.session_state.export_mime,
            key="export_download"
        )

def create_pdf_from_images(images, output_path):
    """Create PDF from list of images"""
    try:
        c = canvas.Canvas(output_path, pagesize=A4)
        page_width, page_height = A4

        for img in images:
            if img:
                # Convert PIL image to bytes
                img_buffer = BytesIO()
                img.save(img_buffer, format='PNG')
                img_buffer.seek(0)

                # Calculate scaling to fit page
                img_width, img_height = img.size
                scale_w = page_width / img_width
                scale_h = page_height / img_height
                scale = min(scale_w, scale_h) * 0.9  # 90% of page size

                new_width = img_width * scale
                new_height = img_height * scale

                margin = 50
                x = margin
                y = page_height - new_height - margin

                # Draw image
                c.drawImage(ImageReader(img_buffer), x, y, new_width, new_height)
                c.showPage()

        c.save()
        return True
    except Exception as e:
        st.error(f"Error creating PDF: {e}")
        return False

        st.markdown("---")
        if st.button("Reset Project", type="secondary"):
            for key in list(st.session_state.keys()):
                if key.startswith(('processed_', 'characters_', 'character_')):
                    del st.session_state[key]
            st.rerun()

    # Main content based on selected page
    if page == "Image Processing & Mapping":
        st.markdown('<div class="step-header">📸 Image Processing & Character Mapping</div>', unsafe_allow_html=True)

        # Two main sections
        tab1, tab2 = st.tabs(["🖼️ Image Processing", "🔤 Character Mapping"])

        with tab1:
            st.markdown("""
            <div class="info-box">
            <strong>Instructions:</strong><br>
            1. Upload an image containing handwritten characters<br>
            2. The system will automatically segment and extract individual characters<br>
            3. Characters will be processed and prepared for mapping<br>
            4. Optional: Use OCR for automatic character recognition assistance
            </div>
            """, unsafe_allow_html=True)

            uploaded_file = st.file_uploader(
                "Choose an image file",
                type=['png', 'jpg', 'jpeg'],
                help="Upload a clear image with handwritten characters"
            )

            if uploaded_file is not None:
                # Check if this is a new image
                current_hash = get_image_hash(uploaded_file)
                is_new_image = st.session_state.processed_image_hash != current_hash

                # Display uploaded image
                col1, col2 = st.columns(2)
                with col1:
                    st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)

                with col2:
                    st.write("**Image Details:**")
                    image = Image.open(uploaded_file)
                    st.write(f"- **Size:** {image.size[0]} x {image.size[1]} pixels")
                    st.write(f"- **Format:** {image.format}")
                    st.write(f"- **Mode:** {image.mode}")

                    # OCR availability info
                    if OCR_AVAILABLE and st.session_state.use_ocr:
                        st.success("✅ OCR Enabled")
                    elif OCR_AVAILABLE:
                        st.info("🔍 OCR Available (Disabled)")
                    else:
                        st.info("💡 Install transformers for OCR")

                # Process button - only show if new image or not processed
                if is_new_image or not st.session_state.processing_complete:
                    if st.button("Process Image", type="primary"):
                        with st.spinner("Processing image..."):
                            success, message = process_image_advanced(uploaded_file)

                        if success:
                            st.session_state.processed_image_hash = current_hash
                            st.success(message)
                            st.rerun()
                        else:
                            st.error(message)
                else:
                    st.info("✅ Image already processed. Results shown in Character Mapping tab.")

            # Display existing results if available
            if st.session_state.processing_complete and st.session_state.processed_image:
                st.write("### Processed Image:")
                st.image(st.session_state.processed_image, caption="Processed Image (White Background)", use_container_width=True)

                if st.session_state.characters_extracted:
                    st.write("### Extracted Characters:")

                    # Options for character display
                    col1, col2, col3 = st.columns([2, 1, 1])
                    with col2:
                        show_ocr = st.checkbox("Show OCR Predictions",
                                             value=False,
                                             disabled=not (OCR_AVAILABLE and st.session_state.use_ocr))
                    with col3:
                        cols_per_row = st.selectbox("Characters per row:", [4, 6, 8], index=1)

                    display_character_grid(st.session_state.characters_extracted,
                                         cols_per_row=cols_per_row,
                                         show_ocr=show_ocr and st.session_state.use_ocr)

        with tab2:
            # Character mapping interface
            if not st.session_state.characters_extracted:
                st.warning("No characters extracted yet. Please process an image first in the Image Processing tab.")
            else:
                st.markdown("### Character Verification & Mapping")

                # OCR options
                if OCR_AVAILABLE and st.session_state.use_ocr:
                    col1, col2 = st.columns([3, 1])
                    with col2:
                        if st.button("Run OCR on All", type="secondary"):
                            with st.spinner("Running OCR..."):
                                for char_path in st.session_state.characters_extracted:
                                    if char_path not in st.session_state.character_mappings or not st.session_state.character_mappings[char_path]:
                                        ocr_result = run_ocr_on_character(char_path)
                                        st.session_state.character_mappings[char_path] = ocr_result
                            st.success("OCR completed!")
                            st.rerun()

                # Character mapping form
                with st.form("character_mapping"):
                    st.write("Map each extracted character to its actual character:")

                    # Grid layout for mapping
                    cols_per_row = 4
                    rows = [st.session_state.characters_extracted[i:i + cols_per_row]
                            for i in range(0, len(st.session_state.characters_extracted), cols_per_row)]

                    mappings = {}

                    for row_idx, row in enumerate(rows):
                        cols = st.columns(cols_per_row)

                        for col_idx, char_path in enumerate(row):
                            with cols[col_idx]:
                                try:
                                    # Display character with uniform sizing
                                    display_img = prepare_character_for_display(char_path, target_size=(100, 100))
                                    char_num = row_idx * cols_per_row + col_idx + 1
                                    st.image(display_img, caption=f"Character #{char_num}", use_container_width=True)

                                    # Get existing mapping or OCR suggestion
                                    existing_mapping = st.session_state.character_mappings.get(char_path, "")

                                    # OCR suggestion
                                    ocr_suggestion = ""
                                    if st.session_state.use_ocr and OCR_AVAILABLE:
                                        ocr_suggestion = run_ocr_on_character(char_path)
                                        if not existing_mapping and ocr_suggestion != "Error":
                                            existing_mapping = ocr_suggestion

                                    char_input = st.text_input(
                                        f"Char #{char_num}:",
                                        value=existing_mapping,
                                        key=f"char_{char_num}",
                                        max_chars=5,
                                        placeholder=f"OCR: {ocr_suggestion}" if ocr_suggestion else "Enter character"
                                    )
                                    mappings[char_path] = char_input

                                except Exception as e:
                                    st.error(f"Error displaying character {char_num}: {e}")

                        # Fill empty columns in the last row
                        if len(row) < cols_per_row:
                            for empty_col in range(len(row), cols_per_row):
                                with cols[empty_col]:
                                    st.empty()

                    col1, col2 = st.columns(2)
                    with col1:
                        if st.form_submit_button("Save Character Mappings", type="primary"):
                            st.session_state.character_mappings = mappings
                            save_character_mappings()
                            st.success("Character mappings saved successfully!")

                            # Display mapping summary
                            st.write("### Mapping Summary:")
                            mapping_data = []
                            for idx, (char_path, mapping) in enumerate(mappings.items(), 1):
                                if mapping.strip():
                                    mapping_data.append({
                                        "Character #": idx,
                                        "File": os.path.basename(char_path),
                                        "Mapped To": mapping.strip()
                                    })

                            if mapping_data:
                                df = pd.DataFrame(mapping_data)
                                st.dataframe(df, use_container_width=True)

                    with col2:
                        if st.form_submit_button("Clear All Mappings", type="secondary"):
                            st.session_state.character_mappings = {}
                            st.warning("All mappings cleared!")
                            st.rerun()

    elif page == "Text Generation":
        create_text_generation_interface()

    elif page == "Analytics":
        create_analytics_dashboard()

    elif page == "Export":
        export_project()

# Main App Layout
def main():
    # Header
    st.markdown('<div class="main-header">🖋️ HandCraft AI</div>', unsafe_allow_html=True)

    # Sidebar
    with st.sidebar:
        st.header("Navigation")
        page = st.selectbox(
            "Choose a section:",
            ["Image Processing & Mapping", "Text Generation", "Analytics", "Export"]
        )

        st.markdown("---")
        st.header("Settings")

        # OCR Toggle
        if OCR_AVAILABLE:
            st.session_state.use_ocr = st.checkbox(
                "Enable OCR Assistant",
                value=st.session_state.get('use_ocr', True),
                help="Use AI to automatically recognize characters"
            )
        else:
            st.info("💡 Install `transformers torch` for OCR")
            st.session_state.use_ocr = False

        st.markdown("---")
        st.header("Project Status")
        if st.session_state.processing_complete:
            st.success("✅ Image Processed")
            st.write(f"📊 {len(st.session_state.characters_extracted)} characters extracted")
        else:
            st.info("⏳ Awaiting Image")

        if st.session_state.character_mappings:
            mapped_count = len([v for v in st.session_state.character_mappings.values() if v.strip()])
            st.success("✅ Characters Mapped")
            st.write(f"📝 {mapped_count} characters mapped")
        else:
            st.info("⏳ Awaiting Mapping")

        # Dataset Status
        if st.session_state.dataset_characters:
            st.success(f"📚 Dataset: {len(st.session_state.dataset_characters)} chars")
        else:
            st.info("📚 No Dataset Found")
            st.caption("Create 'Dataset' folder with a.png, A.png, etc.")

        # OCR Status
        if OCR_AVAILABLE:
            st.success("🔍 OCR Available")
        else:
            st.warning("🔍 OCR Unavailable")
            st.caption("Install: `pip install transformers torch`")

        # Speech-to-Text Status
        if STT_AVAILABLE:
            st.success("🎤 Speech-to-Text Ready")
        else:
            st.info("🎤 STT Unavailable")

        st.markdown("---")
        if st.button("Reset Project", type="secondary"):
            for key in list(st.session_state.keys()):
                if key.startswith(('processed_', 'characters_', 'character_')):
                    del st.session_state[key]
            st.rerun()

    # Main content based on selected page - THIS IS THE KEY FIX
    if page == "Image Processing & Mapping":
        st.markdown('<div class="step-header">📸 Image Processing & Character Mapping</div>', unsafe_allow_html=True)

        # Two main sections
        tab1, tab2 = st.tabs(["🖼️ Image Processing", "🔤 Character Mapping"])

        with tab1:
            st.markdown("""
            <div class="info-box">
            <strong>Instructions:</strong><br>
            1. Upload an image containing handwritten characters<br>
            2. The system will automatically segment and extract individual characters<br>
            3. Characters will be processed and prepared for mapping<br>
            4. Optional: Use OCR for automatic character recognition assistance
            </div>
            """, unsafe_allow_html=True)

            uploaded_file = st.file_uploader(
                "Choose an image file",
                type=['png', 'jpg', 'jpeg'],
                help="Upload a clear image with handwritten characters"
            )

            if uploaded_file is not None:
                # Check if this is a new image
                current_hash = get_image_hash(uploaded_file)
                is_new_image = st.session_state.processed_image_hash != current_hash

                # Display uploaded image
                col1, col2 = st.columns(2)
                with col1:
                    st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)

                with col2:
                    st.write("**Image Details:**")
                    image = Image.open(uploaded_file)
                    st.write(f"- **Size:** {image.size[0]} x {image.size[1]} pixels")
                    st.write(f"- **Format:** {image.format}")
                    st.write(f"- **Mode:** {image.mode}")

                    # OCR availability info
                    if OCR_AVAILABLE and st.session_state.use_ocr:
                        st.success("✅ OCR Enabled")
                    elif OCR_AVAILABLE:
                        st.info("🔍 OCR Available (Disabled)")
                    else:
                        st.info("💡 Install transformers for OCR")

                # Process button - only show if new image or not processed
                if is_new_image or not st.session_state.processing_complete:
                    if st.button("Process Image", type="primary"):
                        with st.spinner("Processing image..."):
                            success, message = process_image_advanced(uploaded_file)

                        if success:
                            st.session_state.processed_image_hash = current_hash
                            st.success(message)
                            st.rerun()
                        else:
                            st.error(message)
                else:
                    st.info("✅ Image already processed. Results shown in Character Mapping tab.")

            # Display existing results if available
            if st.session_state.processing_complete and st.session_state.processed_image:
                st.write("### Processed Image:")
                st.image(st.session_state.processed_image, caption="Processed Image (White Background)", use_container_width=True)

                if st.session_state.characters_extracted:
                    st.write("### Extracted Characters:")

                    # Options for character display
                    col1, col2, col3 = st.columns([2, 1, 1])
                    with col2:
                        show_ocr = st.checkbox("Show OCR Predictions",
                                             value=False,
                                             disabled=not (OCR_AVAILABLE and st.session_state.use_ocr))
                    with col3:
                        cols_per_row = st.selectbox("Characters per row:", [4, 6, 8], index=1)

                    display_character_grid(st.session_state.characters_extracted,
                                         cols_per_row=cols_per_row,
                                         show_ocr=show_ocr and st.session_state.use_ocr)

        with tab2:
            # Character mapping interface
            if not st.session_state.characters_extracted:
                st.warning("No characters extracted yet. Please process an image first in the Image Processing tab.")
            else:
                create_character_mapping_interface()

    elif page == "Text Generation":
        create_text_generation_interface()

    elif page == "Analytics":
        create_analytics_dashboard()

    elif page == "Export":
        export_project()

if __name__ == "__main__":
    main()

In [None]:
!pip install pyngrok

In [22]:
ngrok_token = "30LsGu06oX4YgWEJd6z30DNO1kB_5C5VX4h5YGt3rFAUmRAqn"  # Replace with your actual token

# 4: Run Your App (With sharing - requires ngrok token)
from pyngrok import ngrok
import time
import threading

# Set your ngrok authentication token (replace ngrok_token with your actual token)
ngrok.set_auth_token(ngrok_token)

# Function to launch the Streamlit app using a system command
def run_app():
    !streamlit run app.py --server.headless true --server.port 8501

# Terminate any active ngrok tunnels before starting a new one
ngrok.kill()

# Start the Streamlit app in a separate thread so the script can continue running
app_thread = threading.Thread(target=run_app)
app_thread.start()

# Allow time for the Streamlit app to fully start before creating the tunnel
time.sleep(10)

# Create a public URL using ngrok and display it
try:
    public_url = ngrok.connect(8501)
    print("🚀 Your app is live!")
    print(f"🌐 Share this link: {public_url}")
    print("📱 Anyone can access your app with this link!")
except:
    print("⚠️ Need ngrok token for sharing. App is running locally.")


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
2025-08-01 08:50:06.799 Port 8501 is already in use
🚀 Your app is live!
🌐 Share this link: NgrokTunnel: "https://814048ee5ab7.ngrok-free.app" -> "http://localhost:8501"
📱 Anyone can access your app with this link!
