# SIMPLE/HTR <- colab version

In [None]:
import torch
from torch import nn
from torchvision import datasets

import pandas as pd
from PIL import Image
import numpy as np
import os
import csv
from io import BytesIO

"""
IAM Dataset download
"""

if not os.path.exists('/content/data/') :
    os.mkdir("/content/data")

splits = {'train': 'data/train.parquet', 'validation': 'data/validation.parquet', 'test': 'data/test.parquet'}


max_len = -1
str_len = -1
for iter in ['test', 'validation', 'train']:

    if not os.path.exists(f'/content/data/{iter}') :
        os.mkdir(f"/content/data/{iter}")
    df = pd.read_parquet("hf://datasets/Teklia/IAM-line/" + splits[iter])
    df_csv = df['text']

    if iter == 'train':
        char_set = set()
        for label in df_csv:
            char_set.update(label)
        char_set = list(char_set)
        char_set.sort()

        cToi = {}
        iToc = {}
        for index_, char in enumerate(char_set):
            cToi.update({char : index_ +1})
            iToc.update({index_ +1 : char})

        # save data
        np.save("char_set.npy",char_set)
        np.save('cToi.npy', cToi)
        np.save('iToc.npy', iToc)



    with open(os.path.join("/content/data",f"{iter}.csv"), 'w', newline="") as csvfile:
        spanwriter = csv.writer(csvfile, delimiter = '\t', quotechar="|")
        spanwriter.writerow(["_path","text"])
        for id, txt in enumerate(df_csv):
            spanwriter.writerow([id,txt])
            if len(txt) > str_len:
                str_len =len(txt)

    #df_csv.to_csv(path_or_buf=os.path.join('/content/data', f'{iter}.csv'), sep='\t')

    df_imag = df['image']
    for idx, img in enumerate(df_imag):
        image = Image.open( BytesIO(img['bytes']))

        image.save(os.path.join(f"./data/{iter}",f"{idx}.jpg" ))

        width_size , height_size = image.size
        if(width_size > max_len):
            max_len =width_size

        if(height_size != 128):
            print(height_size)

print(max_len)
print(str_len)

#pre-processing
retain aspect ratio of images and use batches of padded images in order to effectively use mini-batch Stochastic Gradient Descent(SGD)


In [None]:
# All images are resized to a resulution of 128x1024 pixels for line images or 64x256 pixels.
# Initial images are padded in order to attain the aformentioned fixed size.

# IAM dataset image have a size width x 128(height), the max width size = 5027
import torch
import math
from torch import nn
from PIL import Image
import numpy as np
from torchvision.transforms import functional as F

# target_size : 1024 x 128
def preprocess(image) :
    max_height = 128
    max_width = 1024
    #image = image.convert("L")

    width_size, height_size = image.size
    if (height_size > 128):
        print(width_size, height_size)
        ratio = 128/height_size
        #image = image.resize((math.ceil(width_size*ratio), 128))
        image = F.resize(image , ( 128,math.ceil(width_size*ratio)))
        width_size, height_size = image.size

    if (width_size > 1024):
        ratio = 1024/width_size
        #image = image.resize( (1024, math.ceil(height_size *ratio) ) )
        image =F.resize(image , (math.ceil(height_size *ratio) ,1024))
    image = np.array(image)
    image = torch.tensor(image)
    image = padd_img(image)

    return image


    # padd it 1024x128
def padd_img(img):
    height_tar, width_tar = 128,1024
    height_size, width_size = img.shape

    ## 128 - 51 = 77
    ## 77/2 = > 38  76 + 51 =127

    left = math.floor((width_tar - width_size)/2)
    right =math.floor((width_tar -width_size)/2)
    if(width_size %2 == 1):
        right = math.floor((width_tar - width_size)/2) +1
    top = math.floor((height_tar - height_size)/2)
    bottom = math.floor((height_tar - height_size)/2)
    if (height_size %2 ==1):
        bottom = math.floor((height_tar - height_size)/2) + 1

    img = nn.functional.pad(img,(left,right,top,bottom),mode="constant", value=255)
    return img


#During traing, image augmentation is performed.
#considering only rotation and skew of small magnitude in order to generate valid images
#Additionaly, gaussian nosie is added to the images
from torchvision.transforms import v2
"""
transforms = v2.RandomApply(
    torch.nn.ModuleList([
        v2.RandomAffine(degrees=(-10,10),translate=(0.01,0.05), scale = (0.8,1.2),shear = (-10,10))
        ]) , p=0.5 )
"""
transforms = v2.RandomApply(
    torch.nn.ModuleList([
        v2.RandomAffine(degrees = (0,0),translate=(0.01,0.05), scale = (0.8,1),shear = (-10,10),fill=255)
        ]) , p=0.5 )

