In [35]:
import os
import json
import imagesize
import torch
import glob
import random
import cv2
import numpy as np
from torch.utils.data import Dataset, RandomSampler, DataLoader, Sampler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from torchvision import transforms
from PIL import Image
from collections import defaultdict


In [36]:
class DebugCIGroupDataset(Dataset):
    def __init__(self, root_dir='/mnt/efs/RaghavWork/CIGroups', tokenizer=None, max_length=2048, image_size=1024, debug=False):
        self.samples = self.load_samples(root_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_size = image_size
        self.debug = debug
        self.transform = transforms.Compose([
    transforms.Resize((self.image_size, self.image_size)),
    transforms.ToTensor(),  # Converts to [C, H, W] and [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])
 
    def load_samples(self, dirp):
        all_samples = []
        for root, _, files in os.walk(dirp):
            for file in files:
                if not file.endswith('_geo.json'):
                    continue
                img_file = file.replace('_geo.json', '')
                img_path = os.path.join(root, img_file)
                json_path = os.path.join(root, file)
                if os.path.isfile(img_path) and os.path.isfile(json_path):
                    all_samples.append((img_path, json_path))
        return all_samples
 
    def extract_data(self, json_path):
        with open(json_path, encoding="utf8") as f:
            data = json.load(f)
 
        word_boxes = []
        for clas in data['parse']['class']:
            items = data['parse']['class'][clas]
            for item in items:
                for wrd_id in item:
                    word = data['words'][wrd_id]['text']
                    bbox = data['words'][wrd_id]['boundingBox'][0] + data['words'][wrd_id]['boundingBox'][2]
                    row = data['words'][wrd_id].get('row_label', [0])[0]
                    word_boxes.append((bbox, word, clas, row))
        return word_boxes
 
    def sort_by_position(self, words):
        return sorted(words, key=lambda x: (x[0][1], x[0][0]))
 
    def create_prompt_and_gt(self, word_boxes):
        key_values = defaultdict(list)
        line_items = defaultdict(lambda: defaultdict(list))
        ocr_text = ""
 
        for bbox, word, cls, row in word_boxes:
            word = word.strip()
            if not word:
                continue
            ocr_text += f"{word} "
            if row == 0 and cls.endswith('_VALUE'):
                key_values[cls].append(word)
            elif cls.endswith('_VALUE') and 'MISC' not in cls:
                line_items[str(row)][cls].append(word)
 
        gt_parts = []
        for cls, words in key_values.items():
            tag = cls.lower().replace('_value', '')
            gt_parts.append(f"<s_{tag}>{' '.join(words)}</s_{tag}>")
 
        for row in sorted(line_items.keys(), key=int):
            fields = line_items[row]
            if len(fields) < 3:
                continue
            gt_parts.append("<s_line>")
            for cls, words in fields.items():
                tag = cls.lower().replace('_value', '')
                content = ' '.join(words)
                if tag == 'unit_price':
                    content = content.replace(',', '.').replace('/', '').replace('ST', '1.00').strip()
                gt_parts.append(f"<s_{tag}>{content}</s_{tag}>")
            gt_parts.append("</s_line>")
 
        final_output = ''.join(gt_parts)
        prompt = f"<|image|>\n{ocr_text.strip()}\n\nGenerate structured output:\n"
        return prompt, final_output
 
    def __getitem__(self, idx):
        img_path, json_path = self.samples[idx]
 
        # Load & process image
        image = Image.open(img_path).convert("RGB").resize((self.image_size, self.image_size))
        image = self.transform(image)  # Now it's a tensor!
        word_boxes = self.extract_data(json_path)
        word_boxes = self.sort_by_position(word_boxes)
        prompt, target_xml = self.create_prompt_and_gt(word_boxes)
        full_prompt = prompt + target_xml
 
        # Tokenize full prompt
        tokenized = self.tokenizer(
            full_prompt,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
 
        # Decode back for debugging
        if self.debug:
            print(f"\n🔹 Prompt before tokenization:\n{full_prompt}")
            print(f"\n🔹 Token IDs:\n{tokenized['input_ids'][0]}")
            print(f"\n🔹 Decoded text:\n{self.tokenizer.decode(tokenized['input_ids'][0], skip_special_tokens=True)}")
            print(f"\n🔹 Image Path: {img_path}")
 
        return {
            "pixel_values": image,  # Tensor [3, H, W]
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "labels": tokenized["input_ids"].squeeze(0),
            "prompt_raw": full_prompt,  # Optional for external test
            "image_path": img_path
        }
 
    def __len__(self):
        return len(self.samples)

In [37]:
## tokenizer

local_cache = './phi_4_model/'
name = "microsoft/Phi-4-mini-instruct"
# local_cache = './mistral_models'
# name = "mistralai/Mistral-7B-Instruct-v0.3"

#http://localhost:8889/edit/mnt/efs/RaghavWork/VirtualENV/1shot/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=local_cache)
tokenizer.pad_token = tokenizer.eos_token  # For causal LM

In [38]:
obj = DebugCIGroupDataset(max_length=1024*3, tokenizer=tokenizer, debug=True)

In [39]:
obj.__len__()

660

In [41]:
for i in range(3):
    batch = obj.__getitem__(np.random.randint(0,len(obj)))
    print("$$$$$$"*20)
    print()
    print()
    print(batch)

    # print(obj.tokenizer.decode(batch['input_ids']).rstrip('</s>'))
    # print('###'*50)
    # print(batch['labels'])
    # dec = []
    # for i in batch['labels']:
    #    if i>=0:dec.append(i.item())
    # print("hello",obj.tokenizer.decode(dec))
    # # print('@@@'*50)
    break


🔹 Prompt before tokenization:
<|image|>
JIANGSU WANHENG CASTING INDUSTRY CO . , LTD . NO . 9 , CENTURY AVENUE , NORTH INDUSTRIAL PARK , BINHAI 224500 , JIANGSU PROVINCE , CHINA TEL : 0515-8413-3888 / FAX : 0515-8413-4555 COMMERCIAL INVOICE 江 苏万恒 铸 业 有限公司 / 3209967039 Customer Name Cameron , A Schlumberger Company Invoice No BH24211 Address 845 SOUTHEAST 29TH STREET OKLAHOMA Date of Issue Apr.22.2024 : CITY OK 73129 USA Port of Shipment : Shanghai China Contact Person MR.Mark : Vaughan Port of Destination : : Oklahoma , Phone / Fax : 405-629-0448 / 405-629-0495 USA Description Item P.O. NO . Item Part No. QTY . Price U / Amount VALVE PARTS 2 10 2398052-03-01-01 BODY , 320F , 3FP , CL600 , MACH , RF , CS , NACE 4513544779 $ 97.03 5,821.80 1 60 $ 2398053-03-01-01 TAILPIECE , 320F , 3FP , CL600 , MACH , RF , CS , NACE 4513544779 20 2 60 $ 85.06 5,103.60 $ 4513647446 BODY 3 " 300 RF # 75 WCC PISTON CHK VLV 10 2325210-01 $ 116.96 $ 467.84 4 3 2398052-01-01-02 BODY , 320F , 3FP , CL150 , MAC