# config.ipynb

In [None]:
# login to huggingface hub 
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Configuration
import torch 
from datetime import datetime
import logging 
from pathlib import Path 
import os

# Configure Directory 
project_dir = Path(os.getcwd()).parent
data_dir = project_dir / "data"
model_dir = project_dir / "models"
log_dir = project_dir / "logs"

data_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
print(f'project_dir: {project_dir}')
print(f'data_dir: {data_dir}')
print(f'model_dir: {model_dir}')
print(f'log_dir: {log_dir}')

# Configure logger 
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_dir / f'log_{timestamp}.log'

logger = logging.getLogger('Handwriting2LaTeX')
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(logging.StreamHandler())

# log the directory information 
logger.info(f'PROJECT_DIR: {project_dir}')
logger.info(f'MODEL_DIR: {model_dir}')
logger.info(f'LOG_DIR: {log_dir}')

# Define Parameter for InkML parsing 
TIME_SAMPLING_DELTA = 30
SEQ_MAX = 500
SEQ_MIN = -500
PADDING = 4

logger.info(f'TIME_SAMPLING_DELTA: {TIME_SAMPLING_DELTA}')
logger.info(f'SEQ_MAX: {SEQ_MAX}')
logger.info(f'SEQ_MIN: {SEQ_MIN}')
logger.info(f'PADDING: {PADDING}')

# dataset, dataloader parameter
NUM_WORKERS=2
logger.info(f'NUM_WORKERS: {NUM_WORKERS}')

# Load the model?
LOAD_MODEL=False
LOADING_VIT_MODEL_NAME=""
LOADING_UL2_MODEL_NAME=""

# Define Training parameters for ViT
IMG_SIZE = 224
PATCH_SIZE=16
IMG_IN_CHANNELS=3
D_MODEL=512 # also used in mT5
SIGLIP_N_LAYERS=6
SIGLIP_N_HEADS=8
SIGLIP_FFN_HIDDEN=1024
SIGLIP_DROPOUT=0.1


# Define Training parameters 
EPOCHS = 100
BATCH_SIZE = 16

GEMMA_N_LAYERS=3
GEMMA_N_HEADS=8
GEMMA_FFN_HIDDEN=1024
GEMMA_DROPOUT=0.1

GEMMA_MAX_SEQ_LEN=1024

logger.info(f'EPOCHS: {EPOCHS}')
logger.info(f'BATCH_SIZE: {BATCH_SIZE}')
logger.info(f'IMG_SIZE: {IMG_SIZE}')

# Configure device: CUDA, MPS, CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA as device")
else:
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
        device = torch.device("cpu")
        print("Using CPU as device")
    else:
        device = torch.device("mps")
        print("Using MPS as device")

# for mps, we just use cpu
device = torch.device("cpu")

    
torch.set_default_device(device)
logger.info(f'device: {device}')



# InkML-parser.ipynb

In [None]:
from dataclasses import dataclass
import numpy as np
from xml.etree import ElementTree
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
from PIL import Image
import io
from pprint import pprint

# Define Ink class
@dataclass
class Ink:
    """Represents a single ink, as read from an InkML file."""
    # Every stroke in the ink.
    # Each stroke array has shape (3, number of points), where the first
    # dimensions are (x, y, timestamp), in that order.
    
    strokes: list[np.ndarray]
    # Metadata present in the InkML.
    annotations: dict[str, str]

    min_x: int
    min_y: int
    max_x: int
    max_y: int
    min_t: int
    max_t: int

    max_delta_x: int
    max_delta_y: int


