# config.ipynb

In [1]:
# login to huggingface hub

from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
!pip install uv -q

In [3]:
!uv pip install --system -q bitsandbytes peft accelerate lightning torch torchvision

In [4]:
%%capture 
%%bash 
TOKENIZERS_PARALLELISM=true

In [5]:
%%capture
%%bash

#!/bin/bash



DATA_DIR="data";

ZIPFILE_NAME="mathwriting.tgz";

ZIPFILE_PATH="$DATA_DIR/$ZIPFILE_NAME";



UNZIPDIR_NAME="mathwriting-2024";

UNZIPDIR_PATH="$DATA_DIR/$UNZIPDIR_NAME"



# make data directory if not exists

mkdir -p "$DATA_DIR"



# install math-writing dataset

if [ -e $ZIPFILE_PATH ]; then

	echo "$ZIPFILE_PATH - Already exists...";

else

	echo "$ZIPFILE_PATH - Doesn't exists. Starting download...";

	wget -O $ZIPFILE_PATH https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz;

fi



# check if unzip directory exists

if [ -d "$UNZIPDIR_PATH" ]; then

	echo "$UNZIPDIR_PATH - Already exists...";

else

	echo "$UNZIPDIR_PATH - Doesnt exists. Starting untar...";

	tar -xvzf "$ZIPFILE_PATH" -C "$DATA_DIR";

fi


In [6]:
# Configuration

import torch

from datetime import datetime

import logging

from pathlib import Path

import os



# Configure Directory

project_dir = Path(os.getcwd()).parent / "working"

data_dir = project_dir / "data"

model_dir = project_dir / "models"

log_dir = project_dir / "logs"



data_dir.mkdir(parents=True, exist_ok=True)

model_dir.mkdir(parents=True, exist_ok=True)

log_dir.mkdir(parents=True, exist_ok=True)

print(f'project_dir: {project_dir}')

print(f'data_dir: {data_dir}')

print(f'model_dir: {model_dir}')

print(f'log_dir: {log_dir}')



# Configure logger

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

log_file = log_dir / f'log_{timestamp}.log'



logger = logging.getLogger('Handwriting2LaTeX')

logger.setLevel(logging.INFO)



file_handler = logging.FileHandler(log_file)

file_handler.setLevel(logging.INFO)



formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

file_handler.setFormatter(formatter)



logger.addHandler(file_handler)

# logger.addHandler(logging.StreamHandler())



# log the directory information

logger.info(f'PROJECT_DIR: {project_dir}')

logger.info(f'MODEL_DIR: {model_dir}')

logger.info(f'LOG_DIR: {log_dir}')



# Define Parameter for InkML parsing

TIME_SAMPLING_DELTA = 20

SEQ_MAX = 500

SEQ_MIN = -500

PADDING = 4



logger.info(f'TIME_SAMPLING_DELTA: {TIME_SAMPLING_DELTA}')

logger.info(f'SEQ_MAX: {SEQ_MAX}')

logger.info(f'SEQ_MIN: {SEQ_MIN}')

logger.info(f'PADDING: {PADDING}')



# Define Huggingface configuration

REPO_ID = "google/paligemma-3b-pt-224"

LOAD_FINETUNED_MODEL = False

FINETUNED_MODEL_ID = "ball1433/Handwriting2LaTeX"



# Define Training Parameter

MAX_LENGTH = 512

MAX_GENERATION_LENGTH = 128

BATCH_SIZE=1

EPOCHS = 200

IMG_SIZE=224

INIT_LR=3e-4

GRAD_CLIP=1.0

SEED=1234

WARMUP_STEPS=20

NUM_WORKERS=3

LORA_R=12

SHRINK_DATASET=True


logger.info(f'EPOCHS: {EPOCHS}')

logger.info(f'BATCH_SIZE: {BATCH_SIZE}')

project_dir: /kaggle/working
data_dir: /kaggle/working/data
model_dir: /kaggle/working/models
log_dir: /kaggle/working/logs


# InkML-parser.ipynb

In [7]:
from dataclasses import dataclass

import numpy as np

from xml.etree import ElementTree

import matplotlib.pyplot as plt

import matplotlib.patches as mpl_patches

from PIL import Image

import io

from pprint import pprint



# Define Ink class

@dataclass

