# Vision Transformer for im2latex

The feature extractor and encoder will be fine-tuned off Google's ViT Model, located at `./model/vit-base-patch16-224-in21k`

## Build Dataset

In [9]:
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, load_from_disk, Dataset as HuggingFaceDataset, DatasetDict, concatenate_datasets
from tqdm import tqdm
import os
import random

from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from transformers import ViTImageProcessor

In [10]:
DATA_DIR = "./data/im2latex-250k/"
ROBERTA_PATH = "./model/roberta"
VIT_PATH = "./model/vit-base-patch16-224-in21k"
TOKENIZER_PATH = "./model/tokenizer"
MODEL_PATH = "./model/ocr"

# Load dataset and check format
im2latex_dataset = load_from_disk(DATA_DIR)

In [11]:
# Reload tokenizer with necessary processors
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_PATH, "vocab.json"),
    os.path.join(TOKENIZER_PATH, "merges.txt"),
)
tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.enable_truncation(max_length=512)
tokenizer.enable_padding(length=512)

# Load feature extractor
image_processor = ViTImageProcessor.from_pretrained(VIT_PATH)

In [12]:
# Intermediate dataset genenerator
def intermediate_generator(dataset):
    
    # Preprocess all data (tokenization, ViT processing) into dataset so it can be directly read
    # Processing during training is too slow
    for expression in tqdm(dataset):
        formula = tokenizer.encode(expression["formula"])
        pixel_values = image_processor(expression["image"], return_tensors="pt").pixel_values.squeeze()
        labels = torch.tensor(formula.ids)
        attention_mask = torch.tensor(formula.attention_mask)
        
        yield {
            "pixel_values": pixel_values,
            "labels": labels,
            "attention_mask": attention_mask
        }

Generating train split: 3001 examples [00:40,  3.99 examples/s] 
  1%|▏         | 3001/200269 [00:40<43:55, 74.85it/s] 
Generating train split: 3001 examples [00:40, 74.84 examples/s]


KeyboardInterrupt: 

In [None]:
INTERMEDIATE_PATH = "./data/im2latex-intermediate"

train_partition = im2latex_dataset["train"].train_test_split(test_size=0.5, seed=32)
train_1_dataset, train_2_dataset = train_partition["train"], train_partition["test"]

train_1_dataset = HuggingFaceDataset.from_generator(intermediate_generator, gen_kwargs={"dataset": im2latex_dataset["train"]})
train_1_dataset.save_to_disk(os.path.join(INTERMEDIATE_PATH, "train_1"))
train_1_dataset = None # Garbage collect to free up memory

train_2_dataset = HuggingFaceDataset.from_generator(intermediate_generator, gen_kwargs={"dataset": im2latex_dataset["train"]})
train_2_dataset.save_to_disk(os.path.join(INTERMEDIATE_PATH, "train_2"))
train_2_dataset = None 

test_dataset = HuggingFaceDataset.from_generator(intermediate_generator, gen_kwargs={"dataset": im2latex_dataset["test"]})
test_dataset.save_to_disk(os.path.join(INTERMEDIATE_PATH, "test"))
test_dataset = None 

val_dataset = HuggingFaceDataset.from_generator(intermediate_generator, gen_kwargs={"dataset": im2latex_dataset["val"]})
val_dataset.save_to_disk(os.path.join(INTERMEDIATE_PATH, "val"))
val_dataset = None

processed_data = DatasetDict({
    "train": concatenate_datasets(
        HuggingFaceDataset.load_from_disk(os.path.join(INTERMEDIATE_PATH, "train_1")),
        HuggingFaceDataset.load_from_disk(os.path.join(INTERMEDIATE_PATH, "train_2"))
    ),
    "test": HuggingFaceDataset.load_from_disk(os.path.join(INTERMEDIATE_PATH, "test_1")),
    "val": HuggingFaceDataset.load_from_disk(os.path.join(INTERMEDIATE_PATH, "tval_1")),
})

In [None]:
PROCESSED_DATA_PATH = "./data/processed-im2latex-250k/"

# Save dataset to disk
processed_data.save_to_disk(PROCESSED_DATA_PATH)
HuggingFaceDataset.from_generator(intermediate_generator, gen_kwargs={"set_name": "train"})

In [None]:
# load processed data
processed_data = load_from_disk(PROCESSED_DATA_PATH)

random_index = random.randint(0, processed_data["train"].num_rows - 1)

print(["train"][random_index])

In [None]:
# Create the Im2latexData torch dataset class
      
class Im2latexData(Dataset):
    def __init__(self, latex_data):
        self.examples = latex_data
        
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return 

# Create Im2latex torch Datasets
train_dataset = Im2latexData(im2latex_dataset["train"], tokenizer=tokenizer, feature_extractor=image_processor)
val_dataset = Im2latexData(im2latex_dataset["val"], tokenizer=tokenizer, feature_extractor=image_processor)
test_dataset = Im2latexData(im2latex_dataset["test"], tokenizer=tokenizer, feature_extractor=image_processor)

## Initialize Vision Transformer

In [None]:
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, default_data_collator

In [None]:
# Define tokenizer and model
tokenizer = RobertaTokenizerFast.from_pretrained(TOKENIZER_PATH)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(VIT_PATH, ROBERTA_PATH, tie_encoder_decoder=True)

Some weights of RobertaForCausalLM were not initialized from the model checkpoint at ./model/roberta and are newly initialized: ['roberta.encoder.layer.0.crossattention.output.LayerNorm.bias', 'roberta.encoder.layer.0.crossattention.output.LayerNorm.weight', 'roberta.encoder.layer.0.crossattention.output.dense.bias', 'roberta.encoder.layer.0.crossattention.output.dense.weight', 'roberta.encoder.layer.0.crossattention.self.key.bias', 'roberta.encoder.layer.0.crossattention.self.key.weight', 'roberta.encoder.layer.0.crossattention.self.query.bias', 'roberta.encoder.layer.0.crossattention.self.query.weight', 'roberta.encoder.layer.0.crossattention.self.value.bias', 'roberta.encoder.layer.0.crossattention.self.value.weight', 'roberta.encoder.layer.1.crossattention.output.LayerNorm.bias', 'roberta.encoder.layer.1.crossattention.output.LayerNorm.weight', 'roberta.encoder.layer.1.crossattention.output.dense.bias', 'roberta.encoder.layer.1.crossattention.output.dense.weight', 'roberta.encoder.

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 20
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
TRAIN_EPOCHS = 5
EVAL_STEPS = 16384
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4

training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_PATH,
    evaluation_strategy = 'epoch',
    eval_steps=EVAL_STEPS,
    num_train_epochs=TRAIN_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=VAL_BATCH_SIZE,
    save_total_limit=1,
) 

trainer = Seq2SeqTrainer(
    tokenizer=image_processor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=default_data_collator,
)

trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 