# Define function that reads inkml file, and outputs Ink object
def read_inkml_file(filename: str) -> Ink:
    """Simple reader for MathWriting's InkML files."""
    with open(filename, "r") as f:
        root = ElementTree.fromstring(f.read())
        
        strokes = []
        annotations = {}
    
        max_x, max_y, max_t, min_x, min_y, min_t = None, None, None, None, None, None
    
        max_delta_x, max_delta_y = 0, 0
        
        for element in root:
            tag_name = element.tag.removeprefix('{http://www.w3.org/2003/InkML}')
            if tag_name == 'annotation':
                annotations[element.attrib.get('type')] = element.text
            
            elif tag_name == 'trace':
                points = element.text.split(',')
                stroke_x, stroke_y, stroke_t = [], [], []
                
                prev_x, prev_y = None, None
                for point in points:
                    x, y, t = [float(p) for p in point.split(' ')]
                    stroke_x.append(x)
                    stroke_y.append(y)
                    stroke_t.append(t)
    
                    if prev_x == None:
                        prev_x = x 
                        prev_y = y
        
                    if max_x == None:
                        max_x = x 
                        min_x = x 
                        max_y = y 
                        min_y = y 
                        max_t = t 
                        min_t = t
                    if x > max_x:
                        max_x = x 
                    if x < min_x:
                        min_x = x 
                    if y > max_y:
                        max_y = y 
                    if y < min_y:
                        min_y = y
                    if t > max_t:
                        max_t = t 
                    if t < min_t:
                        min_t = t
                    if abs(x - prev_x) > max_delta_x:
                        max_delta_x = abs(x - prev_x)
                    if abs(y - prev_y) > max_delta_y:
                        max_delta_y = abs(y - prev_y)
        
                    prev_x = x 
                    prev_y = y
                strokes.append(np.array((stroke_x, stroke_y, stroke_t)))

    return Ink(strokes=strokes, 
             annotations=annotations, 
             max_x=max_x, 
             min_x=min_x, 
             max_y=max_y,
             min_y=min_y,
             max_t=max_t,
             min_t=min_t,
             max_delta_x=max_delta_x,
             max_delta_y=max_delta_y)

# display inkml file into image
def display_ink(
    ink: Ink,
    *,
    figsize: tuple[int, int]=(15, 10),
    linewidth: int=2,
    color=None):
  """Simple display for a single ink."""
  plt.figure(figsize=figsize)
  for stroke in ink.strokes:
    plt.plot(stroke[0], stroke[1], linewidth=linewidth, color=color)
    plt.title(
        f"{ink.annotations.get('sampleId', '')} -- "
        f"{ink.annotations.get('splitTagOriginal', '')} -- "
        f"{ink.annotations.get('normalizedLabel', ink.annotations['label'])}"
    )
  plt.gca().invert_yaxis()
  plt.gca().axis('equal')

def get_ink_sequence_token(ink: Ink, timedelta_: int):
    """
    Apply 
    1. Time sampling
    2. Scale normalization
    3. Coordinate representation
    """

    # Time sampling
    time_sampled_strokes = []
    
    # time delta between two adjacent points should be at least timedelta_
    for stroke in ink.strokes: 
        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]

        prev_t = stroke_t[0] - (float(timedelta_) * 2)
        sampled_stroke_x, sampled_stroke_y= [], []
        
        for x, y, t in zip(stroke_x, stroke_y, stroke_t):

                
            if t - prev_t >= timedelta_:
                prev_t = t 
                # add this point to strokes 
                sampled_stroke_x.append(x)
                sampled_stroke_y.append(y)
                
        time_sampled_strokes.append(np.array((sampled_stroke_x, sampled_stroke_y)))

    # Scale normalization
    scale_normalized_strokes = []
    
    # print(f'max_x: {max_x}, min_x: {min_x}, max_y: {max_y}, min_y: {min_y}')

    # for every point's x value, (x - min_x) * (IMG_SIZE - 2 * PADDING) / (max_x - min_x) + PADDING
    # for every point's y value, (y - min_y) * IMG_SIZE / (max_y - min_y)
    for stroke in time_sampled_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        normalized_stroke_x, normalized_stroke_y = [], []
        for x, y in zip(stroke_x, stroke_y):
            normalized_stroke_x.append(((x - ink.min_x) * (IMG_SIZE - 2 * PADDING) / (ink.max_x - ink.min_x)) + PADDING)
            normalized_stroke_y.append(((y - ink.min_y) * (IMG_SIZE - 2 * PADDING) / (ink.max_y - ink.min_y)) + PADDING)

        scale_normalized_strokes.append(np.array((normalized_stroke_x, normalized_stroke_y)))

    # pprint(scale_normalized_strokes)

    # Discretization
    # Converting all float coordinates into int 
    discretized_strokes = []

    for stroke in scale_normalized_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        discretized_stroke_x, discretized_stroke_y = [], []

        for x, y in zip(stroke_x, stroke_y):
            discretized_stroke_x.append(round(x))
            discretized_stroke_y.append(round(y))
            
        discretized_strokes.append(np.array((discretized_stroke_x, discretized_stroke_y)))

    # pprint(discretized_strokes)

    # Coordinate representation
    relative_position_strokes = []

    for stroke in discretized_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]

        relative_stroke_x, relative_stroke_y = [], []

        prev_x, prev_y = None, None

        for x, y in zip(stroke_x, stroke_y):
            if prev_x == None and prev_y == None:
                relative_stroke_x.append(x)
                relative_stroke_y.append(y)
                prev_x = x 
                prev_y = y 
            else:
                relative_stroke_x.append(x - prev_x)
                relative_stroke_y.append(y - prev_y)
        
        relative_position_strokes.append(np.array((relative_stroke_x, relative_stroke_y)))
    # pprint(relative_position_strokes)

    # return string of sequences of points 
    # new stroke starts with seperator <stroke>
    result = ""
    for stroke in relative_position_strokes:
        stroke_x, stroke_y = stroke[0], stroke[1]
        result += "<stroke> "

        for x, y in zip(stroke_x, stroke_y):
            if x > SEQ_MAX:
                x = SEQ_MAX
            if x < SEQ_MIN:
                x = SEQ_MIN
                
            if y > SEQ_MAX:
                y = SEQ_MAX
            if y < SEQ_MIN:
                y = SEQ_MIN
                
            result += f'{x} {y} '

    # print(f'token length: {len(result.split())}')
    return result