class Ink:

    """Represents a single ink, as read from an InkML file."""

    # Every stroke in the ink.

    # Each stroke array has shape (3, number of points), where the first

    # dimensions are (x, y, timestamp), in that order.



    strokes: list[np.ndarray]

    # Metadata present in the InkML.

    annotations: dict[str, str]



    min_x: int

    min_y: int

    max_x: int

    max_y: int

    min_t: int

    max_t: int



    max_delta_x: int

    max_delta_y: int





# Define function that reads inkml file, and outputs Ink object

def read_inkml_file(filename: str) -> Ink:

    """Simple reader for MathWriting's InkML files."""

    with open(filename, "r") as f:

        root = ElementTree.fromstring(f.read())



        strokes = []

        annotations = {}



        max_x, max_y, max_t, min_x, min_y, min_t = None, None, None, None, None, None



        max_delta_x, max_delta_y = 0, 0



        for element in root:

            tag_name = element.tag.removeprefix('{http://www.w3.org/2003/InkML}')

            if tag_name == 'annotation':

                annotations[element.attrib.get('type')] = element.text



            elif tag_name == 'trace':

                points = element.text.split(',')

                stroke_x, stroke_y, stroke_t = [], [], []



                prev_x, prev_y = None, None

                for point in points:

                    x, y, t = [float(p) for p in point.split(' ')]

                    stroke_x.append(x)

                    stroke_y.append(y)

                    stroke_t.append(t)



                    if prev_x == None:

                        prev_x = x

                        prev_y = y



                    if max_x == None:

                        max_x = x

                        min_x = x

                        max_y = y

                        min_y = y

                        max_t = t

                        min_t = t

                    if x > max_x:

                        max_x = x

                    if x < min_x:

                        min_x = x

                    if y > max_y:

                        max_y = y

                    if y < min_y:

                        min_y = y

                    if t > max_t:

                        max_t = t

                    if t < min_t:

                        min_t = t

                    if abs(x - prev_x) > max_delta_x:

                        max_delta_x = abs(x - prev_x)

                    if abs(y - prev_y) > max_delta_y:

                        max_delta_y = abs(y - prev_y)



                    prev_x = x

                    prev_y = y

                strokes.append(np.array((stroke_x, stroke_y, stroke_t)))



    return Ink(strokes=strokes,

             annotations=annotations,

             max_x=max_x,

             min_x=min_x,

             max_y=max_y,

             min_y=min_y,

             max_t=max_t,

             min_t=min_t,

             max_delta_x=max_delta_x,

             max_delta_y=max_delta_y)



# display inkml file into image

def display_ink(

    ink: Ink,

    *,

    figsize: tuple[int, int]=(15, 10),

    linewidth: int=2,

    color=None):

  """Simple display for a single ink."""

  plt.figure(figsize=figsize)

  for stroke in ink.strokes:

    plt.plot(stroke[0], stroke[1], linewidth=linewidth, color=color)

    plt.title(

        f"{ink.annotations.get('sampleId', '')} -- "

        f"{ink.annotations.get('splitTagOriginal', '')} -- "

        f"{ink.annotations.get('normalizedLabel', ink.annotations['label'])}"

    )

  plt.gca().invert_yaxis()

  plt.gca().axis('equal')



