In [None]:
!pip install jiwer

In [None]:
import re

CHARS = ['_', '^', '{', '}', '&', '\\\\', ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
             's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q',
             'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '\\mathbb{A}', '\\mathbb{B}',
             '\\mathbb{C}', '\\mathbb{D}', '\\mathbb{E}', '\\mathbb{F}', '\\mathbb{G}', '\\mathbb{H}', '\\mathbb{I}', '\\mathbb{J}',
             '\\mathbb{K}', '\\mathbb{L}', '\\mathbb{M}', '\\mathbb{N}', '\\mathbb{O}', '\\mathbb{P}', '\\mathbb{Q}', '\\mathbb{R}',
             '\\mathbb{S}', '\\mathbb{T}', '\\mathbb{U}', '\\mathbb{V}', '\\mathbb{W}', '\\mathbb{X}', '\\mathbb{Y}', '\\mathbb{Z}',
             '\\mathbb', ',', ';', ':', '!', '?', '.', '(', ')', '[', ']', '\\{', '\\}', '*', '/', '+', '-', '\\_', '\\&', '\\#', '\\%', '|',
             '\\backslash', '\\alpha', '\\beta', '\\delta', '\\Delta', '\\epsilon', '\\eta', '\\chi', '\\gamma', '\\Gamma', '\\iota',
             '\\kappa', '\\lambda', '\\Lambda', '\\nu', '\\mu', '\\omega', '\\Omega', '\\phi', '\\Phi', '\\pi', '\\Pi', '\\psi', '\\Psi',
             '\\rho', '\\sigma', '\\Sigma', '\\tau', '\\theta', '\\Theta', '\\upsilon', '\\Upsilon', '\\varphi', '\\varpi', '\\varsigma',
             '\\vartheta', '\\xi', '\\Xi', '\\zeta', '\\frac', '\\sqrt', '\\prod', '\\sum', '\\iint', '\\int', '\\oint', '\\hat', '\\tilde',
             '\\vec', '\\overline', '\\underline', '\\prime', '\\dot', '\\not', '\\begin{matrix}', '\\end{matrix}', '\\langle', '\\rangle',
             '\\lceil', '\\rceil', '\\lfloor', '\\rfloor', '\\|', '\\ge', '\\gg', '\\le', '\\ll', '<', '>', '=', '\\approx', '\\cong', '\\equiv',
             '\\ne', '\\propto', '\\sim', '\\simeq', '\\in', '\\ni', '\\notin', '\\sqsubseteq', '\\subset', '\\subseteq', '\\subsetneq',
             '\\supset', '\\supseteq', '\\emptyset', '\\times', '\\bigcap', '\\bigcirc', '\\bigcup', '\\bigoplus', '\\bigvee', '\\bigwedge',
             '\\cap', '\\cup', '\\div', '\\mp', '\\odot', '\\ominus', '\\oplus', '\\otimes', '\\pm', '\\vee', '\\wedge', '\\hookrightarrow',
             '\\leftarrow', '\\leftrightarrow', '\\Leftrightarrow', '\\longrightarrow', '\\mapsto', '\\rightarrow', '\\Rightarrow',
             '\\rightleftharpoons', '\\iff', '\\bullet', '\\cdot', '\\circ', '\\aleph', '\\angle', '\\dagger', '\\exists', '\\forall',
             '\\hbar', '\\infty', '\\models', '\\nabla', '\\neg', '\\partial', '\\perp', '\\top', '\\triangle', '\\triangleleft',
             '\\triangleq', '\\vdash', '\\Vdash', '\\vdots']  
CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}
_COMMAND_RE = re.compile(r'\\(mathbb{[a-zA-Z]}|begin{[a-z]+}|end{[a-z]+}|operatorname\*|[a-zA-Z]+|.)')

def tokenize_expression(s: str) -> list[str]:
    tokens = []
    while s:
        if s[0] == '\\':
            tokens.append(_COMMAND_RE.match(s).group(0))
        else:
            tokens.append(s[0])
        s = s[len(tokens[-1]):]
    return tokens


# Example
print(tokenize_expression(r'\frac{\alpha}{2}\not\in\mathbb{R}'))

In [None]:
import jiwer
     

