In [1]:
import os
os.environ['PYTHONDONTWRITEBYTECODE'] = '1'
import pprint
from typing import Dict
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

# Prepare finetuning dataset of receipts
It the dataset comes from kaggle `https://www.kaggle.com/datasets/dhiaznaidi/receiptdatasetssd300v2` and contains
- Images under `dataset/images` of receipts
- Extracted data under `dataset/gdt` containing `company`, `total`, `date` and `address`
- Data about extracted data in `data/info_data` contains the data with important words and their coordinates

The expected data should contain `receipt_path`, `schema`(for now just the company, date, total and address), and the `output`

## Load Dataset

In [2]:
from dataclasses import dataclass

@dataclass
class ReceiptData:
    receipt_path: str
    schema: Dict
    output: Dict

In [3]:
import json
from typing import List

fixed_schema = {
    "total": "number//total amount of the invoice",
    "company": "string//the name of the company or person doing the supply",
    "date": "date//the date of the invoice",
    "address": "string//address of the person or company doing the supply",
}

def prepare_dataset() -> List[ReceiptData]:
    dataset: List[ReceiptData] = []
    for receipt in os.listdir("dataset/info_data"):
        with open(f"dataset/info_data/{receipt}", "r") as f:
            receipt_info = json.loads(f.read())
            receipt_id  = receipt_info.get("image_path","").split("/")[5].replace(".jpg","")
            with open(f"dataset/gdt/{receipt_id}.json","r") as df:
                extracted_data = json.loads(df.read())
                for k,v in extracted_data.items():
                    if k in receipt_info.keys():
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": receipt_info[k]
                        }
                    else:
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": {
                                "xmin": 0,
                                "ymin": 0,
                                "xmax": 0,
                                "ymax": 0
                            }
                        }
            receipt_data = ReceiptData(
                receipt_path=receipt_info.get("image_path","").replace("/content/Dataset/train/","dataset/"),
                schema=fixed_schema,
                output=extracted_data
            )
            dataset.append(receipt_data)
    return dataset
data_list = prepare_dataset()
print(len(data_list))
pprint.pprint(data_list[0])