def get_ink_sequence_token(ink: Ink, timedelta_: int):

    """

    Apply

    1. Time sampling

    2. Scale normalization

    3. Coordinate representation

    """



    # Time sampling

    time_sampled_strokes = []



    # time delta between two adjacent points should be at least timedelta_

    for stroke in ink.strokes:

        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]



        prev_t = stroke_t[0] - (float(timedelta_) * 2)

        sampled_stroke_x, sampled_stroke_y= [], []



        for x, y, t in zip(stroke_x, stroke_y, stroke_t):





            if t - prev_t >= timedelta_:

                prev_t = t

                # add this point to strokes

                sampled_stroke_x.append(x)

                sampled_stroke_y.append(y)



        time_sampled_strokes.append(np.array((sampled_stroke_x, sampled_stroke_y)))



    # Scale normalization

    scale_normalized_strokes = []



    # print(f'max_x: {max_x}, min_x: {min_x}, max_y: {max_y}, min_y: {min_y}')



    # for every point's x value, (x - min_x) * (IMG_SIZE - 2 * PADDING) / (max_x - min_x) + PADDING

    # for every point's y value, (y - min_y) * IMG_SIZE / (max_y - min_y)

    for stroke in time_sampled_strokes:

        stroke_x, stroke_y = stroke[0], stroke[1]



        normalized_stroke_x, normalized_stroke_y = [], []

        for x, y in zip(stroke_x, stroke_y):

            normalized_stroke_x.append(((x - ink.min_x) * (IMG_SIZE - 2 * PADDING) / (ink.max_x - ink.min_x)) + PADDING)

            normalized_stroke_y.append(((y - ink.min_y) * (IMG_SIZE - 2 * PADDING) / (ink.max_y - ink.min_y)) + PADDING)



        scale_normalized_strokes.append(np.array((normalized_stroke_x, normalized_stroke_y)))



    # pprint(scale_normalized_strokes)



    # Discretization

    # Converting all float coordinates into int

    discretized_strokes = []



    for stroke in scale_normalized_strokes:

        stroke_x, stroke_y = stroke[0], stroke[1]



        discretized_stroke_x, discretized_stroke_y = [], []



        for x, y in zip(stroke_x, stroke_y):

            discretized_stroke_x.append(round(x))

            discretized_stroke_y.append(round(y))



        discretized_strokes.append(np.array((discretized_stroke_x, discretized_stroke_y)))



    # pprint(discretized_strokes)



    # Coordinate representation

    relative_position_strokes = []



    for stroke in discretized_strokes:

        stroke_x, stroke_y = stroke[0], stroke[1]



        relative_stroke_x, relative_stroke_y = [], []



        prev_x, prev_y = None, None



        for x, y in zip(stroke_x, stroke_y):

            if prev_x == None and prev_y == None:

                relative_stroke_x.append(x)

                relative_stroke_y.append(y)

                prev_x = x

                prev_y = y

            else:

                relative_stroke_x.append(x - prev_x)

                relative_stroke_y.append(y - prev_y)



        relative_position_strokes.append(np.array((relative_stroke_x, relative_stroke_y)))

    # pprint(relative_position_strokes)



    # return string of sequences of points

    # new stroke starts with seperator <stroke>

    result = ""

    for stroke in relative_position_strokes:

        stroke_x, stroke_y = stroke[0], stroke[1]

        result += "<stroke> "



        for x, y in zip(stroke_x, stroke_y):

            if x > SEQ_MAX:

                x = SEQ_MAX

            if x < SEQ_MIN:

                x = SEQ_MIN



            if y > SEQ_MAX:

                y = SEQ_MAX

            if y < SEQ_MIN:

                y = SEQ_MIN



            result += f'{x} {y} '



    # print(f'token length: {len(result.split())}')

    return result