def compute_cer(truth_and_output: list[tuple[str, str]]):
  """Computes CER given pairs of ground truth and model output."""
  class TokenizeTransform(jiwer.transforms.AbstractTransform):
    def process_string(self, s: str):
      return tokenize_expression(r'{}'.format(s))

    def process_list(self, tokens: list[str]):
      return [self.process_string(token) for token in tokens]

  ground_truth, model_output = zip(*truth_and_output)

  return jiwer.cer(truth=list(ground_truth),
            hypothesis=list(model_output),
            reference_transform=TokenizeTransform(),
            hypothesis_transform=TokenizeTransform(),
      )
     

# Test data to run compute_cer().
# The first element is the model prediction, the second the ground truth.
examples = [
    (r'\sqrt{2}', r'\sqrt{2}'),  # 0 mistakes, 4 tokens
    (r'\frac{1}{2}', r'\frac{i}{2}'),  # 1 mistake, 7 tokens
    (r'\alpha^{2}', 'a^{2}'),  # 1 mistake, 5 tokens
    ('abc', 'def'),  # 3 mistakes, 3 tokens
]

# 5 mistakes for 19 tokens: 26.3% error rate.
print(f"{compute_cer(examples)*100:.1f} %")

In [4]:
#config

common_config = {
    'train_images_dirs': ['./train_splits/folder_001',
                          './train_splits/folder_002',
                          './train_splits/folder_003',
                          './train_splits/folder_004',
                          './train_splits/folder_005',
                          './train_splits/folder_006',
                          './train_splits/folder_007',
                          './train_splits/folder_008',
                          './train_splits/folder_009',
                          './train_splits/folder_010',],
    'valid_images_dirs': ['./train_splits/folder_099'],
    'labels_file': './train_splits/labels.json',
    'img_width': 448,
    'img_height': 336,
    'map_to_seq_hidden': 64,
    'rnn_hidden': 256,
    'leaky_relu': False,
}

train_config = {
    'epochs': 50,
    'train_batch_size': 32,
    'eval_batch_size': 32,
    'lr': 1e-5,
    'show_interval': 50,
    'valid_interval': 300,
    'save_interval': 300,
    'cpu_workers': 4,
    'reload_checkpoint': None,
    'valid_max_iter': 100,
    'decode_method': 'greedy',
    'beam_size': 10,
    'checkpoints_dir': '/kaggle/working/checkpoints/'
}

train_config.update(common_config)

evaluate_config = {
    'eval_batch_size': 512,
    'cpu_workers': 4,
    'reload_checkpoint': '/kaggle/output/crnn-pytorch/checkpoints/crnn_synth90k.pt',
    'decode_method': 'beam_search',
    'beam_size': 10,
}

evaluate_config.update(common_config)



In [5]:
#requirements

import torch.nn as nn
from collections import defaultdict
import torch
import numpy as np
from scipy.special import logsumexp  # log(p1 + p2) = logsumexp([log_p1, log_p2])
import os
import glob
from torch.utils.data import Dataset
from scipy import signal
from scipy.io import wavfile
import cv2
from PIL import Image
from torch.utils.data import DataLoader
from torch.nn import CTCLoss
from tqdm import tqdm
from docopt import docopt

In [6]:
#model

import torch.nn as nn

class CRNN(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
        super(CRNN, self).__init__()
        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i+1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        conv_relu(3)
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        conv_relu(5, batch_norm=True)
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16 - 1, img_width // 4 - 1
        return cnn, (output_channel, output_height, output_width)

    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)


In [7]:
#ctc

NINF = -1 * float('inf')
DEFAULT_EMISSION_THRESHOLD = 0.01


def _reconstruct(labels, blank=0):
    new_labels = []
    # merge same labels
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l
    # delete blank
    new_labels = [l for l in new_labels if l != blank]

    return new_labels


def greedy_decode(emission_log_prob, blank=0, **kwargs):
    labels = np.argmax(emission_log_prob, axis=-1)
    labels = _reconstruct(labels, blank=blank)
    return labels