def get_ink_image(ink: Ink, 
                  figsize: int = 800, 
                  linewidth: int=3):
    """
    returns a ink image of shape (figsize, figsize, 3)
    containing time, delta_x, delta_y information in color channel
    """

    dpi = 100
    width = figsize * 2 
    height = figsize // 2

    fig, ax = plt.subplots(figsize=(width // dpi, height // dpi), dpi=dpi)
    ax.axis('off')
    
    for stroke in ink.strokes:
        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]
        colors = []

        prev_x, prev_y, prev_t = None, None, None

        for x, y, t in zip(stroke_x, stroke_y, stroke_t):
            if prev_x == None:
                prev_x = x 
                prev_y = y 
                prev_t = t

            # store img_drawing[(x, y)] = (r, g, b)
            # r, g, b range 0 - 1
            r = (t - ink.min_t) / (ink.max_t - ink.min_t) 
            g = abs(x - prev_x) / ink.max_delta_x
            b = abs(y - prev_y) / ink.max_delta_y

            colors.append((r, g, b))

            prev_x = x
            prev_y = y 
            prev_t = t
        for i in range(len(stroke_x)):
            ax.plot(stroke_x[i:i+2], stroke_y[i:i+2], linewidth=linewidth, color=colors[i])
    ax.invert_yaxis()
    ax.axis('equal')
    
    plt.close()
    fig.canvas.draw()
    plt.tight_layout()
    # plt.show()
    
    # width, height = fig.canvas.get_width_height()
    
    
    img_array = np.array(fig.canvas.buffer_rgba())
    img_array = img_array[:, :, :3]
    height, width, _ = img_array.shape

    left_img_array = img_array[:, :(width//2), :]
    right_img_array = img_array[:, (width//2):, :]
    # print(f'left_img_array shape: {left_img_array.shape}')
    # print(f'right_img_array shape: {right_img_array.shape}')

    img_array = np.concatenate((left_img_array, right_img_array), axis=0)

    # print(f'img_array shape: {img_array.shape}')

    return img_array
            

# Train.ipynb

In [None]:
import torch 
from torchinfo import summary
from pprint import pprint
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
from transformers import AutoProcessor

model_id ="google/paligemma2-3b-pt-224"

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

# freeze the pretrained model weight for LoRA training
for param in model.vision_tower.parameters():
    param.requires_grad = False 

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False 