967
ReceiptData(receipt_path='dataset/images/627.jpg',
            schema={'address': 'string//address of the person or company doing '
                               'the supply',
                    'company': 'string//the name of the company or person '
                               'doing the supply',
                    'date': 'date//the date of the invoice',
                    'total': 'number//total amount of the invoice'},
            output={'address': {'coordinates': {'xmax': 0,
                                                'xmin': 0,
                                                'ymax': 0,
                                                'ymin': 0},
                                'value': '12, JALAN TAMPOI 7/4,KAWASAN '
                                         'PERINDUSTRIAN TAMPOI,81200 JOHOR '
                                         'BAHRU,JOHOR'},
                    'company': {'coordinates': {'xmax': 0,
                                                'xmin': 0,


## Tokenize the dataset

In [4]:
from dataclasses import asdict
from datasets import Dataset

def convert_to_dataset(data_list: List[ReceiptData]) -> Dataset:
    data_dict = {
        "prompt": [asdict(record) for record in data_list], 
        "completion": [{"output": record.output} for record in data_list]
    }
    dataset = Dataset.from_dict(data_dict)
    return dataset

dataset = convert_to_dataset(data_list=data_list)
dataset

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 967
})

In [5]:
from typing import Dict
from PIL import Image

def tokenize_batched(examples: Dict[str, List]) -> Dict[str, torch.Tensor]:
    receipt_paths = [item.get("receipt_path") for item in examples["prompt"]]
    schemas = [item.get("schema") for item in examples["prompt"]]
    outputs = [f'{item.get("output")}' for item in examples["completion"]]
    
    images = [[Image.open(path)] for path in receipt_paths] 
    input_prompts = [f'<|image|><|begin_of_text|>{schema}' for schema in schemas]
    
    model_inputs = processor(
        images=images,
        text=input_prompts,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )
      
    completion_tokens = tokenizer(
        outputs,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    # Determine max length across all combined sequences in this batch
    max_len = 0
    for i in range(len(input_prompts)):
        combined_len = model_inputs["input_ids"][i].shape[0] + completion_tokens["input_ids"][i].shape[0]
        if combined_len > max_len:
            max_len = combined_len
    
    # Optional: Cap max_len to a global maximum for consistency
    # if max_len > 1024: max_len = 1024 

    for i in range(len(input_prompts)):
        prompt_input_ids = model_inputs["input_ids"][i]
        prompt_attention_mask = model_inputs["attention_mask"][i]
        
        output_input_ids = completion_tokens["input_ids"][i]
        output_attention_mask = completion_tokens["attention_mask"][i]

        combined_input_ids = torch.cat((prompt_input_ids, output_input_ids))
        combined_attention_mask = torch.cat((prompt_attention_mask, output_attention_mask))

        labels = combined_input_ids.clone()
        prompt_length = prompt_input_ids.shape[0]
        labels[:prompt_length] = -100

        # Pad to max_len of the batch
        current_len = combined_input_ids.shape[0]
        if current_len < max_len:
            padding_len = max_len - current_len
            combined_input_ids = torch.cat([combined_input_ids, torch.full((padding_len,), tokenizer.pad_token_id, dtype=torch.long)])
            combined_attention_mask = torch.cat([combined_attention_mask, torch.zeros(padding_len, dtype=torch.long)])
            labels = torch.cat([labels, torch.full((padding_len,), -100, dtype=torch.long)])

        batch_input_ids.append(combined_input_ids)
        batch_attention_mask.append(combined_attention_mask)
        batch_labels.append(labels)

    return {
        "input_ids": torch.stack(batch_input_ids),
        "attention_mask": torch.stack(batch_attention_mask),
        "labels": torch.stack(batch_labels),
        "pixel_values": model_inputs["pixel_values"],
        "aspect_ratio_ids": model_inputs["aspect_ratio_ids"],
        "aspect_ratio_mask": model_inputs['aspect_ratio_mask'],
    }

tokenized_dataset = dataset.map(
    tokenize_batched,
    batched=True,
    batch_size=4,
    remove_columns=dataset.column_names
)

Map:   0%|          | 0/967 [00:00<?, ? examples/s]

### Test tokenized dataset

In [6]:
sample_data = {
    k: torch.tensor(v, dtype=torch.int32).unsqueeze(0)
    if k in ["input_ids", "attention_mask", "labels"]
    else torch.tensor(v).unsqueeze(0)  # for pixel_values and others
    for k, v in tokenized_dataset[0].items()
}
if 'pixel_values' in sample_data:
    sample_data['pixel_values'] = sample_data['pixel_values'].to(torch.bfloat16)
output = model.generate(**sample_data, max_new_tokens=30)
print(processor.decode(output[0]))

<|begin_of_text|><|image|><|begin_of_text|>{'address':'string//address of the person or company doing the supply', 'company':'string//the name of the company or person doing the supply', 'date': 'date//the date of the invoice', 'total': 'number//total amount of the invoice'}<|begin_of_text|>{'address': {'coordinates': {'xmax': 0, 'xmin': 0, 'ymax': 0, 'ymin': 0}, 'value': '12, JALAN TAMPOI 7/4,KAWASAN PERINDUSTRIAN TAMPOI,81200 JOHOR BAHRU,JOHOR'}, 'company': {'coordinates': {'xmax': 0, 'xmin': 0, 'ymax': 0, 'ymin': 0}, 'value': 'UNIHAKKA INTERNATIONAL SDN BHD'}, 'date': {'coordinates': {'xmax': 2743, 'xmin': 2259, 'ymax': 1868, 'ymin': 1792}, 'value': '03 APR 2018'}, 'total': {'coordinates': [], 'value': '$7.10'}}<|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><|finetune_right_pad_id|><