# rotation image ... are cutted so did;nt show the character in image ...
# even i didn't see the chacracter it is... not good
# rotation, translation, scaling and shearing  and gray-scale erosion and dilation
# the paper named "Are Multidimensional Recurrent Layers Really Necessary for Handwritten Text Recognition?" said it toolkit defualt value but didn't found it

"""
We perform adequate random distortions on the input images,
in order to artificially augment the training samples and reduce overfitting.

These distortions include: rotation,translation, scaling and shearing
(all performed as a single affine transform) and gray-scale erosion and dilation.
Each of these operations is applied dynamically and independently on each image of the training batch (each with 0.5 probability).
Thus, the exact same image is virtually never observed twice
during training.
The parameters controlling each distortion (e.g. rotation angle, scaling factor, erosion kernel, etc.) are sampled from a fixed distribution.
    -- Are Multidimensional Recurrent Layers Really Necessary for Handwritten Text Recognition?
"""
# transforms  => preprocessing


#Each word/line transcription has spaces added before and after, "He rose from" => " He rose from "

def pre_processing_target(label):
    return " " + label + " "
target_transforms = pre_processing_target

#This operation aims to assist the system during the training phase, For the testing phase, these additional spaces are discarded


In [None]:
import torch
from torchvision.transforms import functional as F
import os
from PIL import Image
import pandas as pd

class IAMData():
    def __init__(self,split ,root_path = '/content/data' , transform =None,target_transform = None ):

        assert os.path.exists(root_path), "valueError: IAMData 'wrong root_path' "

        assert  split in ['train', 'test', 'validation'] , "valueError: IAMData 'split must be one of ['train', 'test', 'validation]'"

        self.preprocess = preprocess


        self.annotations_file = os.path.join(root_path, f"{split}.csv")
        self.img_dir = os.path.join(root_path, split)

        self.transform = transform
        self.target_transform = target_transform

        self.img_labels = pd.read_csv(self.annotations_file, delimiter = '\t', quotechar='|')


    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self,idx):
        img_path = os.path.join(self.img_dir, f"{self.img_labels.iloc[idx,0]}.jpg")

        image = Image.open(img_path)
        label =self.img_labels.iloc[idx,1]

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        image = self.preprocess(image)
        image = torch.Tensor(image).float().unsqueeze(0).reshape(1,128,1024)

        return image, label


#architectural
replace the column-wise concatenation step between the CNN backbone and the recurrent head with a max-pooling step. such a choice not only redueces the required parameters but has an intuitive motivation: we care only about the existence of a character and not its vertical poision

#Convolutional Backbone

In [None]:
"""
In our model, the convolutional backbone is made
up of standard convolutional layers and ResNet blocks [12], interspersed with
max-pooling and dropout layers.
"""

"""class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

"""
import torch
from torch.nn import functional

