## Welcome!
to the repo for

*Learning the Legibility of Visual Text Perturbations* (EACL 2023)

by Dev Seth, Rickard Stureborg, Danish Pruthi and Bhuwan Dhingra

### A `LEGIT` Introduction
This notebook provides a helpful starting point to interact with the datasets and models presented in the Learning Legibility paper.

All assets are hosted on the HuggingFace Hub and can be used with the `transformers` and `datasets` libraries: 
  - TrOCR-MT Model: https://huggingface.co/dvsth/LEGIT-TrOCR-MT 
  - LEGIT Dataset: https://huggingface.co/datasets/dvsth/LEGIT
  - Perturbed Jigsaw Dataset: https://huggingface.co/datasets/dvsth/LEGIT-VIPER-Jigsaw-Toxic-Comment-Perturbed

##### Setup

In [None]:
# external imports -- use pip or conda to install these packages
import torch
from transformers import TrOCRProcessor, AutoModel, TrainingArguments
from datasets import load_dataset

# local imports -- make sure these files are in the same directory as this notebook
from LegibilityModel import LegibilityModel
from Trainer import MultiTaskTrainer
from Metrics import binary_classification_metric, ranking_metric
from utils import Renderer

#### Loading the Model and Dataset

In [None]:
# load the model schema and pretrained weights
# (this may take some time to download)
model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True)

Interactive dataset preview available [here](https://huggingface.co/datasets/dvsth/LEGIT/viewer/dvsth--LEGIT/test).

In [None]:
dataset = load_dataset('dvsth/LEGIT').with_format('torch')

#### Training/Eval Loop

##### Trainer setup

In [None]:
# preprocessor provides image normalization and resizing
preprocessor = TrOCRProcessor.from_pretrained(
    "microsoft/trocr-base-handwritten")

# apply preprocessing batch-wise
def collate_fn(data):
    return {
        'choice': torch.tensor([d['choice'].item() for d in data]),
        'img0': preprocessor([d['img0'] for d in data], return_tensors='pt')['pixel_values'],
        'img1': preprocessor([d['img1'] for d in data], return_tensors='pt')['pixel_values']
    }


train_args = TrainingArguments(
    output_dir=f'runs',             # change this to a unique path for each run, e.g. f'runs/{run_id}'
    overwrite_output_dir=True,
    num_train_epochs=5,             # we found 3 epochs to be sufficient for convergence on the base models
    per_device_train_batch_size=26, # fits on 1 x NVIDIA A6000, 48GB VRAM
    per_device_eval_batch_size=26,  # can be increased to 32
    gradient_accumulation_steps=2,  # increase this to fit on a smaller GPU
    warmup_steps=0,             
    weight_decay=0.0,
    learning_rate=1e-5,             # we found this to be the best initial learning rate for the base models
    save_strategy="steps",
    save_steps=200,
    eval_steps=200,
    evaluation_strategy="steps",
    logging_strategy='steps',
    logging_steps=50,
    fp16=False,                     
    load_best_model_at_end=True,    # load the best model at the end of training based on validation F1
    metric_for_best_model='f1_score')

trainer = MultiTaskTrainer(
    model=model,
    compute_metrics=binary_classification_metric, # check out metrics.py for a list of metrics
    args=train_args,
    data_collator=collate_fn,
    train_dataset=dataset['train'],
    eval_dataset=dataset['valid'])


##### Generate predictions and compute metrics

In [None]:
predictions = trainer.predict(dataset['test'].select(range(100))) # takes ~1-2 minutes on a laptop CPU
print(predictions.metrics)