In [9]:
import pickle
import matplotlib.pyplot as plt
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
import pytorch_lightning as L
from pytorch_lightning import loggers as pl_loggers
import torch
import pdb
import numpy as np
from torchmetrics.text import CharErrorRate, WordErrorRate
import torchvision.transforms as transforms
from torchvision.io import read_image, ImageReadMode
from torch.utils.data import Dataset
import xml.etree.ElementTree as ET
import glob
import os.path
import hashlib

In [8]:
from hashlib import md5
string = "test"
print(md5(string.encode()))

<md5 _hashlib.HASH object @ 0x7f6f65a3f090>


In [2]:
all_chars = " ',-.:;ABCDEFGHIJKLMNOPQRSTUVWXabcdefghijklmnopqrstuvwxyz°¶–’"
print(len(all_chars))
char_to_num = {}
num_to_char = {}
for i in range(len(all_chars)):
    num_to_char[i+1] = all_chars[i]
    char_to_num[all_chars[i]] = i+1

61


In [3]:
class LineImageDataset(Dataset):
    def classify(self, line_im_filename):
        str_hash = hashlib.md5(line_im_filename.encode()).hexdigest()
        hash_num = int(str_hash[:8], 16) % 100 #"random" num between 0 and 99
        if hash_num < 90: return "train"
        else: return "val"
    
    def __init__(self, dirname, char_to_num, num_to_char, data_type, transform=None):
        self.transform = transform       
        self.char_to_num = char_to_num
        self.num_to_char = num_to_char
        self.data_type = data_type
        self.line_images = []
        self.line_image_filenames = []
        self.labels = []
        self.num_labels = []
        
        #Iterate over all lines of all XML files
        ns = {'ns': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15'}
        ET.register_namespace('', ns['ns'])

        #print(dirname)
        #print(glob.glob("./data/*.xml"))
        for filename in sorted(glob.glob(dirname + "/*.xml")):
            #print(filename)
            tree = ET.parse(filename)
            root = tree.getroot()

            image_filename = root.find('ns:Page', ns).get('imageFilename')

            #First iteration: calculate average line spacing
            for text_region in root.findall('.//ns:TextRegion', ns):
                for lineno, text_line in enumerate(text_region.findall('.//ns:TextLine', ns)):                    
                    line_im_filename = dirname + "/line_{}_{}".format(lineno, image_filename)
                    line_im_filename, _ = os.path.splitext(line_im_filename)
                    line_im_filename += ".png"
                    if self.classify(line_im_filename) != data_type:
                        continue
                        
                    self.line_image_filenames.append(line_im_filename)
                    #self.line_images.append(read_image(line_im_filename, ImageReadMode.GRAY))   
                    self.line_images.append(torch.tensor(np.load(line_im_filename.replace(".png", ".npy")), dtype=torch.float32).unsqueeze(0))
                    
                    text = text_line.find('.//ns:TextEquiv', ns).find('.//ns:Unicode', ns).text
                    self.labels.append(text.strip())
                    self.num_labels.append(torch.tensor([self.char_to_num[c] for c in text]))
                    if self.classify(line_im_filename) == "val":
                        print(text)
                                        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):                
        image = self.line_images[idx]
        
        if self.transform is not None:
            #median = torch.median(image)
            image = self.transform(image)
            #image[image == -1] = torch.median(image)
       
        return {"image": image, "target": self.num_labels[idx], "text": self.labels[idx]}


In [4]:
train_transform = transforms.Compose(
    [
        #transforms.ToPILImage(),
       transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
       transforms.RandomAffine(0.7, translate=(0.01, 0.02), scale=(0.98, 1.02)),
       transforms.RandomChoice([
           transforms.RandomAdjustSharpness(2, p=0.5),
            transforms.GaussianBlur(21, (1,6))
        ]),
        #transforms.RandomEqualize(p=1),
        #transforms.ToTensor(),        
        transforms.Normalize(0.15, 0.38)
    ])

val_transform = transforms.Compose(
    [
        #transforms.ToPILImage(),
        #transforms.RandomEqualize(p=1),
        #transforms.ToTensor(),        
        transforms.Normalize(0.15, 0.38)
    ])

train_dataset = LineImageDataset("data/", char_to_num, num_to_char, data_type="train", transform=train_transform)
val_dataset = LineImageDataset("data/", char_to_num, num_to_char, data_type="val", transform=val_transform)
print(val_dataset[0])

elemosina pertinens ad ecclesiam de Duddebir unde Willelmus de Ros est persona an laicum feodum Roberti de Furches 
mesuagio et duabus acris terre cum pertinenciis in Stokes die quo obilt etc. et si etc. Unde Adam de Stok unum mesuagium 
alius ipsam inde inplacitaret teneretur ei warantizare. Et predictus Colemannus le Blund non potest hoc 
dicunt enim quod idem Colemannus per longum tempus ante mortem suam tenementa illa dedit et concessit cuidam Colemanno 
predicta communa nisi ex gracia et pro suo dando et aliquando faciendo sectam ad molendinum predicti Thome de Brochton. Et ideo consideratum 
dicunt quod, si aliqua disseisina ei inde facta fuerit, facta fuit per predictum Galfridum patrem predicti Galfridi de Ledewyk et non per ipsum Galfridum 
Et perquirat sibi per aliud breve si voluerit. Pardonatur misericordia quia infra etatem.
Et Willelmus et Iohannes veniunt et nichil dicunt quare assisa remaneat.
super sacramentum suum hoc idem testantur. Et ideo consideratum est quod pred

