In [None]:
import zipfile
import tarfile
import os
from torchtext.vocab import build_vocab_from_iterator
from string import printable
from functools import reduce
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import torchvision.transforms.functional as fn
from sklearn.model_selection import train_test_split
!pip install torchinfo -q
from torchinfo import summary
!pip install torchmetrics -q
!pip install Unidecode
from torchmetrics.functional.text import char_error_rate

In [None]:
imgur_zip_path: str = "drive/MyDrive/handwritten/imgur/sub_images.zip"
iam_words_tgz_path: str = "drive/MyDrive/handwritten/words.tgz"
metadata_path: str = "drive/MyDrive/handwritten/metadata.csv"

imgur_path: str = "imgur-images"
iam_words_path: str = 'iam-images'

def unpack_archive_tgz(path: str, extract_path: str):
    os.makedirs(extract_path, exist_ok=True)
    with tarfile.open(path, 'r') as file:
      file.extractall(path=extract_path)

def unpack_archive_zip(path: str, extract_path: str):
  os.makedirs(extract_path, exist_ok=True)
  with zipfile.ZipFile(path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

def get_metadata(imgur_path: str, iam_path: str, metadata_path: str):
    metadata = pd.read_csv(metadata_path, index_col=0).reset_index(drop=True)

    iam = metadata['database'] == 'iam'
    metadata.loc[iam, 'image_path'] = iam_path + '/' + metadata.loc[iam, 'image_path']

    imgur = metadata['database'] == 'imgur'
    metadata.loc[imgur, 'image_path'] = imgur_path + '/' + metadata.loc[imgur, 'image_path']

    return metadata


metadata = get_metadata(
    imgur_path='imgur-images/sub_images',
    iam_path='iam-images',
    metadata_path=metadata_path
)

unpack_archive_tgz(iam_words_tgz_path, iam_words_path)
unpack_archive_zip(imgur_zip_path, imgur_path)

In [None]:
metadata

In [None]:
import seaborn as sns
def print_shape(w, h, word):
  fig, ax = plt.subplots(1,4, figsize=(16,5))
  sns.histplot(w, label='width', ax=ax[0])
  sns.histplot(h, label='height', ax=ax[1])
  sns.histplot(w/h, label='ratio', ax=ax[2])
  sns.histplot(word.str.len(), label='label_length', ax=ax[3])
  for axe in ax:
    axe.legend()

print('words metadata')
iam_only= (metadata['database'] =='iam')
print_shape(metadata.loc[iam_only, 'w'], metadata.loc[iam_only, 'h'],metadata.loc[iam_only, 'word'])
plt.show()

In [None]:
class CharacterTokenizer:

  pad_token: str = "<PAD>" # 0
  unknown_token: str = "<UNK>" # 1
  sep_token: str = "<SEP>" # 2
  void_token: str = "<VOID>" # 3 to guarantee 104 tokens in vocab

  def __init__(self, max_length: int):
    self.special_tokens = [self.pad_token, self.unknown_token, self.sep_token, self.void_token]
    self.max_length: int = max_length
    self.vocab = build_vocab_from_iterator(printable, specials=self.special_tokens)
    self.vocab.set_default_index(self.vocab[self.unknown_token])


  @property
  def vocab_size(self):
    return len(self.vocab)

  def get_vocab(self):
    return self.vocab.get_stoi()


  def __call__(self, sentence: str):
    indices = self.vocab.lookup_indices(list(sentence))
    sep_indices = []
    for indice, next_indice in zip(indices[:-1], indices[1:]):
      sep_indices.append(indice)
      if indice == next_indice:
        sep_indices.append(self.vocab[self.sep_token])

    if len(indices) > 0:
      sep_indices.append(indices[-1])

    if len(sep_indices) >= self.max_length:
      return sep_indices[:self.max_length]
    return sep_indices + [self.vocab[self.pad_token]]*(self.max_length - len(sep_indices))


  def filter_special_tokens(self, token):
    return token not in self.special_tokens

  def decode(self, ids: list[int] | torch.Tensor, skip_special_tokens: bool = False):
    tokens = self.vocab.lookup_tokens(list(ids))
    if skip_special_tokens:
      tokens = filter(
        self.filter_special_tokens,
        tokens
      )

    tokens = reduce(
        lambda acc, token: acc + token if acc[-1] != token else acc,
        tokens,
        " "
      )
    return tokens[1:]

In [None]:

class WordsDataset(Dataset):

  def __init__(
      self,
      tokenizer: CharacterTokenizer,
      iam_only: bool = False,
      max_length=20,
      img_height: int=40,
      max_img_width: int = 320,
  ):
    self.tokenizer = tokenizer
    self.img_height = img_height
    self.max_length = max_length
    self.max_img_width = max_img_width

    # prevoir au minimum le double de token de sortie en raison des separators
    self.metadata = metadata[
            (metadata['word'].str.len() <= self.max_length)
        ]

    if iam_only:
      self.metadata = self.metadata[self.metadata['database'] == 'iam']

  def __len__(self):
    return len(self.metadata)

  def get_resized_image(self, data: pd.Series):
    image = Image.open(data["image_path"]).convert('L')
    width, height = image.size
    aspect_ratio = width / height
    img_width = round(self.img_height * aspect_ratio)
    image = fn.resize(image, size=[self.img_height, img_width], antialias=None) #type: ignore
    image = np.array(image) / 255.

    if image.shape[1] >= self.max_img_width:
      image = image[:, :self.max_img_width]
    else:
      image = np.pad(
          image,
          pad_width=((0,0), (0,self.max_img_width-image.shape[1])),
          mode="constant",

          constant_values=0
      )
    return image


  def __getitem__(self, idx: int):
    data: pd.Series = self.metadata.iloc[idx]
    label = torch.tensor(self.tokenizer(data['word']))
    image = torch.tensor(self.get_resized_image(data),dtype=torch.float32)

    return image, label

In [None]:
def build_datasets(
    batch_size: int,
    max_length: int,
    img_height: int,
    max_img_width: int,
    iam_only: bool = False,
    random_state: int = 42,
):
  tokenizer = CharacterTokenizer(max_length=max_length)
  dataset = WordsDataset(img_height=img_height, iam_only=iam_only, max_img_width=max_img_width, max_length=max_length, tokenizer=tokenizer)

  train_indices, _test_valid_indices = train_test_split(np.arange(len(dataset)), test_size=0.20, random_state=random_state)
  test_indices, valid_indices = train_test_split(_test_valid_indices, test_size=0.50, random_state=random_state)

  train_set = Subset(dataset, train_indices)
  test_set = Subset(dataset, test_indices)
  valid_set = Subset(dataset, valid_indices)

  train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,drop_last=True)
  test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,drop_last=True)
  valid_loader = DataLoader(valid_set,batch_size=batch_size,shuffle=False,drop_last=True)

  return {
      'train_loader': train_loader,
      'test_loader': test_loader,
      'valid_loader': valid_loader,
      'tokenizer': tokenizer
  }