class Residual_Block(nn.Module):
    def __init__(self,in_channels, out_channels,stride =1):
        super(Residual_Block,self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = functional.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = functional.relu(out)

        return out

class ConvolutionalBackbone(nn.Module):
    def __init__(self):
        super(ConvolutionalBackbone, self).__init__()

        self.CNNLayer = nn.Conv2d(1,32,(7,7), stride=1 , padding=3)

        self.MaxPoolFirst = nn.MaxPool2d((2,2))
        self.ResBlockFirst = nn.Sequential(
              Residual_Block(32,64),
              Residual_Block(64,64)
        )

        self.MaxPoolSecond = nn.MaxPool2d((2,2))
        self.ResBlockSecond = nn.Sequential(
              Residual_Block(64,128),
              Residual_Block(128,128),
              Residual_Block(128,128),
              Residual_Block(128,128)
        )

        self.MaxPoolThird = nn.MaxPool2d((2,2))
        self.ResBlockThird = nn.Sequential(
              Residual_Block(128,256),
              Residual_Block(256,256),
              Residual_Block(256,256),
              Residual_Block(256,256)
        )

    def forward(self,x):
        x = x.reshape([-1,1,128,1024])
        #[128,1024]
        logits = self.CNNLayer(x)
        logits = self.MaxPoolFirst(logits)
        logits = self.ResBlockFirst(logits)

        logits = self.MaxPoolSecond(logits)
        logits = self.ResBlockSecond(logits)

        logits = self.MaxPoolThird(logits)
        logits = self.ResBlockThird(logits)

        #[batch, channel, height, width ] =[batch, 256,16,128]
        return logits




# the first layer is a 7x7 convolution with 32 output channels


#

In [None]:
model = ConvolutionalBackbone()


a = torch.rand(2,1,128,1024)
a = model(a)
print(a.shape)

#Flattening Operation => ColumneWiseMaxPooling

In [None]:
"""
    The convolutional backbone output should be transformed into a sequence of features in order to processed by recurrent networks.
    Typical HTR approaches, assume a column-wise approach (towards the writing direction) to ideally simulate a character by character processing.
    the CNN output is flattened by a max-pooling operation in a column-wise manner

    ... Apart from the apparent computational advantage,
    column-wise max-pooling achieves model translation invariance in the vertical direction
"""

#.. it is too small, it is better to apply it just function ... in whole class ...
class ColumnwiseMaxPool(nn.Module):
    def __init__(self):
        super(ColumnwiseMaxPool,self).__init__()

    def forward(self,CNNoutput):
        MaxPoolData  =torch.max(CNNoutput, dim = 2)
        #[batch,channel,width]
        return MaxPoolData

#Recurrent Head


In [None]:
"""
    The recurrent component consists of 3 stacked Bidirectional Long Short-Term Memory (BiLSTM) units of hidden size 256. These are followed
by a linear projection layer, which converts the sequence to a size equal to the number of possible character tokens, nclasses (including the blank character, required by CTC).

    The final output of the recurrent part can be translated into a sequence of probability distributions by applying a softmax operation.
"""

"""
 During evaluation, the aforementioned greedy decoding is performed by selecting the character
  with the highest probability at each step and then removing the blank characters from the resulting sequence [8].
"""

class RecurrentHead(nn.Module):
    def __init__(self, nclasses):
        super(RecurrentHead,self).__init__()
        self.BiLSTM = nn.LSTM(256,256,num_layers=3,bidirectional=True)
        self.Projection = nn.Linear(512,nclasses)
        #nclasses include the blank character

    def forward(self,x):
        x = x.permute(2,0,1)
        #[batch,channel, width] = > [width,batch,channel]
        outputs, (hidden , cell) = self.BiLSTM(x)
        logits = self.Projection(outputs)
        #[width,batch,2 * channel] => [width,batch, nclasses ]
        logits = logits.transpose(0,1)
        #[batch,width,nclasses]

        return logits



# CTC shortcut


In [None]:
"""
Architecture-wise, the CTC shortcut module consists only of
a single 1D convolutional layer, with kernel size 3. Its output channels equal
to the number of the possible character tokens (nclasses). Therefore, the 1D
convolutional layer is responsible for straightforwardly encoding context-wise
information and providing an alternative decoding path.
"""

class CTCshortcut(nn.Module):
    def __init__(self, nclasses):
        # nclasses include the black_id
        super(CTCshortcut,self).__init__()

        self.ConLayer = nn.Conv1d(256,nclasses,3,padding = 1)


    def forward(self,x):
        logits = self.ConLayer(x)
        #[batch,channel,width]=>[batch,nclasses,width]
        logits = logits.transpose(1,2)
        #[batch,nclasses,width] => [batch,width,nclasses]

        return logits


#CTC Loss

the multi-task loss is written as

$L_{CTC}(f_{rec}(f_{cnn}(I)) ;s ) +0.1 L_{CTC}((f_{shortcut}(f_{cnn}(I)); s))$



    The CTC shortcut is trained along with the main architecture using
    a multitask loss by adding the corresponding CTC losses of the two branches with the appropriate weights

    Since CTC shortcut acts only as an auxiliary training path, it is weighted by 0.1
    to reduce its relative contribution to the overall loss.



In [None]:
#total HTRnet

class HTRnet(nn.Module):
    def __init__(self, nclasses):
        #nclasses include the blank_id
        super(HTRnet,self).__init__()

        self.CNNBackbone = ConvolutionalBackbone()
        # columnewise Maxpooling

        self.RecHead = RecurrentHead(nclasses)
        self.CTCshort = CTCshortcut(nclasses)


    def forward(self, x):
        features = self.CNNBackbone(x)
        #[batch,channel,height, width]

        # columnewise Maxpooling
        colMaxPool = torch.max(features,dim =2).values
        #[batch,channel, width]

        RecOutput =  self.RecHead(colMaxPool)
        CTCOutput = self.CTCshort(colMaxPool)
        #[batch,width, nclasses]

        return RecOutput, CTCOutput


In [None]:
model = HTRnet(93)
a = torch.rand(2,1,128,1024)
Rec, CTC = model(a)
print(Rec.shape)
print(CTC.shape)

## Matrix

it is copy of HTR-best-practices

https://github.com/georgeretsi/HTR-best-practices/blob/main/utils/metrics.py

In [None]:
import editdistance

import nltk
nltk.download('punkt_tab')
from nltk.tokenize import word_tokenize

# character error rate
class CER:
    def __init__(self):
        self.total_dist = 0
        self.total_len = 0

    def update(self, prediction, target):
        dist = float(editdistance.eval(prediction, target))
        self.total_dist += dist
        self.total_len += len(target)

    def score(self):
        return self.total_dist / self.total_len

    def reset(self):
        self.total_dist = 0
        self.total_len = 0

# word error rate
# two supported modes: tokenizer & space
class WER:
    def __init__(self, mode='tokenizer'):
        self.total_dist = 0
        self.total_len = 0

        if mode not in ['tokenizer', 'space']:
            raise ValueError('mode must be either "tokenizer" or "space"')

        self.mode = mode

    def update(self, prediction, target):
        if self.mode == 'tokenizer':
            target = word_tokenize(target)
            prediction = word_tokenize(prediction)
        elif self.mode == 'space':
            target = target.split(' ')
            prediction = prediction.split(' ')

        dist = float(editdistance.eval(prediction, target))
        self.total_dist += dist
        self.total_len += len(target)

    def score(self):
        return self.total_dist / self.total_len

    def reset(self):
        self.total_dist = 0
        self.total_len = 0



In [None]:
def decoder(tdec, tdict, blank_id=0):

    tt = [v for j,v in enumerate(tdec) if j==0 or v !=tdec[j-1]]
    dec_transcr = ''.join([tdict[t] for t in tt if t != blank_id])

    return dec_transcr

a = [[1,2,3,5,7,5,8,9,2], [2,7,6,4,7,43,5,6,5]]
iTOc = np.load('iToc.npy',allow_pickle=True).tolist()
print(iTOc)
dec_transcr = decoder(a,iTOc)
print(dec_transcr)

In [None]:
iTOc = np.load('iToc.npy', allow_pickle=True)
print(iTOc)
iTOc = iTOc.tolist()
print(iTOc)

# training
add and extra shortcut branch, consisting of a single 1D convoluion layer, at the ouptput of the CNN backbone. this branch results to a an extra character sequnce estimation, trained in parallel to the recurrent branch. both branches use a CTC loss. the motivation behind such a choice comes from the increased difficulty of training recurrent layers. However, if such a strainghtforward shortcut exists, the output of the CNN backbone should coverge to more discriminative features, ideal for fully harnessing the power of recurrent layers compared to an end-to-end training scheme

In [None]:
"""
  the training of the HTR system is performed via an Adam optimizer
  using an initial learing rate of  0.001 which gradually decreases using a multistep schedular.
  the overall training epochs are 240 and the scheduler decreses the learning rate by a factor of 0.1 at 120 and 180 epoches
"""

"""
  the optimizering sheme, with minor modifications, is commonly used for HTR systems.
"""


from torch.utils.data import DataLoader
from torch import nn
import tqdm

from datetime import datetime

class HTRTrainer(nn.Module):
    def __init__(self):
        super(HTRTrainer,self).__init__()

        self.prepare_dataloaders()
        self.prepare_net()
        self.prepare_losses()
        self.prepare_optimizers()

    def prepare_dataloaders(self):
        train_ds = IAMData('train',transform = transforms,target_transform= target_transforms)
        test_ds = IAMData('test')
        validation_ds = IAMData('validation')

        self.char_set = np.load('char_set.npy')
        self.cToi = np.load('cToi.npy', allow_pickle=True).tolist()
        self.iToc = np.load('iToc.npy', allow_pickle=True).tolist()

        train_loader = DataLoader(train_ds,batch_size = 16, shuffle=True)
        validation_loader = DataLoader(validation_ds,batch_size = 16, shuffle=True)
        test_loader = DataLoader(test_ds,batch_size = 16, shuffle=True)

        self.loaders = {'train':train_loader ,'test': test_loader,'validation': validation_loader}

    def prepare_net(self):
        nclasses = len(self.char_set)+1

        self.net = HTRnet(nclasses)

    def prepare_losses(self):
        self.CTCLoss = nn.CTCLoss(reduction='sum')
        # the Log_probs of input nn.functional.log_softmax()

    def prepare_optimizers(self):
        """
          This optimizing scheme, with minor modifications, is commonly used for HTR systems.
        """
        """
          The overall training epochs are 240 and
          the scheduler decreases the learning rate by a factor of 0.1 at 120 and 180 epochs
        """
        self.optimizer = torch.optim.AdamW(self.net.parameters(), 0.001,weight_decay=0.00005)
        self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,[120,180], gamma=0.1)

    def decoder(self,tdecs, tdict, blank_id=0):

        tt= [t for j, t in enumerate(tdecs) if j == 0 or t!=tdecs[j-1] ]
        dec_transcr = "".join([tdict[t] for t in tt if t != blank_id])

        return dec_transcr

    def encoder(self, labels, tdict, blank_id = 0):

        enc_transcrs = []
        for label in labels:
            enc_transcr = []
            for c in label :
                enc_transcr.append(tdict[c])
            enc_transcrs.append(enc_transcr)

        return enc_transcrs

    def padd_transcr(self, labels, max_length, blank_id =0):

        result=[]
        for label in labels:
            label = torch.tensor(label)
            out = torch.nn.functional.pad(label, (0,max_length -len(label)), mode = 'constant',value=blank_id )
            result.append(out)

        result = torch.stack(result,dim=0)
        return result

    def max_length(self, labels):
        max_length_ =-1
        for label in labels:
            if len(label)> max_length_:
                max_length_ = len(label)

        return max_length_

    def target_lengths(self, labels):
        target_lengths_ =[]
        for label in labels:
            target_lengths_.append(len(label))

        target_lengths_ = torch.tensor(target_lengths_)
        return target_lengths_

    def sample_decoding(self):
        # get a random image from the test set

        img, transcr = self.loaders['validation'].dataset[np.random.randint(0,len(self.loaders['validation'].dataset))]
        self.net.eval()
        self.net.cpu()

        #plt.figure()
        #plt.imshow(to_pil_image(img.numpy()), cmap = 'gray')

        with torch.no_grad():
            tst_o, __ = self.net(img)

        tdec = tst_o.argmax(2).cpu().numpy()
        tdec = tdec[0]

        dec_transcr = self.decoder(tdec, self.iToc)
        target = transcr.strip()
        prediction = dec_transcr.strip()

        print('orig:: ' + target)
        print('pred:: ' + prediction)

        dist = float(editdistance.eval(prediction, target))
        length = len(target)
        print('dist:: {}'.format(dist/length))

    def load_model(self, _path):

        #self.net.load_state_dict(torch.load(_path))
        self.net.load_state_dict(torch.load(_path, map_location=torch.device('cpu') ) )

    def test(self, epoch, tset = 'test'):
        self.net.eval()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if tset == 'test':
            loader = self.loaders['test']
        elif tset == 'validation':
            loader = self.loaders['validation']

        else:
            print('not recognized set in test function')

        print('####################### Evaluating {} set at epoch {} #######################'.format(tset, epoch))

        cer, wer = CER(), WER(mode='tokenizer')
        self.net.to(device)
        for (imgs, transcrs) in tqdm.tqdm(loader):

            imgs =imgs.to(device)
            with torch.no_grad():
                o, __  = self.net(imgs)

            tdecs = o.argmax(2).cpu().numpy()

            for tdec, transcr in zip(tdecs, transcrs):

                transcr = transcr.strip()
                dec_transcr = self.decoder(tdec, self.iToc).strip()

                cer.update(dec_transcr, transcr)
                wer.update(dec_transcr, transcr)

        cer_score = cer.score()
        wer_score = wer.score()

        print('CER at epoch {}: {:.3f}'.format(epoch, cer_score))
        print('WER at epoch {}: {:.3f}'.format(epoch, wer_score))


    def train_one_epoch(self,epoch_index, device):
        running_loss =0.
        last_loss =0.

        loader = self.loaders['train']
        self.net.train()
        iter=0
        for (imgs, labels) in tqdm.tqdm(loader):
            imgs = imgs.to(device)

            self.optimizer.zero_grad()

            RecTrancr, CTCTranscr = self.net(imgs)
            # we didn't do log_softmax yet
            RecTdecs = torch.nn.functional.log_softmax(RecTrancr, dim = 2).transpose(0,1)

            width_size, batch_size, nclasses = RecTdecs.shape
            input_lengths = torch.full(size=(batch_size,), fill_value=width_size, dtype=torch.long)
            input_lengths = input_lengths.to(device)


            labels = self.encoder(labels, self.cToi)
            max_lengths = self.max_length(labels)
            target_lengths = self.target_lengths(labels)
            target_lengths = target_lengths.to(device)


            #padd labels to max_length
            labels = self.padd_transcr(labels, max_length=max_lengths, blank_id = 0)
            labels = labels.to(device)
            loss = self.CTCLoss(RecTdecs, labels,input_lengths, target_lengths)
            CTCTdecs = torch.nn.functional.log_softmax(CTCTranscr, dim = 2).transpose(0,1)
            loss += self.CTCLoss(CTCTdecs,labels,input_lengths, target_lengths) * 0.1

            loss.backward()
            self.optimizer.step()
            self.schedular.step()

            running_loss += loss.item()

        last_loss = running_loss / 100 # loss per batch
        running_loss = 0.

        return last_loss

    def train(self):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        EPOCHS = 800
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.to(device)

        for epoch in range(EPOCHS):
            print('EPOCH {}:'.format(epoch + 1))
            self.net.train()

            # Make sure gradient tracking is on, and do a pass over the data
            avg_loss = self.train_one_epoch(epoch, device)

            self.net.eval()

            print('LOSS train {}'.format(avg_loss))
            self.test(epoch= epoch,tset='validation')

            # Track best performance, and save the model's state
            if epoch % 20 ==0 :
                self.test(epoch = epoch, tset='test')
                model_path = 'model_{}_{}'.format(timestamp, epoch)
                torch.save(self.net.state_dict(), model_path)

        model_path = 'model_{}_{}'.format(timestamp, epoch)
        torch.save(self.net.state_dict(), model_path)