In [None]:
#plt.imshow(torch.tensor(np.load("data/line_0_0001_JUST1-734m5d.npy")))
idx = 40
print(val_dataset[idx]["image"][0].mean(), val_dataset[idx]["image"][0].std())
plt.figure(figsize=(16,6))
plt.imshow(train_dataset[20]["image"][0], cmap="gray")

In [None]:
# class MyNN(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.features = nn.Sequential(
#             nn.Conv2d(1, 32, (4,16), padding=(1,7)),
#             nn.ReLU(),
#             nn.Dropout2d(0.1),
#             nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            
#             nn.Conv2d(32, 32, (4,16), padding=(1,7)),
#             nn.ReLU(),
#             nn.Dropout2d(0.1),
#             nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            
#             nn.Conv2d(32, 64, (3,8), padding=(1,3)),
#             nn.ReLU(),
#             nn.Dropout2d(0.1),
#             nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            
#             nn.Conv2d(64, 64, (3,8), padding=(1,3)),
#             nn.ReLU(),
#             nn.Dropout2d(0.1)
#         )
        
#         self.lstms = nn.ModuleList([
#             nn.LSTM(960, 256, bidirectional=True, batch_first=True),
#             nn.Dropout(0.3),
#             nn.LSTM(512, 256, bidirectional=True, batch_first=True),
#             nn.Dropout(0.3),
#             nn.LSTM(512, 256, bidirectional=True, batch_first=True),
#             nn.Dropout(0.3),
#         ])
#         self.lin = nn.Linear(512, 62)

#     def forward(self, x):
#         x = self.features(x)
#         x = x.contiguous().view(-1, x.shape[1] * x.shape[2], x.shape[3]).transpose(1,2)
#         #x = x.contiguous().view(x.shape[0], x.shape[3], x.shape[1] * x.shape[2])
        
#         for layer in self.lstms:
#             if isinstance(layer, nn.LSTM):
#                 x, _ = layer(x)
#             else:
#                 x = layer(x)      
#         x = self.lin(x)
#         x = nn.functional.log_softmax(x, dim=2)
#         return x.transpose(1,2)

# net = MyNN()