def beam_search_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    beams = [([], 0)]  # (prefix, accumulated_log_prob)
    for t in range(length):
        new_beams = []
        for prefix, accumulated_log_prob in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue
                new_prefix = prefix + [c]
                # log(p1 * p2) = log_p1 + log_p2
                new_accu_log_prob = accumulated_log_prob + log_prob
                new_beams.append((new_prefix, new_accu_log_prob))

        # sorted by accumulated_log_prob
        new_beams.sort(key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_size]

    # sum up beams to produce labels
    total_accu_log_prob = {}
    for prefix, accu_log_prob in beams:
        labels = tuple(_reconstruct(prefix, blank))
        # log(p1 + p2) = logsumexp([log_p1, log_p2])
        total_accu_log_prob[labels] = \
            logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])

    labels_beams = [(list(labels), accu_log_prob)
                    for labels, accu_log_prob in total_accu_log_prob.items()]
    labels_beams.sort(key=lambda x: x[1], reverse=True)
    labels = labels_beams[0][0]

    return labels


def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    beams = [(tuple(), (0, NINF))]  # (prefix, (blank_log_prob, non_blank_log_prob))
    # initial of beams: (empty_str, (log(1.0), log(0.0)))

    for t in range(length):
        new_beams_dict = defaultdict(lambda: (NINF, NINF))  # log(0.0) = NINF

        for prefix, (lp_b, lp_nb) in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue

                end_t = prefix[-1] if prefix else None

                # if new_prefix == prefix
                new_lp_b, new_lp_nb = new_beams_dict[prefix]

                if c == blank:
                    new_beams_dict[prefix] = (
                        logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
                        new_lp_nb
                    )
                    continue
                if c == end_t:
                    new_beams_dict[prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_nb + log_prob])
                    )

                # if new_prefix == prefix + (c,)
                new_prefix = prefix + (c,)
                new_lp_b, new_lp_nb = new_beams_dict[new_prefix]

                if c != end_t:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
                    )
                else:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob])
                    )

        # sorted by log(blank_prob + non_blank_prob)
        beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
        beams = beams[:beam_size]

    labels = list(beams[0][0])
    return labels


def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10):
    emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
    # size of emission_log_probs: (batch, length, class)

    decoders = {
        'greedy': greedy_decode,
        'beam_search': beam_search_decode,
        'prefix_beam_search': prefix_beam_decode,
    }
    decoder = decoders[method]

    decoded_list = []
    for emission_log_prob in emission_log_probs:
        decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
        if label2char:
            decoded = [label2char[l] for l in decoded]
        decoded_list.append(decoded)

    return decoded_list

In [8]:
#dataset

import re
import json

