In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [2]:
%%capture
%pip install --no-deps bitsandbytes accelerate xformers peft trl triton cut_cross_entropy unsloth_zoo
%pip install sentencepiece protobuf datasets huggingface_hub hf_transfer nltk python-Levenshtein
%pip install --no-deps unsloth
%pip install -q albumentations opencv-python scikit-image
%pip install -q torch torchvision torchaudio
%pip install uncertainty-toolbox

In [20]:
from io import BytesIO
import base64
from PIL import Image
import cv2
import numpy as np
from datasets import load_dataset
import torch
from unsloth import FastVisionModel
from trl import SFTTrainer, SFTConfig
from unsloth.trainer import UnslothVisionDataCollator
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import sentence_bleu
import Levenshtein
import subprocess
import tempfile
import os
from transformers import TextStreamer
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [21]:
class ImagePreprocessor:
    def __init__(self):
        pass

    def binarize_image(self, image: Image.Image, method='otsu') -> Image.Image:
        gray = image.convert('L')
        img_array = np.array(gray)
        if method == 'otsu':
            _, binary = cv2.threshold(img_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        elif method == 'adaptive':
            binary = cv2.adaptiveThreshold(img_array, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                        cv2.THRESH_BINARY, 11, 2)
        return Image.fromarray(binary)
    
    def noise_filtering(self, image: Image.Image) -> Image.Image:
        img_array = np.array(image)
        filtered = cv2.medianBlur(img_array, 3)
        kernel = np.ones((2, 2), np.uint8)
        filtered = cv2.morphologyEx(filtered, cv2.MORPH_CLOSE, kernel)
        filtered = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
        return Image.fromarray(filtered)
    
    def deskew_image(self, image: Image.Image) -> Image.Image:
        img_array = np.array(image.convert('L'))
        coords = np.column_stack(np.where(img_array > 0))
        if len(coords) == 0:
            return image
        angle = cv2.minAreaRect(coords)[-1]
        if angle < -45:
            angle = -(90 + angle)
        else:
            angle = -angle
        (h, w) = img_array.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        rotated = cv2.warpAffine(img_array, M, (w, h), flags=cv2.INTER_CUBIC, 
                               borderMode=cv2.BORDER_REPLICATE)
        return Image.fromarray(rotated)
    
    def preprocess_image(self, image: Image.Image) -> Image.Image:
        filtered_img = self.noise_filtering(image)
        deskewed_img = self.deskew_image(filtered_img)
        binary_img = self.binarize_image(deskewed_img)
        return binary_img

## Image Encoding & Dataset Preprocessing

In [22]:
def pil_image_to_base64_str(img: Image.Image) -> str:
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return img_b64

In [23]:
def base64_str_to_pil(img_b64: str) -> Image.Image:
    img_bytes = base64.b64decode(img_b64)
    return Image.open(BytesIO(img_bytes))

In [24]:
def preprocess_dataset(dataset, preprocessor):
    """Add a new column with preprocessed images encoded as base64 strings."""
    new_images_b64 = []
    for sample in tqdm(dataset, desc="Preprocessing dataset"):
        image = sample["image"]
        preprocessed = preprocessor.preprocess_image(image)
        img_b64 = pil_image_to_base64_str(preprocessed)
        new_images_b64.append(img_b64)
    return dataset.add_column("preprocessed_image", new_images_b64)

## Evaluation Functions

In [25]:
def exact_match_accuracy(predictions: list, targets: list) -> float:
    matches = [pred.strip() == target.strip() for pred, target in zip(predictions, targets)]
    return np.mean(matches)

In [26]:
def calculate_bleu_scores(predictions: list, targets: list) -> list:
    bleu_scores = []
    for pred, target in zip(predictions, targets):
        pred_tokens = pred.strip().split()
        target_tokens = target.strip().split()
        if len(target_tokens) == 0:
            bleu_scores.append(0.0)
        else:
            try:
                score = sentence_bleu([target_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25))
                bleu_scores.append(score)
            except:
                bleu_scores.append(0.0)
    return bleu_scores

In [27]:
def calculate_edit_distances(predictions: list, targets: list) -> list:
    return [Levenshtein.distance(p.strip(), t.strip()) for p, t in zip(predictions, targets)]

In [28]:
def check_latex_compilation(latex_code: str) -> bool:
    latex_document = f"""
    \\documentclass{{article}}
    \\usepackage{{amsmath, amssymb, amsfonts}}
    \\begin{{document}}
    $${latex_code}$$
    \\end{{document}}
    """
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.tex', delete=False) as f:
            f.write(latex_document)
            tex_file = f.name
        result = subprocess.run(['pdflatex', '-interaction=nonstopmode', tex_file], 
                              capture_output=True, timeout=10, cwd='/tmp')
        base_name = tex_file[:-4]
        for ext in ['.tex', '.pdf', '.log', '.aux']:
            try:
                os.unlink(base_name + ext)
            except:
                pass
        return result.returncode == 0
    except Exception as e:
        return False

In [29]:
def compilation_success_rate(predictions: list) -> float:
    success_count = sum(check_latex_compilation(pred) for pred in predictions)
    return success_count / len(predictions) if len(predictions) > 0 else 0.0

In [30]:
def analyze_errors(predictions: list, targets: list, top_k=5):
    errors = []
    for pred, target in zip(predictions, targets):
        if pred.strip() != target.strip():
            errors.append({
                'prediction': pred,
                'target': target,
                'edit_distance': Levenshtein.distance(pred, target)
            })
    errors.sort(key=lambda x: x['edit_distance'], reverse=True)
    print(f"\nTop {top_k} most different predictions:")
    for i, error in enumerate(errors[:top_k]):
        print(f"{i+1}. Edit distance: {error['edit_distance']}")
        print(f"   Target: {error['target']}")
        print(f"   Prediction: {error['prediction']}\n")
    return errors

In [31]:
def convert_to_conversation(sample, instruction="Write the LaTeX representation for this image."):
    conversation = [
        {"role": "user", "content": [
            {"type": "text", "text": instruction},
            {"type": "image", "image": sample["image"]}
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": sample["text"]}
        ]}
    ]
    return {"messages": conversation}

## Model Eval Function

In [44]:
def evaluate_model(model, tokenizer, eval_dataset, image_col="image", max_samples=None):
    """Evaluate model on dataset supporting both PIL images and base64 strings."""
    model.eval()
    predictions, targets = [], []
    samples = eval_dataset if max_samples is None else eval_dataset.select(range(min(max_samples, len(eval_dataset))))
    instruction = "Write the LaTeX representation for this image."
    for sample in samples:
        img_data = sample[image_col]
        # Determine if the image is a PIL image or a base64 string
        if isinstance(img_data, Image.Image):
            img = img_data
        elif isinstance(img_data, str):
            img = base64_str_to_pil(img_data)
        else:
            raise TypeError(f"Unsupported type for image: {type(img_data)}")
        
        messages = [
            {"role": "user", "content": [
                {"type": "text", "text": instruction},
                {"type": "image", "image": img}
            ]}
        ]
        input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        inputs = tokenizer(
            img, input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=128, 
                use_cache=True, 
                temperature=0.1, 
                pad_token_id=tokenizer.eos_token_id
            )
        prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
        predictions.append(prediction)
        targets.append(sample["text"])
    exact_match = exact_match_accuracy(predictions, targets)
    bleu_scores = calculate_bleu_scores(predictions, targets)
    edit_distances = calculate_edit_distances(predictions, targets)
    compilation_rate = compilation_success_rate(predictions)
    metrics = {
        "exact_match_accuracy": exact_match,
        "average_bleu": np.mean(bleu_scores),
        "median_bleu": np.median(bleu_scores),
        "std_bleu": np.std(bleu_scores),
        "average_edit_distance": np.mean(edit_distances),
        "median_edit_distance": np.median(edit_distances),
        "std_edit_distance": np.std(edit_distances),
        "compilation_success_rate": compilation_rate,
        "num_samples": len(predictions)
    }
    return metrics, predictions, targets

In [41]:
from PIL import Image
import io

def load_and_convert_to_pil(path_or_bytes):
    if isinstance(path_or_bytes, bytes):
        return Image.open(io.BytesIO(path_or_bytes))
    else:
        return Image.open(path_or_bytes)

In [51]:
dataset = load_dataset("unsloth/Latex_OCR", split="train[:500]")

In [46]:
# dataset = dataset.map(lambda x: {"image": load_and_convert_to_pil(x["image"])}, batched=False)

In [52]:
split_data = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_data["train"]
temp_dataset = split_data["test"]
eval_test_split = temp_dataset.train_test_split(test_size=0.5, seed=42)
eval_dataset = eval_test_split["train"]
test_dataset = eval_test_split["test"]

In [53]:
preprocessor = ImagePreprocessor()

In [54]:
# Step 3: Preprocess validation and test datasets
eval_dataset = preprocess_dataset(eval_dataset, preprocessor)  # Adds "preprocessed_image" column
test_dataset = preprocess_dataset(test_dataset, preprocessor)   # Adds "preprocessed_image" column

Preprocessing dataset: 100%|██████████| 50/50 [00:00<00:00, 365.52it/s]


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Preprocessing dataset: 100%|██████████| 50/50 [00:00<00:00, 415.07it/s]


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

In [55]:
# Step 4: Convert training dataset to conversation format
converted_train_dataset = [convert_to_conversation(sample) for sample in train_dataset]
converted_eval_dataset = [convert_to_conversation(sample) for sample in eval_dataset]

In [56]:
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2-VL-7B-Instruct",
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth"
)