In [None]:
class MyNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, (4,16), padding=(1,7)),
            nn.ReLU(),
            nn.BatchNorm2d(32),            
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Conv2d(32, 32, (4,16), padding=(1,7)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Conv2d(32, 64, (3,8), padding=(1,3)),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Conv2d(64, 64, (3,8), padding=(1,3)),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        
        self.lstms = nn.ModuleList([
            nn.LSTM(960, 256, bidirectional=True, batch_first=True),
            #nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.LSTM(512, 256, bidirectional=True, batch_first=True),
            nn.Dropout(0.3),
            #nn.BatchNorm1d(256),
            nn.LSTM(512, 256, bidirectional=True, batch_first=True),
            nn.Dropout(0.3),
        ])
        self.lin = nn.Linear(512, 62)

    def forward(self, x):
        x = self.features(x)
        x = x.contiguous().view(-1, x.shape[1] * x.shape[2], x.shape[3]).transpose(1,2)
        #x = x.contiguous().view(x.shape[0], x.shape[3], x.shape[1] * x.shape[2])
        
        for layer in self.lstms:
            if isinstance(layer, nn.LSTM):
                x, _ = layer(x)
            else:
                x = layer(x)      
        x = self.lin(x)
        x = nn.functional.log_softmax(x, dim=2)
        return x.transpose(1,2)

net = MyNN()

In [None]:
#device = torch.device("cuda")

class LatinTranscriber(L.LightningModule):
    def __init__(self, net, codec_l2c):
        super().__init__()
        self.codec_l2c = codec_l2c
        self.cer_calc = CharErrorRate()
        self.wer_calc = WordErrorRate()
        self.train_cer_calc = CharErrorRate()
        self.train_wer_calc = WordErrorRate()
        self.net = net
                
    def get_loss(self, batch, batch_idx):
        #input, _, target = batch
        
        target = batch["target"]
        target_length = batch["target"].shape[1]
       
        #self.net = self.net.to(device)
        input = batch["image"]#.to(device)
        
        #output, _ = self.net(input)
        output = self.net(input)
        #print("Output shape", output.shape)
        # height should be 1 by now
        #if output.size(2) != 1:
        #    raise ValueError('Expected dimension 3 to be 1, actual {}'.format(output.size(2)))
        #output = output.squeeze(2)
        
        output_length = output.shape[-1]
        
        loss_func = nn.CTCLoss(reduction='sum', zero_infinity=True)
        loss = loss_func(output.permute(2,0,1), target, (output_length,), (target_length,))
        return loss, output
    
    def on_train_epoch_start(self):
        self.train_cer_calc.reset()
        self.train_wer_calc.reset()
        
        
    def _get_current_lr(self):
        for param_group in self.trainer.optimizers[0].param_groups:
            return param_group['lr']
        
    def on_train_epoch_end(self):
        char_accuracy = 1 - self.train_cer_calc.compute()
        word_accuracy = 1 - self.train_wer_calc.compute()
        lr = self._get_current_lr()
        
        self.log("train_char_acc", char_accuracy)
        self.log("train_word_acc", word_accuracy)
        self.log('lr-Adam', lr)

    def training_step(self, batch, batch_idx):
        assert self.net.training
        loss, output = self.get_loss(batch, batch_idx)
        prediction, truth = self.get_prediction_and_truth(output, batch["target"])
        self.train_cer_calc.update(truth, prediction)
        self.train_wer_calc.update(truth, prediction)
        self.log("train_loss", loss)
        return loss
    
    def get_prediction_and_truth(self, output, target):
        target = torch.squeeze(target).cpu().numpy()
        #truth = ''.join([self.codec_l2c[(target[i].item(),)] for i in range(len(target))])        
        truth = ''.join([self.codec_l2c[target[i].item()] for i in range(len(target))])  
        labels = torch.argmax(torch.squeeze(output), axis=0).cpu().numpy()
        prediction = ""
        for i in range(len(labels)):
            label = labels[i]
            if label != 0 and (i==0 or label != labels[i-1]):
                #prediction += self.codec_l2c[(label,)]
                prediction += self.codec_l2c[label]
                
        return prediction, truth
    
    def validation_step(self, batch, batch_idx):
        assert not self.net.training
        assert batch["target"].shape[0] == 1
        loss, output = self.get_loss(batch, batch_idx)
        prediction, truth = self.get_prediction_and_truth(output, batch["target"])
        self.cer_calc.update(truth, prediction)
        self.wer_calc.update(truth, prediction)
        
        if batch_idx == 0:
            print("batch 0", truth)
        # Get tensorboard logger
        if batch_idx < 16:
            tb_logger = None
            for logger in self.trainer.loggers:
                if isinstance(logger, pl_loggers.TensorBoardLogger):
                    tb_logger = logger.experiment
                    break

            tb_logger.add_image(f'Validation #{batch_idx}, target: {truth}', batch['image'][0], self.global_step, dataformats="CHW")
            tb_logger.add_text(f'Validation #{batch_idx}, target: {truth}', prediction, self.global_step)

        return loss
    
    def on_validation_epoch_start(self):
        self.cer_calc.reset()
        self.wer_calc.reset()
        
    def on_validation_epoch_end(self):
        char_accuracy = 1 - self.cer_calc.compute()
        word_accuracy = 1 - self.wer_calc.compute()
        print("Epoch, char acc, word acc:", self.current_epoch, round(char_accuracy.item(), 4), round(word_accuracy.item(), 4))
        self.log("val_char_acc", char_accuracy)
        self.log("val_word_acc", word_accuracy)


    def configure_optimizers(self):        
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)
        '''self.scheduler = optim.lr_scheduler.OneCycleLR(
                                        optimizer, max_lr=3e-3,
                                        steps_per_epoch=492,
                                        epochs=100)
        sched = {
            'scheduler': self.scheduler,
            'interval': 'step'
        }
        return {"optimizer": optimizer, "lr_scheduler": sched}'''
                      
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "name": "lr-scheduler",
                "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.34, min_lr=1e-4, patience=15),
                "monitor": "val_word_acc",
                "frequency": 1
            },
        }
        
        #optimizer = optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-2)        
        #return optimizer

transcriber = LatinTranscriber(net, num_to_char)

In [None]:
plt.figure(figsize=(16,6))
idx = 10
print(torch.min(val_dataset[idx]["image"]), torch.max(val_dataset[idx]["image"]))
plt.imshow(val_dataset[idx]["image"][0], cmap="gray", vmin=-0.5, vmax=2)

In [None]:
#train_set_size = int(len(dataset) * 0.9)
#valid_set_size = len(dataset) - train_set_size

# split the train set into two
#seed = torch.Generator().manual_seed(40)
#train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size], generator=seed)

train_loader = utils.data.DataLoader(train_dataset, num_workers=4)
valid_loader = utils.data.DataLoader(val_dataset, num_workers=4)

In [None]:
val_dataset[0]

In [None]:
trainer = L.Trainer(accumulate_grad_batches=1, max_epochs=100)
trainer.fit(transcriber, train_loader, valid_loader)

In [None]:
#torch.save(net.state_dict(), "my_custom_best.pt") %94% char, 82% word
#torch.save(net.state_dict(), "my_aguillar_repro.pt") #91% char, 69% word
#torch.save(net.state_dict(), "my_all_custom_best.pt") #88% char, 66% word