class LatexOcrDataset(Dataset):

    CHARS = ['_', '^', '{', '}', '&', '\\\\', ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
             's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q',
             'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '\\mathbb{A}', '\\mathbb{B}',
             '\\mathbb{C}', '\\mathbb{D}', '\\mathbb{E}', '\\mathbb{F}', '\\mathbb{G}', '\\mathbb{H}', '\\mathbb{I}', '\\mathbb{J}',
             '\\mathbb{K}', '\\mathbb{L}', '\\mathbb{M}', '\\mathbb{N}', '\\mathbb{O}', '\\mathbb{P}', '\\mathbb{Q}', '\\mathbb{R}',
             '\\mathbb{S}', '\\mathbb{T}', '\\mathbb{U}', '\\mathbb{V}', '\\mathbb{W}', '\\mathbb{X}', '\\mathbb{Y}', '\\mathbb{Z}',
             '\\mathbb', ',', ';', ':', '!', '?', '.', '(', ')', '[', ']', '\\{', '\\}', '*', '/', '+', '-', '\\_', '\\&', '\\#', '\\%', '|',
             '\\backslash', '\\alpha', '\\beta', '\\delta', '\\Delta', '\\epsilon', '\\eta', '\\chi', '\\gamma', '\\Gamma', '\\iota',
             '\\kappa', '\\lambda', '\\Lambda', '\\nu', '\\mu', '\\omega', '\\Omega', '\\phi', '\\Phi', '\\pi', '\\Pi', '\\psi', '\\Psi',
             '\\rho', '\\sigma', '\\Sigma', '\\tau', '\\theta', '\\Theta', '\\upsilon', '\\Upsilon', '\\varphi', '\\varpi', '\\varsigma',
             '\\vartheta', '\\xi', '\\Xi', '\\zeta', '\\frac', '\\sqrt', '\\prod', '\\sum', '\\iint', '\\int', '\\oint', '\\hat', '\\tilde',
             '\\vec', '\\overline', '\\underline', '\\prime', '\\dot', '\\not', '\\begin{matrix}', '\\end{matrix}', '\\langle', '\\rangle',
             '\\lceil', '\\rceil', '\\lfloor', '\\rfloor', '\\|', '\\ge', '\\gg', '\\le', '\\ll', '<', '>', '=', '\\approx', '\\cong', '\\equiv',
             '\\ne', '\\propto', '\\sim', '\\simeq', '\\in', '\\ni', '\\notin', '\\sqsubseteq', '\\subset', '\\subseteq', '\\subsetneq',
             '\\supset', '\\supseteq', '\\emptyset', '\\times', '\\bigcap', '\\bigcirc', '\\bigcup', '\\bigoplus', '\\bigvee', '\\bigwedge',
             '\\cap', '\\cup', '\\div', '\\mp', '\\odot', '\\ominus', '\\oplus', '\\otimes', '\\pm', '\\vee', '\\wedge', '\\hookrightarrow',
             '\\leftarrow', '\\leftrightarrow', '\\Leftrightarrow', '\\longrightarrow', '\\mapsto', '\\rightarrow', '\\Rightarrow',
             '\\rightleftharpoons', '\\iff', '\\bullet', '\\cdot', '\\circ', '\\aleph', '\\angle', '\\dagger', '\\exists', '\\forall',
             '\\hbar', '\\infty', '\\models', '\\nabla', '\\neg', '\\partial', '\\perp', '\\top', '\\triangle', '\\triangleleft',
             '\\triangleq', '\\vdash', '\\Vdash', '\\vdots']  
    CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}
    _COMMAND_RE = re.compile(r'\\(mathbb{[a-zA-Z]}|begin{[a-z]+}|end{[a-z]+}|operatorname\*|[a-zA-Z]+|.)')
    
    def tokenize_expression(self, s: str) -> list[str]:
        tokens = []
        while s:
            if s[0] == '\\':
                tokens.append(self._COMMAND_RE.match(s).group(0))
            else:
                tokens.append(s[0])
            s = s[len(tokens[-1]):]
        return tokens

    def __init__(self, images_dirs=None, labels_file=None, paths=None, mode=None, img_height=32, img_width=100):

        if not images_dirs and not labels_file and paths:
            texts = None
        else:
            paths, texts = self._load_from_raw_files(images_dirs, labels_file)
        
        self.paths = paths
        self.texts = texts
        self.img_height = img_height
        self.img_width = img_width

    def _load_from_raw_files(self, images_dirs, labels_file):

        paths = []
        texts = []

        with open(labels_file, 'r') as file:
            labels = json.load(file)

        for images_dir in images_dirs:
            for filename in os.listdir(images_dir):
                image_path = os.path.join(images_dir, filename)
                image_id = filename.split('.')[0]
                paths.append(image_path)
                texts.append(labels[image_id])
        
        return paths, texts


    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]

        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        image = (image / 127.5) - 1.0
        image = torch.FloatTensor(image)
        
        if self.texts:
            text = self.texts[index]
            tokens = self.tokenize_expression(text)
            target = [self.CHAR2LABEL[t] for t in tokens]
            target_length = [len(target)]
            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            return image, target, target_length
        else:
            return image


def latex_ocr_collate_fn(batch):
    images, targets, target_lengths = zip(*batch)
    images = torch.stack(images, 0)
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    return images, targets, target_lengths


In [9]:
#evaluation

torch.backends.cudnn.enabled = False