datasets = build_datasets(
    batch_size=32,
    max_length=20,
    img_height=40,
    iam_only=True,
    max_img_width=320,
    random_state=42
)

In [None]:
tokenizer = datasets['tokenizer']
train = datasets['train_loader']
test = datasets['test_loader']
valid = datasets['valid_loader']

In [None]:
images, labels = next(iter(valid))
print(tokenizer.decode(labels[0], skip_special_tokens=True))
print(tokenizer.decode(labels[1], skip_special_tokens=True))
print(tokenizer.decode(labels[2], skip_special_tokens=True))
print(tokenizer.decode(labels[3], skip_special_tokens=True))

plt.imshow(images[0])
plt.show()
plt.imshow(images[1])
plt.show()
plt.imshow(images[2])
plt.show()
plt.imshow(images[3])


In [None]:
class ResidualBlock(nn.Module):

  def __init__(self,
               input_channels: int,
               planes: int,
               kernel_size: int = 3,
               stride: int = 2,
               force_downsample: bool = False
    ):
    super().__init__()
    self.relu = nn.ReLU(inplace=True)
    self.bn1 = nn.BatchNorm2d(input_channels)
    self.conv1 = nn.Conv2d(input_channels, planes, kernel_size=3, stride=stride,padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)

    if stride != 1 or force_downsample:
      self.downsample = nn.Sequential(
          nn.Conv2d(input_channels, planes, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(planes)
      )
    else:
      self.downsample = None


  def forward(self, X: torch.Tensor):
    x = self.bn1(X)
    x = self.relu(x)
    x = self.conv1(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.conv2(x)
    skip_connection = self.downsample(X) if self.downsample is not None else X
    return x + skip_connection
    # return x

class TinyOCR(nn.Module):

  def __init__(self, vocab_size: int, seq_size: int):
    super().__init__()
    self.cnn = nn.Sequential(
        ResidualBlock(1,16),
        ResidualBlock(16,16, stride=1),
        ResidualBlock(16,32),
        nn.MaxPool2d((2,1)),
        ResidualBlock(32,32, stride=1),
        ResidualBlock(32,64, stride=1, force_downsample=True),
        ResidualBlock(64,64, stride=1),
        ResidualBlock(64,seq_size, stride=1, force_downsample=True),
        nn.MaxPool2d((5,1)),
    )
    self.flatten = nn.Flatten(1,2)
    self.rnn = nn.LSTM(input_size=80, hidden_size=128,num_layers=2, bidirectional=True, batch_first=True)
    self.linear = nn.Linear(256, vocab_size)
    self.softmax = nn.LogSoftmax(-1)

  def forward(self, X: torch.Tensor):
    x = self.cnn(X)
    x = self.flatten(x)
    x, hidden_state = self.rnn(x)
    x = self.linear(x)
    return self.softmax(x)



In [None]:
# nn.Sequential(
#         nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=False),
#         nn.ReLU(),
#         nn.BatchNorm2d(8),
#         nn.MaxPool2d(2),
#         nn.Conv2d(8, 16, kernel_size=3, padding=1, bias=False),
#         nn.ReLU(),
#         nn.BatchNorm2d(16),
#         nn.MaxPool2d((2, 1)),
#         nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
#         nn.ReLU(),
#         nn.BatchNorm2d(32),
#         nn.MaxPool2d((2, 1)),
#         nn.Dropout2d(0.2),
#         nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
#         nn.ReLU(),
#         nn.BatchNorm2d(64),
#         nn.Dropout2d(0.2),
#         nn.Conv2d(64, 64, kernel_size=(4,2), padding='same', bias=False),
#         nn.ReLU(),
#         nn.BatchNorm2d(64),
#         nn.Dropout2d(0.2),
#         nn.MaxPool2d((4,2)),

#     )

In [None]:
from torchvision.models import resnet18

class MollyOCR(nn.Module):

  def __init__(self, vocab_size: int):
    super().__init__()
    resnet = resnet18(weights='IMAGENET1K_V1')
    resnet_modules = list(resnet.children())[1:-3]

    self.cnn1 = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
        *resnet_modules,
        nn.Conv2d(256, 256, kernel_size=(3,6), stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Linear(17,128),
    )
    self.rnn = nn.LSTM(input_size=768, hidden_size=256,num_layers=2, bidirectional=True, batch_first=True, dropout=0.5)
    self.dense = nn.Linear(512, vocab_size)
    self.softmax = nn.LogSoftmax(dim=-1)

  def forward(self, X: torch.Tensor):
    x = self.cnn1(X)
    x = x.permute(0, 3, 1,2)
    x = x.flatten(-2,-1)
    x, hidden = self.rnn(x)
    x = self.dense(x)
    x = self.softmax(x)
    return x


In [None]:
summary(MollyOCR(104), input_size=(2,1,40,320))


In [None]:
summary(TinyOCR(104, 128), input_size=(2,1,40,320))

In [None]:
def get_batch_cer(preds, labels, tokenizer, batch_size):
    preds_cpu = preds.cpu().detach().numpy()
    labels_cpu = labels.cpu().detach().numpy()

    decoded_preds = []
    decoded_labels = []
    for i in range(batch_size):
        decoded_preds.append(
            tokenizer.decode(
                np.argmax(preds_cpu[i], axis=1),
                skip_special_tokens=True,
            )
        )
        decoded_labels.append(
            tokenizer.decode(labels_cpu[i], skip_special_tokens=True)
        )

    return char_error_rate(decoded_preds, decoded_labels).item()

In [None]:
from tqdm import tqdm

def train_epoch(
    model,
    optim,
    loss,
    batch_size: int,
    im_dims: tuple[int, int],
    loaders: tuple,
    input_length: int,
    tokenizer: CharacterTokenizer
):
  model.train()
  total_train_cost = 0
  total_valid_cost = 0
  total_train_cer = 0
  total_valid_cer = 0
  for image, label in tqdm(loaders[0], leave=True):
    optim.zero_grad()

    label = label.to().cuda()
    image = image.view(batch_size,1,im_dims[0],im_dims[1]).to().cuda()
    preds = model(image)


    cost = loss(
        preds.transpose(0,1),
        label,
        torch.full(size=(batch_size,), fill_value=input_length),
        torch.count_nonzero(label, dim=1)
    )
    cost.backward()
    optim.step()
    total_train_cost += cost.item()
    total_train_cer += get_batch_cer(preds, label, tokenizer, batch_size)

  train_cer = total_train_cer / len(loaders[0])
  train_loss = total_train_cost / len(loaders[0])


  model.eval()
  with torch.no_grad():
    for image, label in loaders[1]:
      label = label.to().cuda()
      image = image.view(batch_size,1,im_dims[0],im_dims[1]).to().cuda()
      preds = model(image)
      cost = loss(preds.transpose(0,1), label, torch.full(size=(batch_size,), fill_value=input_length), torch.count_nonzero(label, dim=1)).item()

      total_valid_cost += cost
      total_valid_cer += get_batch_cer(preds, label, tokenizer, batch_size)

    valid_loss = total_valid_cost / len(loaders[1])
    valid_cer = total_valid_cer / len(loaders[1])


  tqdm.write('\n')
  tqdm.write(f'mean loss train: {train_loss}')
  tqdm.write(f'mean loss valid: {valid_loss}')
  tqdm.write('\n')
  tqdm.write(f'mean cer train: {train_cer}')
  tqdm.write(f'mean cer valid: {valid_cer}')


  return train_loss, valid_loss, train_cer, valid_cer

In [None]:
def train_model(
    model,
    optim,
    loss,
    epochs: int,
    batch_size: int,
    im_dims: tuple[int, int],
    loaders: tuple,
    input_length: int,
    tokenizer: CharacterTokenizer,
    scheduler,
    patience: int = 4
):
  total_train_costs = []
  total_valid_costs = []
  total_train_cer = []
  total_valid_cer = []
  counter = 0
  best_valid_loss = np.inf
  for i in range(epochs):
    train_loss, valid_loss, train_cer, valid_cer = (
        train_epoch(model, optim, loss, batch_size, im_dims, loaders, input_length, tokenizer)
    )
    total_train_costs.append(train_loss)
    total_valid_costs.append(valid_loss)
    total_train_cer.append(train_cer)
    total_valid_cer.append(valid_cer)


    before_lr = optim.param_groups[0]["lr"]
    scheduler.step(valid_cer)
    after_lr = optim.param_groups[0]["lr"]
    print("Epoch %d: SGD lr %.4f -> %.4f" % (i, before_lr, after_lr))
    #  Vérifier la perte de validation
    if valid_loss < best_valid_loss:
      best_valid_loss = valid_loss
      counter = 0
    else:
      counter += 1
      if counter >= patience:
          print(f"Early stopping at epoch {i}. Validation loss did not improve.")
          break

  return (
      total_train_costs,
      total_valid_costs,
      total_train_cer,
      total_valid_cer,
  )

In [None]:
from tqdm import tqdm
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau

model = MollyOCR(104).to().cuda()
optim = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = ReduceLROnPlateau(optim, verbose=True, patience=5)
loss = nn.CTCLoss()
epochs = 20
# scheduler = torch.optim.lr_scheduler.SequentialLR(
#     optim,
#     schedulers=[
#         torch.optim.lr_scheduler.LinearLR(optim, start_factor=1/3, end_factor=1.0, total_iters=3),
#         torch.optim.lr_scheduler.StepLR(optim, step_size=3, gamma=0.5)
#     ],
#     milestones=[3]
# )


In [None]:
# summary(model, input_size=(1,1,40,320))
summary(TinyOCR(104, 128), input_size=(1,1,40,320))
#

In [None]:
state_path = "drive/MyDrive/handwritten/tinyocr_more_output_train.pth"
state = torch.load(state_path)
state

In [None]:
model.load_state_dict(state["model"])
optim.load_state_dict(state["optim"])

In [None]:
for g in optim.param_groups:
    g['lr'] = 0.00005

In [None]:
(
    total_train_costs,
    total_valid_costs,
    total_train_cer,
    total_valid_cer,
) = train_model(
    model, optim, loss, epochs,
    batch_size=32, im_dims=(40, 320),
    loaders=(train, test),
    input_length=80,
    tokenizer=tokenizer,
    scheduler=scheduler,
    patience=10
    )

In [None]:
(819200/320)/40

In [None]:
# val_costs = []
# train_costs = []
# train_cer = []
# val_cer = []
# train_avg_loss = []
# val_avg_loss = []
val_costs.extend(total_valid_costs)
train_costs.extend(total_train_costs)
train_cer.extend(total_train_cer)
val_cer.extend(total_valid_cer)
train_avg_loss.extend(avg_train_costs)
val_avg_loss.extend(avg_valid_costs)

In [None]:
print_model_result(total_train_costs,
    total_valid_costs,
    total_train_cer,
    total_valid_cer,
    avg_train_costs,
    avg_valid_costs)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def print_model_result(total_train_costs,
    total_valid_costs,
    total_train_cer,
    total_valid_cer,
    avg_train_costs,
    avg_valid_costs):
  fig, axes = plt.subplots(3,2, figsize=(10,7))
  axes = axes.flatten()

  sns.lineplot(x=np.arange(len(total_train_cer[500:])), y=total_train_cer[500:], ax=axes[0], label='train-cer')
  sns.lineplot(x=np.arange(len(total_valid_cer)), y=total_valid_cer, ax=axes[1], label='valid-cer')
  sns.lineplot(x=np.arange(len(total_train_costs[500:])), y=total_train_costs[500:], ax=axes[2], label='train-loss')
  sns.lineplot(x=np.arange(len(total_valid_costs)), y=total_valid_costs, ax=axes[3], label='valid-loss')
  sns.lineplot(x=np.arange(len(avg_train_costs)), y=avg_train_costs, ax=axes[4], label='avg-train-loss')
  sns.lineplot(x=np.arange(len(avg_valid_costs)), y=avg_valid_costs, ax=axes[4], label='avg-valid-loss')
  plt.show()

print_model_result(train_costs,
    val_costs,
    train_cer,
    val_cer,
    train_avg_loss,
    val_avg_loss)

In [None]:
sum(total_valid_cer)/len(total_valid_cer)


In [None]:
import json
results = {
    "train_costs": train_costs,
    "val_costs": val_costs,
    "train_cer": train_cer,
    "val_cer": val_cer,
    "train_avg_loss": train_avg_loss,
    "val_avg_loss": val_avg_loss
}


In [None]:
json_object = json.dumps(results, indent=4)

# Writing to sample.json
with open("tiny_ocr_transposed.json", "w") as outfile:
    outfile.write(json_object)

In [None]:
state = {
    'model': model.state_dict(),
    'optim': optim.state_dict(),
    'epoch': 20,
}
torch.save(state, 'tinyocr_transposed_train.pth')

In [None]:
sum(val_cer[-3000:])/3000

In [None]:
def compare_preds_labels(preds, labels, image, batch_size: int = 200):
  decoded_preds = [
      tokenizer.decode(
          preds[i].argmax(dim=1), skip_special_tokens=True
      ) for i in range(batch_size)]

  decoded_labels = [
      tokenizer.decode(labels[i], skip_special_tokens=True)
      for i in range(batch_size)
  ]
  print('char-error-rate: ', char_error_rate(decoded_preds, decoded_labels).item())
  print('\n\n')
  for pred,label,im in zip(decoded_preds, decoded_labels, image):
    print(pred)
    plt.imshow(im)
    plt.show()


image, label = next(iter(valid))
# image.shape
pred = model(image.view(16, 1, 40, 320).to().cuda())
compare_preds_labels(pred, label,image.view(16, 40, 320), batch_size=16)