In [None]:
Trainer = HTRTrainer()
Trainer.load_model("model_20250503_121842_140")
Trainer.sample_decoding()

In [None]:
for i in range(100):
    Trainer.sample_decoding()

In [None]:
Trainer.test(0)

In [None]:
Trainer.train()

#Evaluation

In [None]:
class HTREval(nn.Module):
    def __init__(self):
        super(HTREval,self).__init__()
        self.prepare_dataloaders()
        self.prepare_net()

    def prepare_dataloaders(self):
        train_ds = IAMData('train',transform = transforms,target_transform= target_transforms)
        test_ds = IAMData('test')
        validation_ds = IAMData('validation')

        self.char_set = np.load('char_set.npy')
        self.cToi = np.load('cToi.npy', allow_pickle=True).tolist()
        self.iToc = np.load('iToc.npy', allow_pickle=True).tolist()

        train_loader = DataLoader(train_ds,batch_size = 16, shuffle=True)
        validation_loader = DataLoader(validation_ds,batch_size = 16, shuffle=True)
        test_loader = DataLoader(test_ds,batch_size = 16, shuffle=True)

        self.loaders = {'train':train_loader ,'test': test_loader,'validation': validation_loader}

    def prepare_net(self):
        nclasses = len(self.char_set)+1

        self.net = HTRnet(nclasses)

    def load_model(self, _path):
        self.net.load_state_dict(torch.load(_path))

    def decoder(self,tdec, tdict, blank_id=0):
        tt = [v for j,v in enumerate(tdec) if j==0 or v !=tdec[j-1]]
        dec_transcr = ''.join([tdict[t] for t in tt if t!=blank_id])

        return dec_transcr


    def test(self, epoch, tset = 'test'):
        self.net.eval()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if tset == 'test':
            loader = self.loaders['test']
        elif tset == 'validation':
            loader = self.loaders['validation']

        else:
            print('not recognized set in test function')

        print('####################### Evaluating {} set at epoch {} #######################'.format(tset, epoch))

        cer, wer = CER(), WER(mode='tokenizer')
        for (imgs, transcrs) in tqdm.tqdm(loader):

            imgs.to(device)
            with torch.no_grad():
                o, __ = self.net(imgs)

            tdecs = o.argmax(2).permute(1, 0).cpu().numpy().squeeze()

            for tdec, transcr in zip(tdecs, transcrs):
                transcr = transcr.strip()
                dec_transcr = self.decoder(tdec, self.iToc).strip()

                cer.update(dec_transcr, transcr)
                wer.update(dec_transcr, transcr)

        cer_score = cer.score()
        wer_score = wer.score()

        print('CER at epoch {}: {:.3f}'.format(epoch, cer_score))
        print('WER at epoch {}: {:.3f}'.format(epoch, wer_score))


