In [None]:
import os
import pprint
from typing import Dict
import torch
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import MllamaForConditionalGeneration, AutoProcessor
from transformers.utils.quantization_config import BitsAndBytesConfig
from utils import load_model_quantized

In [None]:
hf_model_id = "meta-llama/Llama-3.2-11B-Vision"

base_model_path: str = "models/llama/Llama-3.2-11B-Vision-base"

finetune_dataset_path = "dataset"

finetuned_model_path = "models/llama/Llama-3.2-11B-Vision-ReceiptIQ-tuned"
if not os.path.exists(finetuned_model_path):
    os.makedirs(finetuned_model_path)

model, processor = load_model_quantized(base_model_path, hf_model_id)
tokenizer = processor.tokenizer

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

finetuned_model = get_peft_model(model, lora_config)
finetuned_model.print_trainable_parameters()

# Prepare finetuning dataset of receipts
It the dataset comes from kaggle `https://www.kaggle.com/datasets/dhiaznaidi/receiptdatasetssd300v2` and contains
- Images under `dataset/images` of receipts
- Extracted data under `dataset/gdt` containing `company`, `total`, `date` and `address`
- Data about extracted data in `data/info_data` contains the data with important words and their coordinates

The expected data should contain `receipt_path`, `schema`(for now just the company, date, total and address), and the `output`

## Load Dataset

In [None]:
import json
from typing import List
from dataclasses import dataclass

@dataclass
class ReceiptData:
    receipt_path: str
    schema: Dict
    output: Dict

fixed_schema = {
    "total": "number//total amount of the invoice",
    "company": "string//the name of the company or person doing the supply",
    "date": "date//the date of the invoice",
    "address": "string//address of the person or company doing the supply",
}

def prepare_dataset() -> List[ReceiptData]:
    dataset: List[ReceiptData] = []
    for receipt in os.listdir(f"{finetune_dataset_path}/info_data"):
        with open(f"{finetune_dataset_path}/info_data/{receipt}", "r") as f:
            receipt_info = json.loads(f.read())
            receipt_id  = receipt_info.get("image_path","").split("/")[5].replace(".jpg","")
            with open(f"{finetune_dataset_path}/gdt/{receipt_id}.json","r") as df:
                extracted_data = json.loads(df.read())
                for k,v in extracted_data.items():
                    if k in receipt_info.keys():
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": receipt_info[k]
                        }
                    else:
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": {
                                "xmin": 0,
                                "ymin": 0,
                                "xmax": 0,
                                "ymax": 0
                            }
                        }
            receipt_data = ReceiptData(
                receipt_path=receipt_info.get("image_path","").replace("/content/Dataset/train/",f"{finetune_dataset_path}/"),
                schema=fixed_schema,
                output=extracted_data
            )
            dataset.append(receipt_data)
    return dataset
data_list = prepare_dataset()
print(len(data_list))
pprint.pprint(data_list[0])

## Tokenize the dataset

In [None]:
from typing import Dict
from PIL import Image
from dataclasses import asdict
from datasets import Dataset

def convert_to_dataset(data_list: List[ReceiptData]) -> Dataset:
    data_dict = {
        "prompt": [asdict(record) for record in data_list], 
        "completion": [{"output": record.output} for record in data_list]
    }
    dataset = Dataset.from_dict(data_dict)
    return dataset

dataset = convert_to_dataset(data_list=data_list)

def tokenize_batched(examples: Dict[str, List]) -> Dict[str, torch.Tensor]:
    receipt_paths = [item.get("receipt_path") for item in examples["prompt"]]
    schemas = [item.get("schema") for item in examples["prompt"]]
    outputs = [f'{item.get("output")}' for item in examples["completion"]]
    
    images = [[Image.open(path)] for path in receipt_paths] 
    input_prompts = [f'<|image|><|begin_of_text|>{schema}' for schema in schemas]
    
    model_inputs = processor(
        images=images,
        text=input_prompts,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )
      
    completion_tokens = tokenizer(
        outputs,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    # Determine max length across all combined sequences in this batch
    max_len = 0
    for i in range(len(input_prompts)):
        combined_len = model_inputs["input_ids"][i].shape[0] + completion_tokens["input_ids"][i].shape[0]
        if combined_len > max_len:
            max_len = combined_len
    
    # Optional: Cap max_len to a global maximum for consistency
    # if max_len > 1024: max_len = 1024 

    for i in range(len(input_prompts)):
        prompt_input_ids = model_inputs["input_ids"][i]
        prompt_attention_mask = model_inputs["attention_mask"][i]
        
        output_input_ids = completion_tokens["input_ids"][i]
        output_attention_mask = completion_tokens["attention_mask"][i]

        combined_input_ids = torch.cat((prompt_input_ids, output_input_ids))
        combined_attention_mask = torch.cat((prompt_attention_mask, output_attention_mask))

        labels = combined_input_ids.clone()
        prompt_length = prompt_input_ids.shape[0]
        labels[:prompt_length] = -100

        # Pad to max_len of the batch
        current_len = combined_input_ids.shape[0]
        if current_len < max_len:
            padding_len = max_len - current_len
            combined_input_ids = torch.cat([combined_input_ids, torch.full((padding_len,), tokenizer.pad_token_id, dtype=torch.long)])
            combined_attention_mask = torch.cat([combined_attention_mask, torch.zeros(padding_len, dtype=torch.long)])
            labels = torch.cat([labels, torch.full((padding_len,), -100, dtype=torch.long)])

        batch_input_ids.append(combined_input_ids)
        batch_attention_mask.append(combined_attention_mask)
        batch_labels.append(labels)

    return {
        "input_ids": torch.stack(batch_input_ids),
        "attention_mask": torch.stack(batch_attention_mask),
        "labels": torch.stack(batch_labels),
        "pixel_values": model_inputs["pixel_values"],
        "aspect_ratio_ids": model_inputs["aspect_ratio_ids"],
        "aspect_ratio_mask": model_inputs['aspect_ratio_mask'],
    }

tokenized_dataset = dataset.map(
    tokenize_batched,
    batched=True,
    batch_size=4,
    remove_columns=dataset.column_names
)

# LoRA Finetune the model

In [None]:
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from transformers.data.data_collator import default_data_collator

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4
num_epochs = 1
batch_size = 1

optimizer = torch.optim.AdamW(finetuned_model.parameters(), lr=lr)
train_dataloader = DataLoader(
    tokenized_dataset, 
    shuffle=True, 
    collate_fn=default_data_collator, 
    batch_size=batch_size, 
    pin_memory=True if device == "cuda" else False,
    num_workers=2
)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

print(f"Training on device: {device}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if device == "cuda" else "Using CPU")
for epoch in range(num_epochs):
    finetuned_model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = finetuned_model(**batch)
        print(f"forward pass done")
        loss = outputs.loss
        print(f"loss: {loss}")
        loss.backward()
        print(f"back pass done")
        optimizer.step()
        print(f"optimizer step done")
        lr_scheduler.step()
        optimizer.zero_grad()
        print(f"{step=}: {loss=}")
        break

# Save Model

In [None]:
import time

time_now = time.time()
finetuned_model.save_pretrained(os.path.join(finetuned_model_path, f"receiptiq_model_{time_now}"))