def get_ink_image(ink: Ink,

                  figsize: int = 800,

                  linewidth: int=3):

    """

    returns a ink image of shape (figsize, figsize, 3)

    containing time, delta_x, delta_y information in color channel

    """



    dpi = 100

    width = figsize * 2

    height = figsize // 2



    fig, ax = plt.subplots(figsize=(width // dpi, height // dpi), dpi=dpi)

    ax.axis('off')



    for stroke in ink.strokes:

        stroke_x, stroke_y, stroke_t = stroke[0], stroke[1], stroke[2]

        colors = []



        prev_x, prev_y, prev_t = None, None, None



        for x, y, t in zip(stroke_x, stroke_y, stroke_t):

            if prev_x == None:

                prev_x = x

                prev_y = y

                prev_t = t



            # store img_drawing[(x, y)] = (r, g, b)

            # r, g, b range 0 - 1

            r = (t - ink.min_t) / (ink.max_t - ink.min_t)

            g = abs(x - prev_x) / ink.max_delta_x

            b = abs(y - prev_y) / ink.max_delta_y



            colors.append((r, g, b))



            prev_x = x

            prev_y = y

            prev_t = t

        for i in range(len(stroke_x)):

            ax.plot(stroke_x[i:i+2], stroke_y[i:i+2], linewidth=linewidth, color=colors[i])

    ax.invert_yaxis()

    ax.axis('equal')



    plt.close()

    fig.canvas.draw()

    plt.tight_layout()

    # plt.show()



    # width, height = fig.canvas.get_width_height()





    img_array = np.array(fig.canvas.buffer_rgba())

    img_array = img_array[:, :, :3]

    height, width, _ = img_array.shape



    left_img_array = img_array[:, :(width//2), :]

    right_img_array = img_array[:, (width//2):, :]

    # print(f'left_img_array shape: {left_img_array.shape}')

    # print(f'right_img_array shape: {right_img_array.shape}')



    img_array = np.concatenate((left_img_array, right_img_array), axis=0)



    # print(f'img_array shape: {img_array.shape}')



    return img_array



# Tokenizer.ipynb

In [8]:
from transformers import AutoTokenizer

custom_tokenizer = AutoTokenizer.from_pretrained(REPO_ID)

new_tokens = ["<bos>", "<eos>", "<stroke>", "<latex>", "!", "&", "(", ")", "*", "+", ",", "-", ".", "/", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ":", ";", "¡", "=", "¿", "?", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U","V", "W", "X", "Y", "Z", "[", r"\#", r"\%", r"\&", r"\Delta", r"\Gamma", r"\Lambda", r"\Leftrightarrow", r"\Omega", r"\Phi", r"\Pi", r"\Psi", r"\Rightarrow", r"\Sigma", r"\Theta", r"\Upsilon", r"\Vdash", r"\Xi", r"\ ", r"\aleph", r"\alpha", r"\angle", r"\approx", r"\backslash", r"\beginmatrix", r"\beta", r"\bigcap", r"\bigcirc", r"\bigcup", r"\bigoplus", r"\bigvee", r"\bigwedge", r"\bullet", r"\cap", r"\cdot", r"\chi", r"\circ", r"\cong", r"\cup", r"\dagger", r"\delta", r"\div", r"\dot", r"\emptyset", r"\endmatrix", r"\epsilon", r"\equiv", r"\eta", r"\exists", r"\forall", r"\frac", r"\gamma", r"\ge", r"\gg", r"\hat", r"\hbar", r"\hookrightarrow", r"\iff", r"\iint", r"\in", r"\infty", r"\int", r"\iota", r"\kappa", r"\lambda", r"\langle", r"\lceil", r"\le", r"\leftarrow", r"\leftrightarrow", r"\lfloor", r"\ll", r"\longrightarrow", r"\mapsto", r"\mathbb", r"\models", r"\mp", r"\mu", r"\nabla", r"\ne", r"\neg", r"\ni", r"\not", r"\notin", r"\nu", r"\odot", r"\oint", r"\omega", r"\ominus", r"\oplus", r"\otimes", r"\overline", r"\partial", r"\perp", r"\phi", r"\pi", r"\pm", r"\prime", r"\prod", r"\propto", r"\psi", r"\rangle", r"\rceil", r"\rfloor", r"\rho", r"\rightarrow", r"\rightleftharpoons", r"\sigma", r"\sim", r"\simeq", r"\sqrt", r"\sqsubseteq", r"\subset", r"\subseteq", r"\subsetneq", r"\sum", r"\supset", r"\supseteq", r"\tau", r"\theta", r"\tilde", r"\times", r"\top", r"\triangle", r"\triangleleft", r"\triangleq", r"\underline", r"\upsilon", r"\varphi", r"\varpi", r"\varsigma", r"\vartheta", r"\vdash", r"\vdots", r"\vec", r"\vee", r"\wedge", r"\xi", r"\zeta", r"\{", r"\—", r"\}", "]", "ˆ", " ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "{", "|", "}", "\\" ]

new_tokens = set(new_tokens) - set(custom_tokenizer.vocab.keys())

custom_tokenizer.add_tokens(list(new_tokens))


139

In [9]:
# TEST PURPOSE

for token in list(new_tokens)[:10]:

    token_id = custom_tokenizer.convert_tokens_to_ids(token)

    print(f"Token: {token}, ID: {token_id}")

Token: \triangle, ID: 257153
Token: \simeq, ID: 257154
Token: \upsilon, ID: 257155
Token: \ll, ID: 257156
Token: \Phi, ID: 257157
Token: \sim, ID: 257158
Token: \supseteq, ID: 257159
Token: \sum, ID: 257160
Token: \cdot, ID: 257161
Token: \overline, ID: 257162


# Train.ipynb

## Define a collate function

In [10]:
from transformers import AutoProcessor



processor = AutoProcessor.from_pretrained(REPO_ID)

# change the default tokenizer into custom tokenizer

processor.tokenizer = custom_tokenizer



def train_collate_fn(samples):

  images = [sample["image"] for sample in samples]

  text_sequences = ["<image>" + sample["text"] + "<latex>" for sample in samples]

  labels = [sample["label"]+"<eos>" for sample in samples]



  inputs = processor(text=text_sequences, images=images, suffix=labels, return_tensors="pt",

                     padding=True, truncation=True, max_length=MAX_LENGTH)



  input_ids = inputs["input_ids"]

  token_type_ids = inputs["token_type_ids"]

  attention_mask = inputs["attention_mask"]

  pixel_values = inputs["pixel_values"]

  labels = inputs["labels"]



  return input_ids, token_type_ids, attention_mask, pixel_values, labels



def test_collate_fn(samples):

    images = [sample["image"] for sample in samples]
    
    text_sequences = ["<image>" + sample["text"] + "<latex>"for sample in samples]
    
    labels = [sample["label"]+"<eos>" for sample in samples]
    
    
    
    inputs = processor(text=text_sequences, images=images, return_tensors="pt",
    
                         padding=True, truncation=True, max_length=MAX_LENGTH)

    input_ids = inputs["input_ids"]
    
    attention_mask = inputs["attention_mask"]
    
    pixel_values = inputs["pixel_values"]

    return input_ids, attention_mask, pixel_values, labels

# Datasets and DataLoader

In [11]:
import torch

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms

import os

import numpy as np

import time

import matplotlib.pyplot as plt



class MathWritingDataset(Dataset):

    def __init__(self, dataset_dir, data_types=["train", "symbols", "synthetic"], transform=None):

        self.dataset_dir = dataset_dir

        self.types = data_types

        self.filenames = []

        self.transform = transform

        for type_ in self.types:

            filename = [f'{type_}/{f.name}' for f in (self.dataset_dir / type_).glob("*.inkml")]

            self.filenames.extend(filename)



    def __len__(self):

        return len(self.filenames)



    def __getitem__(self, idx):

        # start = time.time()

        assert type(idx) == int

        target_file_path = self.dataset_dir / self.filenames[idx]



        # read inkml file

        ink = read_inkml_file(target_file_path)



        # generate ink sequence

        text_sequence = get_ink_sequence_token(ink, TIME_SAMPLING_DELTA)



        image = get_ink_image(ink, IMG_SIZE)

        if "normalizedLabel" in ink.annotations:

          label = ink.annotations['normalizedLabel']

        else:

          label = ink.annotations['label']



        sample = {'image': image, 'text': text_sequence, 'label': label}



        if self.transform:

            sample = self.transform(sample)



        # print(f'time taken to load one data: {time.time() - start}s')



        return sample



# FIXME

# change "mathwriting-2024-excerpt" into "mathwriting-2024" for real training

train_dataset = MathWritingDataset(data_dir / "mathwriting-2024-excerpt", data_types=["train", "symbols", "synthetic"])

test_dataset = MathWritingDataset(data_dir / "mathwriting-2024-excerpt", data_types=["test"])

# FIXME
# shrink dataset size
if SHRINK_DATASET:"
    generator = torch.Generator().manual_seed(SEED)
    
    train_dataset, _ = torch.utils.data.random_split(train_dataset, [3_000, len(train_dataset) - 2_000])
    test_dataset, _ = torch.utils.data.random_split(test_dataset, [500, len(test_dataset) - 500])
    

# train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_collate_fn)

# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=test_collate_fn)


## Pytorch LightningModule

In [12]:
import lightning as L

import re

from nltk import edit_distance

import numpy as np



class PaliGemmaModel(L.LightningModule):

    def __init__(self, config, processor, model):

        super().__init__()
    
        self.config = config
    
        self.processor = processor
    
        self.model = model
    
        self.batch_size=config.get("batch_size")
    
        self.train_losses = []
    
        self.val_losses = []
        self.val_scores = []


    def training_step(self, batch, batch_idx):
        input_ids, token_type_ids, attention_mask, pixel_values, labels = batch

        # outputs: transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast

        # outputs.loss is provided if labels is given.
        outputs = self.model(input_ids=input_ids,
        
                             attention_mask=attention_mask,
        
                             token_type_ids=token_type_ids,
        
                             pixel_values=pixel_values,
        
                             labels=labels)
        
        loss = outputs.loss
        
        self.train_losses.append(loss.item())
        
        self.log("train_loss", loss)
        
        return loss


    def validation_step(self, batch, batch_idx, dataset_idx=0):

        input_ids, attention_mask, pixel_values, labels = batch

        # look at https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/generation/utils.py#L1904 for detailed documentation of model.generate() method
        # to get the logits of the generated text, we should set return_dict_in_generate=True
        generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, max_new_tokens=MAX_GENERATION_LENGTH)

        # add 1 to input_ids size to skip the bos token
        predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1)+1:], skip_special_tokes=True)
    
        scores = []
        for pred, label in zip(predictions, labels):
    
            self.val_scores.append(edit_distance(pred, label) / max(len(pred), len(label)))
            scores.append(edit_distance(pred, label) / max(len(pred), len(label)))
    
            if self.config.get("verbose", False) and len(scores) == 1:
    
                print(f"Prediction: {pred}")
    
                print(f"    Answer: {label}")
    
                print(f" Normed ED: {self.val_scores[-1]}")
    
    
    
        self.log("val_edit_distance", np.mean(scores))
    
        return scores



    def configure_optimizers(self):
    
        optimizer = torch.optim.AdamW(self.parameters(), lr=INIT_LR)
        
        return optimizer



    def train_dataloader(self):

        return DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_collate_fn, num_workers=NUM_WORKERS)

    def val_dataloader(self):

        return DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=test_collate_fn, num_workers=NUM_WORKERS)


## Q-LoRA configuration

In [13]:
from transformers import BitsAndBytesConfig

from peft import get_peft_model, LoraConfig



bnb_config = BitsAndBytesConfig(

    load_in_4bit=True,

    bnb_4bit_quant_type="nf4",

    bnb_4bit_compute_dtype=torch.bfloat16

)



lora_config = LoraConfig(

    r=LORA_R,

    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],

    task_type="CAUSAL_LM",

)


## Define the Model

In [14]:
from transformers import PaliGemmaForConditionalGeneration

model = PaliGemmaForConditionalGeneration.from_pretrained(REPO_ID, quantization_config=bnb_config)

model.resize_token_embeddings(len(custom_tokenizer))

model = get_peft_model(model, lora_config)



model.print_trainable_parameters()

`low_cpu_mem_usage` was None, now default to True since model is quantized.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


trainable params: 11,298,816 || all params: 2,934,920,944 || trainable%: 0.3850


In [15]:
config = {"max_epochs": EPOCHS,

          "check_val_every_n_epoch": 1,

          "gradient_clip_val": GRAD_CLIP,

          "accumulate_grad_batches": 8,

          "lr": INIT_LR,

          "batch_size": BATCH_SIZE,

          "seed":SEED,

          "num_nodes": 1,

          "warmup_steps": WARMUP_STEPS,

          "result_path": log_dir,

          "verbose": True,

}



model_module = PaliGemmaModel(config, processor, model)

In [16]:
from lightning.pytorch.callbacks import Callback

from lightning.pytorch.callbacks.early_stopping import EarlyStopping



from huggingface_hub import HfApi



api = HfApi()

class Print_TrainValidation_ResultCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # print the average of training loss 
        print(f'Average Training Loss: {np.mean(pl_module.train_losses)}')

        # print the average of edit distance score
        print(f'Average Validation Score: {np.mean(pl_module.val_scores)}')

        # reset the list
        pl_module.train_losses = []
        pl_module.val_scores = []


class PushToHubCallback(Callback):

    def on_train_epoch_end(self, trainer, pl_module):

        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")

        pl_module.model.push_to_hub(FINETUNED_MODEL_ID,

                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")



    def on_train_end(self, trainer, pl_module):

        print(f"Pushing model to the hub after training")

        pl_module.processor.push_to_hub(FINETUNED_MODEL_ID,

                                    commit_message=f"Training done")

        pl_module.model.push_to_hub(FINETUNED_MODEL_ID,

                                    commit_message=f"Training done")



early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=20, verbose=False, mode="min")

## Train

In [None]:
trainer = L.Trainer(
        devices=-1, 
    
        accelerator="auto",

        max_epochs=config.get("max_epochs"),

        accumulate_grad_batches=config.get("accumulate_grad_batches"),

        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),

        gradient_clip_val=config.get("gradient_clip_val"),

        precision="16-mixed",

        limit_val_batches=1.0,

        num_sanity_val_steps=2,

        callbacks=[PushToHubCallback(), Print_TrainValidation_ResultCallback(), early_stop_callback],

)



trainer.fit(model_module)

## Inference from test dataset

In [None]:
input_ids, attention_mask, pixel_values, labels = next(iter(model_module.val_dataloader))

generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, max_new_tokens=MAX_GENERATION_LENGTH)

# add 1 to input_ids size to skip the bos token
predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1)+1:], skip_special_tokes=True)

for pred, label in zip(predictions, labels):
    print(f"Prediction: {pred}")

    print(f"    Answer: {label}")

    print(f" Normed ED: {edit_distance(pred, label) / max(len(pred), len(label))}")