==((====))==  Unsloth 2025.7.11: Fast Qwen2 patching. Transformers: 4.52.4.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [57]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=True,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    random_state=3407,
    use_rslora=False,
    loftq_config=None
)

Unsloth: Making `model.base_model.model.model.visual` require gradients


In [58]:
print("\n=== BASELINE EVALUATION (BEFORE FINE-TUNING) ===")
print("Evaluating on raw images...")
baseline_metrics_raw, _, _ = evaluate_model(model, tokenizer, eval_dataset, image_col="image", max_samples=50)


=== BASELINE EVALUATION (BEFORE FINE-TUNING) ===
Evaluating on raw images...


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [59]:
print("Evaluating on preprocessed images...")
baseline_metrics_preprocessed, _, _ = evaluate_model(model, tokenizer, eval_dataset, image_col="preprocessed_image", max_samples=50)

Evaluating on preprocessed images...


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


In [60]:
print("\n=== BASELINE RESULTS ===")
print("Metric               | Raw Images | Preprocessed Images | Difference")
print("---------------------|------------|---------------------|-----------")
for metric in baseline_metrics_raw.keys():
    raw_val = baseline_metrics_raw[metric]
    pre_val = baseline_metrics_preprocessed[metric]
    diff = pre_val - raw_val if isinstance(raw_val, float) else "-"
    print(f"{metric:20s} | {raw_val:.4f}      | {pre_val:.4f}            | {diff:+.4f}")


=== BASELINE RESULTS ===
Metric               | Raw Images | Preprocessed Images | Difference
---------------------|------------|---------------------|-----------
exact_match_accuracy | 0.0000      | 0.0000            | +0.0000
average_bleu         | 0.6916      | 0.0001            | -0.6914
median_bleu          | 0.7729      | 0.0000            | -0.7729
std_bleu             | 0.2448      | 0.0010            | -0.2439
average_edit_distance | 31.3000      | 157.3200            | +126.0200
median_edit_distance | 20.5000      | 161.5000            | +141.0000
std_edit_distance    | 34.6192      | 55.2885            | +20.6693
compilation_success_rate | 0.0000      | 0.0000            | +0.0000


ValueError: Unknown format code 'f' for object of type 'str'