# Train a LaTeX OCR model
In this brief notebook I show how you can finetune/train an OCR model.

I've opted to mix in handwritten data into the regular pdf LaTeX images. For that I started out with the released pretrained model and continued training on the slightly larger corpus.

In [None]:
import torch
print(torch.__version__)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name())
else:
    device = torch.device("cpu")

print(device)

In [None]:
%pip install pix2tex[train] -qq
%pip install tensorflow
%pip install gpustat -q
%pip install opencv-python-headless==4.1.2.30 -U -q
%pip install --upgrade --no-cache-dir gdown -q
%pip install gdown

In [None]:
!gpustat

In [None]:
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

if tf.config.experimental.list_physical_devices('GPU'):
    print("GPU is available")
else:
    print("No GPU detected")

In [None]:
import os
os.makedirs("dataset/data")
os.makedirs("images")

In [None]:
# em xem rồi xóa đi dataset của thằng handwritten để train trên tập 100k nhe :V

import os
import subprocess

os.makedirs('dataset/data', exist_ok=True)
os.makedirs('images', exist_ok=True)

def gdown(id, output):
    subprocess.run(['gdown', '--id', id, '-O', output], check=True)

handwritten_id = '13vjxGYrFCuYnwgDIUqkxsNGKk__D_sOM'
pdf_images_id = '176PKaCUDWmTJdQwc-OfkO0y8t4gLsIvQ'
pdf_math_id = '1QUjX6PFWPa-HBWdcY-7bA5TRVUnbyS1D'

gdown(handwritten_id, 'dataset/data/crohme.zip')
gdown(pdf_images_id, 'dataset/data/pdf.zip')
gdown(pdf_math_id, 'dataset/data/pdfmath.txt')

os.chdir('dataset/data')

subprocess.run(['unzip', '-q', 'crohme.zip'], check=True)
subprocess.run(['unzip', '-q', 'pdf.zip'], check=True)

os.chdir('images')

os.makedirs('../valimages', exist_ok=True)

shuffled_files = subprocess.run(['ls'], capture_output=True, text=True)
files = shuffled_files.stdout.split()
subprocess.run(['shuf', '-n', '1000'], input="\n".join(files), text=True, capture_output=True, check=True)
selected_files = subprocess.run(['shuf', '-n', '1000'], input="\n".join(files), text=True, capture_output=True)
selected_files_list = selected_files.stdout.split()

for file in selected_files_list:
    os.rename(file, f'../valimages/{file}')

os.chdir('../../..')

Now we generate the datasets. We can string multiple datasets together to get one large lookup table. The only thing saved in these pkl files are image sizes, image location and the ground truth latex code. That way we can serve batches of images with the same dimensionality.

In [None]:
!python -m pix2tex.dataset.dataset -i dataset/data/images dataset/data/train -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/train.pkl

In [None]:
!python -m pix2tex.dataset.dataset -i dataset/data/valimages dataset/data/val -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/val.pkl

In [None]:
# download the weights we want to fine tune
!curl -L -o weights.pth https://github.com/lukas-blecher/LaTeX-OCR/releases/download/v0.0.1/weights.pth

In [None]:
# If using wandb
%pip install -Uq wandb 
# you can cancel this if you don't wan't to use it or don't have a W&B acc.
#!wandb login

In [None]:
# generate colab specific config (set 'debug' to true if wandb is not used)
!echo {backbone_layers: [2, 3, 7], betas: [0.9, 0.999], batchsize: 10, bos_token: 1, channels: 1, data: dataset/data/train.pkl, debug: true, decoder_args: {'attn_on_attn': true, 'cross_attend': true, 'ff_glu': true, 'rel_pos_bias': false, 'use_scalenorm': false}, dim: 256, encoder_depth: 4, eos_token: 2, epochs: 50, gamma: 0.9995, heads: 8, id: null, load_chkpt: 'weights.pth', lr: 0.001, lr_step: 30, max_height: 192, max_seq_len: 512, max_width: 672, min_height: 32, min_width: 32, name: mixed, num_layers: 4, num_tokens: 8000, optimizer: Adam, output_path: outputs, pad: false, pad_token: 0, patch_size: 16, sample_freq: 2000, save_freq: 1, scheduler: StepLR, seed: 42, temperature: 0.2, test_samples: 5, testbatchsize: 20, tokenizer: dataset/tokenizer.json, valbatches: 100, valdata: dataset/data/val.pkl} > colab.yaml

In [None]:
import os
import wandb

wandb.login()


In [None]:
!python -m pix2tex.train --config colab.yaml