# config.ipynb

In [1]:
!uv venv
!source .venv/bin/activate
!uv pip install -q bitsandbytes peft accelerate pytorch-lightning torch==2.4.1

Using CPython 3.10.12 interpreter at: [36m/usr/bin/python3[39m
Creating virtual environment at: [36m.venv[39m
Activate with: [32msource .venv/bin/activate[39m


In [2]:
!uv pip freeze

[1maccelerate[0m==1.2.1
[1maiohappyeyeballs[0m==2.4.4
[1maiohttp[0m==3.11.10
[1maiosignal[0m==1.3.2
[1masync-timeout[0m==5.0.1
[1mattrs[0m==24.3.0
[1mbitsandbytes[0m==0.45.0
[1mcertifi[0m==2024.12.14
[1mcharset-normalizer[0m==3.4.0
[1mfilelock[0m==3.16.1
[1mfrozenlist[0m==1.5.0
[1mfsspec[0m==2024.10.0
[1mhuggingface-hub[0m==0.27.0
[1midna[0m==3.10
[1mjinja2[0m==3.1.4
[1mlightning-utilities[0m==0.11.9
[1mmarkupsafe[0m==3.0.2
[1mmpmath[0m==1.3.0
[1mmultidict[0m==6.1.0
[1mnetworkx[0m==3.4.2
[1mnumpy[0m==2.2.0
[1mnvidia-cublas-cu12[0m==12.1.3.1
[1mnvidia-cuda-cupti-cu12[0m==12.1.105
[1mnvidia-cuda-nvrtc-cu12[0m==12.1.105
[1mnvidia-cuda-runtime-cu12[0m==12.1.105
[1mnvidia-cudnn-cu12[0m==9.1.0.70
[1mnvidia-cufft-cu12[0m==11.0.2.54
[1mnvidia-curand-cu12[0m==10.3.2.106
[1mnvidia-cusolver-cu12[0m==11.4.5.107
[1mnvidia-cusparse-cu12[0m==12.1.0.106
[1mnvidia-nccl-cu12[0m==2.20.5
[1mnvidia-nvjitlink-cu12[0m==12.6.85
[1mnvidia-nvtx-cu

In [3]:
# Install data
# generate data_install.sh
"""
#!/bin/bash

DATA_DIR="data";
ZIPFILE_NAME="mathwriting.tgz";
ZIPFILE_PATH="$DATA_DIR/$ZIPFILE_NAME";

UNZIPDIR_NAME="mathwriting-2024-excerpt";
UNZIPDIR_PATH="$DATA_DIR/$UNZIPDIR_NAME"

# make data directory if not exists
mkdir -p "$DATA_DIR"

# install math-writing dataset
if [ -e $ZIPFILE_PATH ]; then
	echo "$ZIPFILE_PATH - Already exists...";
else
	echo "$ZIPFILE_PATH - Doesn't exists. Starting download...";
	wget -O $ZIPFILE_PATH https://storage.googleapis.com/mathwriting_data/mathwriting-2024-excerpt.tgz;
fi

# check if unzip directory exists
if [ -d "$UNZIPDIR_PATH" ]; then
	echo "$UNZIPDIR_PATH - Already exists...";
else
	echo "$UNZIPDIR_PATH - Doesnt exists. Starting untar...";
	tar -xvzf "$ZIPFILE_PATH" -C "$DATA_DIR";
fi
"""
! sh ./data_install.sh

data/mathwriting.tgz - Already exists...
data/mathwriting-2024-excerpt - Already exists...


In [4]:
# login to huggingface hub
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# Configuration
import torch
from datetime import datetime
import logging
from pathlib import Path
import os

# Configure Directory
project_dir = Path(os.getcwd()).parent / "content"
data_dir = project_dir / "data"
model_dir = project_dir / "models"
log_dir = project_dir / "logs"

data_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
print(f'project_dir: {project_dir}')
print(f'data_dir: {data_dir}')
print(f'model_dir: {model_dir}')
print(f'log_dir: {log_dir}')

# Configure logger
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_dir / f'log_{timestamp}.log'

logger = logging.getLogger('Handwriting2LaTeX')
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
# logger.addHandler(logging.StreamHandler())

# log the directory information
logger.info(f'PROJECT_DIR: {project_dir}')
logger.info(f'MODEL_DIR: {model_dir}')
logger.info(f'LOG_DIR: {log_dir}')

# Define Huggingface configuration
REPO_ID = "google/paligemma-3b-pt-224"
FINETUNED_MODEL_I = "ball1433/Handwriting2LaTeX"

# Define Parameter for InkML parsing
TIME_SAMPLING_DELTA = 40
SEQ_MAX = 500
SEQ_MIN = -500
PADDING = 4

logger.info(f'TIME_SAMPLING_DELTA: {TIME_SAMPLING_DELTA}')
logger.info(f'SEQ_MAX: {SEQ_MAX}')
logger.info(f'SEQ_MIN: {SEQ_MIN}')
logger.info(f'PADDING: {PADDING}')

# dataset, dataloader parameter
NUM_WORKERS=1
logger.info(f'NUM_WORKERS: {NUM_WORKERS}')

# Define Training Parameter
MAX_LENGTH = 1024
BATCH_SIZE=2
EPOCHS = 200
IMG_SIZE=224
INIT_LR=3e-4


logger.info(f'EPOCHS: {EPOCHS}')
logger.info(f'BATCH_SIZE: {BATCH_SIZE}')

# Configure device: CUDA, MPS, CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA as device")
else:
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
        device = torch.device("cpu")
        print("Using CPU as device")
    else:
        device = torch.device("mps")
        print("Using MPS as device")

# for mps, we just use cpu
device = torch.device("cpu")


torch.set_default_device(device)
logger.info(f'device: {device}')



INFO:Handwriting2LaTeX:PROJECT_DIR: /content
INFO:Handwriting2LaTeX:MODEL_DIR: /content/models
INFO:Handwriting2LaTeX:LOG_DIR: /content/logs
INFO:Handwriting2LaTeX:TIME_SAMPLING_DELTA: 40
INFO:Handwriting2LaTeX:SEQ_MAX: 500
INFO:Handwriting2LaTeX:SEQ_MIN: -500
INFO:Handwriting2LaTeX:PADDING: 4
INFO:Handwriting2LaTeX:NUM_WORKERS: 1
INFO:Handwriting2LaTeX:EPOCHS: 200
INFO:Handwriting2LaTeX:BATCH_SIZE: 2
INFO:Handwriting2LaTeX:device: cpu


project_dir: /content
data_dir: /content/data
model_dir: /content/models
log_dir: /content/logs
MPS not available because the current PyTorch install was not built with MPS enabled.
Using CPU as device


# InkML-parser.ipynb

In [6]:
from dataclasses import dataclass
import numpy as np
from xml.etree import ElementTree
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from PIL import Image
import io
from pprint import pprint

# Define Ink class
@dataclass
class Ink:
    """Represents a single ink, as read from an InkML file."""
    # Every stroke in the ink.
    # Each stroke array has shape (3, number of points), where the first
    # dimensions are (x, y, timestamp), in that order.

    strokes: list[np.ndarray]
    # Metadata present in the InkML.
    annotations: dict[str, str]

    min_x: int
    min_y: int
    max_x: int
    max_y: int
    min_t: int
    max_t: int

    max_delta_x: int
    max_delta_y: int


# Define function that reads inkml file, and outputs Ink object
def read_inkml_file(filename: str) -> Ink:
    """Simple reader for MathWriting's InkML files."""
    with open(filename, "r") as f:
        root = ElementTree.fromstring(f.read())

        strokes = []
        annotations = {}

        max_x, max_y, max_t, min_x, min_y, min_t = None, None, None, None, None, None

        max_delta_x, max_delta_y = 0, 0

        for element in root:
            tag_name = element.tag.removeprefix('{http://www.w3.org/2003/InkML}')
            if tag_name == 'annotation':
                annotations[element.attrib.get('type')] = element.text

            elif tag_name == 'trace':
                points = element.text.split(',')
                stroke_x, stroke_y, stroke_t = [], [], []

                prev_x, prev_y = None, None
                for point in points:
                    x, y, t = [float(p) for p in point.split(' ')]
                    stroke_x.append(x)
                    stroke_y.append(y)
                    stroke_t.append(t)

                    if prev_x == None:
                        prev_x = x
                        prev_y = y

                    if max_x == None:
                        max_x = x
                        min_x = x
                        max_y = y
                        min_y = y
                        max_t = t
                        min_t = t
                    if x > max_x:
                        max_x = x
                    if x < min_x:
                        min_x = x
                    if y > max_y:
                        max_y = y
                    if y < min_y:
                        min_y = y
                    if t > max_t:
                        max_t = t
                    if t < min_t:
                        min_t = t
                    if abs(x - prev_x) > max_delta_x:
                        max_delta_x = abs(x - prev_x)
                    if abs(y - prev_y) > max_delta_y:
                        max_delta_y = abs(y - prev_y)

                    prev_x = x
                    prev_y = y
                strokes.append(np.array((stroke_x, stroke_y, stroke_t)))

    return Ink(strokes=strokes,
             annotations=annotations,
             max_x=max_x,
             min_x=min_x,
             max_y=max_y,
             min_y=min_y,
             max_t=max_t,
             min_t=min_t,
             max_delta_x=max_delta_x,
             max_delta_y=max_delta_y)

# display inkml file into image
def display_ink(
    ink: Ink,
    *,
    figsize: tuple[int, int]=(15, 10),
    linewidth: int=2,
    color=None):
  """Simple display for a single ink."""
  plt.figure(figsize=figsize)
  for stroke in ink.strokes:
    plt.plot(stroke[0], stroke[1], linewidth=linewidth, color=color)
    plt.title(
        f"{ink.annotations.get('sampleId', '')} -- "
        f"{ink.annotations.get('splitTagOriginal', '')} -- "
        f"{ink.annotations.get('normalizedLabel', ink.annotations['label'])}"
    )
  plt.gca().invert_yaxis()
  plt.gca().axis('equal')

def get_ink_sequence_token(ink: Ink, timedelta_: int):
    """
    Apply
    1. Time sampling
    2. Scale normalization
    3. Coordinate representation
    """

    # Time sampling
    time_sampled_strokes = []

    # time delta between two adjacent points should be at least timedelta_
    for stroke in ink.strokes:
        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]

        prev_t = stroke_t[0] - (float(timedelta_) * 2)
        sampled_stroke_x, sampled_stroke_y= [], []

        for x, y, t in zip(stroke_x, stroke_y, stroke_t):


            if t - prev_t >= timedelta_:
                prev_t = t
                # add this point to strokes
                sampled_stroke_x.append(x)
                sampled_stroke_y.append(y)

        time_sampled_strokes.append(np.array((sampled_stroke_x, sampled_stroke_y)))

    # Scale normalization
    scale_normalized_strokes = []

    # print(f'max_x: {max_x}, min_x: {min_x}, max_y: {max_y}, min_y: {min_y}')

    # for every point's x value, (x - min_x) * (IMG_SIZE - 2 * PADDING) / (max_x - min_x) + PADDING
    # for every point's y value, (y - min_y) * IMG_SIZE / (max_y - min_y)
    for stroke in time_sampled_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        normalized_stroke_x, normalized_stroke_y = [], []
        for x, y in zip(stroke_x, stroke_y):
            normalized_stroke_x.append(((x - ink.min_x) * (IMG_SIZE - 2 * PADDING) / (ink.max_x - ink.min_x)) + PADDING)
            normalized_stroke_y.append(((y - ink.min_y) * (IMG_SIZE - 2 * PADDING) / (ink.max_y - ink.min_y)) + PADDING)

        scale_normalized_strokes.append(np.array((normalized_stroke_x, normalized_stroke_y)))

    # pprint(scale_normalized_strokes)

    # Discretization
    # Converting all float coordinates into int
    discretized_strokes = []

    for stroke in scale_normalized_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        discretized_stroke_x, discretized_stroke_y = [], []

        for x, y in zip(stroke_x, stroke_y):
            discretized_stroke_x.append(round(x))
            discretized_stroke_y.append(round(y))

        discretized_strokes.append(np.array((discretized_stroke_x, discretized_stroke_y)))

    # pprint(discretized_strokes)

    # Coordinate representation
    relative_position_strokes = []

    for stroke in discretized_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        relative_stroke_x, relative_stroke_y = [], []

        prev_x, prev_y = None, None

        for x, y in zip(stroke_x, stroke_y):
            if prev_x == None and prev_y == None:
                relative_stroke_x.append(x)
                relative_stroke_y.append(y)
                prev_x = x
                prev_y = y
            else:
                relative_stroke_x.append(x - prev_x)
                relative_stroke_y.append(y - prev_y)

        relative_position_strokes.append(np.array((relative_stroke_x, relative_stroke_y)))
    # pprint(relative_position_strokes)

    # return string of sequences of points
    # new stroke starts with seperator <stroke>
    result = ""
    for stroke in relative_position_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]
        result += "<stroke> "

        for x, y in zip(stroke_x, stroke_y):
            if x > SEQ_MAX:
                x = SEQ_MAX
            if x < SEQ_MIN:
                x = SEQ_MIN

            if y > SEQ_MAX:
                y = SEQ_MAX
            if y < SEQ_MIN:
                y = SEQ_MIN

            result += f'{x} {y} '

    # print(f'token length: {len(result.split())}')
    return result


def get_ink_image(ink: Ink,
                  figsize: int = 800,
                  linewidth: int=3):
    """
    returns a ink image of shape (figsize, figsize, 3)
    containing time, delta_x, delta_y information in color channel
    """

    dpi = 100
    width = figsize * 2
    height = figsize // 2

    fig, ax = plt.subplots(figsize=(width // dpi, height // dpi), dpi=dpi)
    ax.axis('off')

    for stroke in ink.strokes:
        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]
        colors = []

        prev_x, prev_y, prev_t = None, None, None

        for x, y, t in zip(stroke_x, stroke_y, stroke_t):
            if prev_x == None:
                prev_x = x
                prev_y = y
                prev_t = t

            # store img_drawing[(x, y)] = (r, g, b)
            # r, g, b range 0 - 1
            r = (t - ink.min_t) / (ink.max_t - ink.min_t)
            g = abs(x - prev_x) / ink.max_delta_x
            b = abs(y - prev_y) / ink.max_delta_y

            colors.append((r, g, b))

            prev_x = x
            prev_y = y
            prev_t = t
        for i in range(len(stroke_x)):
            ax.plot(stroke_x[i:i+2], stroke_y[i:i+2], linewidth=linewidth, color=colors[i])
    ax.invert_yaxis()
    ax.axis('equal')

    plt.close()
    fig.canvas.draw()
    plt.tight_layout()
    # plt.show()

    # width, height = fig.canvas.get_width_height()


    img_array = np.array(fig.canvas.buffer_rgba())
    img_array = img_array[:, :, :3]
    height, width, _ = img_array.shape

    left_img_array = img_array[:, :(width//2), :]
    right_img_array = img_array[:, (width//2):, :]
    # print(f'left_img_array shape: {left_img_array.shape}')
    # print(f'right_img_array shape: {right_img_array.shape}')

    img_array = np.concatenate((left_img_array, right_img_array), axis=0)

    # print(f'img_array shape: {img_array.shape}')

    return img_array


# Train.ipynb
## Define a collate function

In [7]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(REPO_ID)

def train_collate_fn(samples):
  images = [sample["image"] for sample in samples]
  text_sequences = ["<image><bos>" + sample["text"]+"<latex>" for sample in samples]
  labels = [sample["label"]+"<eos>" for sample in samples]

  inputs = processor(text=text_sequences, images=images, suffix=labels, return_tensors="pt",
                     padding=True, truncation=True, max_length=MAX_LENGTH)

  input_ids = inputs["input_ids"]
  token_type_ids = inputs["token_type_ids"]
  attention_mask = inputs["attention_mask"]
  pixel_values = inputs["pixel_values"]
  labels = inputs["labels"]

  return input_ids, token_type_ids, attention_mask, pixel_values, labels

def test_collate_fn(samples):
  images = [sample["image"] for sample in samples]
  text_sequences = ["<image><bos>" + sample["text"]+"<latex>" for sample in samples]
  labels = [sample["label"]+"<eos>" for sample in samples]

  inputs = processor(text=text_sequences, images=images, return_tensors="pt",
                     padding=True, truncation=True, max_length=MAX_LENGTH)

  input_ids = inputs["input_ids"]
  attention_mask = inputs["attention_mask"]
  pixel_values = inputs["pixel_values"]
  labels = inputs["labels"]

  return input_ids, attention_mask, pixel_values, labels

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Datasets and DataLoader

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import numpy as np
import time
import matplotlib.pyplot as plt

class MathWritingDataset(Dataset):
    def __init__(self, dataset_dir, data_types=["train", "symbols", "synthetic"], transform=None):
        self.dataset_dir = dataset_dir
        self.types = data_types
        self.filenames = []
        self.transform = transform
        for type_ in self.types:
            filename = [f'{type_}/{f.name}' for f in (self.dataset_dir / type_).glob("*.inkml")]
            self.filenames.extend(filename)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # start = time.time()
        assert type(idx) == int
        target_file_path = self.dataset_dir / self.filenames[idx]

        # read inkml file
        ink = read_inkml_file(target_file_path)

        # generate ink sequence
        text_sequence = get_ink_sequence_token(ink, TIME_SAMPLING_DELTA)

        image = get_ink_image(ink, IMG_SIZE)



        if "normalizedLabel" in ink.annotations:
          label = ink.annotations['normalizedLabel']
        else:
          label = ink.annotations['label']

        sample = {'image': image, 'text': text_sequence, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        # print(f'time taken to load one data: {time.time() - start}s')

        return sample

train_transform = transforms.Compose([
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

# FIXME
# change "mathwriting-2024-excerpt" into "mathwriting-2024" for real training
train_dataset = MathWritingDataset(data_dir / "mathwriting-2024-excerpt", data_types=["train", "symbols", "synthetic"])
test_dataset = MathWritingDataset(data_dir / "mathwriting-2024-excerpt", data_types=["test"])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=train_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=test_collate_fn)


In [9]:
# TEST PURPOSE
input_ids, token_type_ids, attention_mask, pixel_values, labels = next(iter(train_dataloader))
pprint(processor.batch_decode(input_ids))
# pprint(labels.shape)
# pprint(labels)

for id, label in zip(input_ids[0][-5:], labels[0][-5:]):
  print(processor.decode([id.item()]), processor.decode([label.item()]))

['<image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><ima

## Pytorch LightningModule

In [10]:
import lightning as L
import re
from nltk import edit_distance
import numpy

class PaliGemmaModel(L.LightningModule):
  def __init__(self, config, processor, model):
    super().__init__()
    self.config = config
    self.processor = processor
    self.model = model
    self.batch_size=BATCH_SIZE

  def training_step(self, batch, batch_idx):
    input_ids, token_type_ids, attention_mask, pixel_values, labels = batch

    outputs = self.model(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids,
                         pixel_values=pixel_values,
                         labels=labels)

    loss = outputs.loss

    self.log("train_loss", loss)

    return loss

  def validation_step(self, batch, batch_idx, dataset_idx=0):
    input_ids, attention_mask, pixel_values, labels = batch
    generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
                                        pixel_values=pixel_values, max_new_tokens=MAX_LENGTH)

    predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokes=True)

    scores = []

    for pred, label in zip(predictions, labels):
      scores.append(edit_distance(pred, label) / max(len(pred), len(label)))

    self.log("val_edit_distance", np.mean(scores))

    return scores

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), lr=INIT_LR)

    return optimizer

  def train_dataloader(self):
    return DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=train_collate_fn)
  def val_dataloader(self):
    return DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=test_collate_fn)


ModuleNotFoundError: No module named 'lightning'

## Q-LoRA configuration

In [None]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_type=torch.bfloat16
)