def evaluate(crnn, dataloader, criterion,
             max_iter=None, decode_method='beam_search', beam_size=10):
    crnn.eval()

    tot_count = 0
    tot_loss = 0
    tot_correct = 0
    wrong_cases = []

    # Print("Evaluating")

    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader, desc="Evaluating")):
            if max_iter and i >= max_iter:
                break
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

            images, targets, target_lengths = [d.to(device) for d in data]

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            batch_size = images.size(0)
            input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

            loss = criterion(log_probs, targets, input_lengths, target_lengths)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size)
            reals = targets.cpu().numpy().tolist()
            target_lengths = target_lengths.cpu().numpy().tolist()

            tot_count += batch_size
            tot_loss += loss.item()
            target_length_counter = 0
            for pred, target_length in zip(preds, target_lengths):
                real = reals[target_length_counter:target_length_counter + target_length]
                target_length_counter += target_length
                if pred == real:
                    tot_correct += 1
                else:
                    wrong_cases.append((real, pred))


    evaluation = {
        'loss': tot_loss / tot_count,
        'acc': tot_correct / tot_count,
        'wrong_cases': wrong_cases
    }
    return evaluation


def eval_main():
    config = evaluate_config
    eval_batch_size = config['eval_batch_size']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    test_dataset = LatexOcrDataset(images_dirs=config['train_images_dirs'], labels_file=config['labels_file'],
                                   img_height=img_height, img_width=img_width)

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=cpu_workers,
        collate_fn=latex_ocr_collate_fn)

    num_class = len(LatexOcrDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    evaluation = evaluate(crnn, test_loader, criterion,
                          decode_method=config['decode_method'],
                          beam_size=config['beam_size'])
    print('test_evaluation: loss={loss}, acc={acc}'.format(**evaluation))


# if __name__ == '__main__':
#     eval_main()


In [10]:
#train

import torch.optim as optim
from tqdm import trange

def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, targets, target_lengths = [d.to(device) for d in data]

    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    target_lengths = torch.flatten(target_lengths)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(crnn.parameters(), 5) # gradient clipping with 5
    optimizer.step()
    return loss.item()


def train():
    config = train_config
    print(config)
    epochs = config['epochs']
    train_batch_size = config['train_batch_size']
    eval_batch_size = config['eval_batch_size']
    lr = config['lr']
    show_interval = config['show_interval']
    valid_interval = config['valid_interval']
    save_interval = config['save_interval']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']
    valid_max_iter = config['valid_max_iter']

    img_width = config['img_width']
    img_height = config['img_height']
    train_images_dirs = config['train_images_dirs']
    valid_images_dirs = config['valid_images_dirs']
    labels_file = config['labels_file']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    train_dataset = LatexOcrDataset(images_dirs=train_images_dirs, labels_file=labels_file,
                                    img_height=img_height, img_width=img_width)
    valid_dataset = LatexOcrDataset(images_dirs=valid_images_dirs, labels_file=labels_file,
                                    img_height=img_height, img_width=img_width)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=latex_ocr_collate_fn)
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=eval_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=latex_ocr_collate_fn)

    num_class = len(LatexOcrDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    if reload_checkpoint:
        print(reload_checkpoint)
        crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum', zero_infinity=True)
    criterion.to(device)

    assert save_interval % valid_interval == 0


    loss_list = []
    previous_loss = 0.999999999999
    loss_list.append(previous_loss)
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in tqdm(train_loader, desc="Training"):
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)
            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print('train_batch_loss[', i, ']: ', loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(crnn, valid_loader, criterion,
                                      decode_method=config['decode_method'],
                                      beam_size=config['beam_size'])
                print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))

                if i % save_interval == 0:
                    prefix = 'crnn'
                    loss = evaluation['loss']
                    # Ensure the directory exists
                    os.makedirs(config['checkpoints_dir'], exist_ok=True)
                    save_model_path = os.path.join(config['checkpoints_dir'],
                                                   f'{prefix}_{i:06}_loss{loss}.pt')
                    torch.save(crnn.state_dict(), save_model_path)
                    print('save model at ', save_model_path)
                    reload_checkpoint = save_model_path

            i += 1

        current_loss = tot_train_loss / tot_train_count
        loss_list.append(current_loss)
        print('train_loss: ', current_loss)


    
    # Save the training loss
    file_path = '/kaggle/working/loss_list.txt'
    with open(file_path, 'w') as file:
        # Write each loss value to the file
        for loss in loss_list:
            file.write(f"{loss}\n")
    
    print(f"Loss list saved to {file_path}")





In [None]:
train()