In [1]:
import os
import json
import imagesize
import torch
import glob
import random
import cv2
import numpy as np
from torch.utils.data import Dataset, RandomSampler, DataLoader, Sampler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
import pickle
from collections import defaultdict
# In[2]:


from peft import LoraConfig, get_peft_model, TaskType
import transformers
# from transformers import AdamW
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import bitsandbytes as bnb
from transformers import Trainer, TrainingArguments

## Training using Plain instruction + example

In [3]:
class DLoader(Dataset):
    def __init__(self, sortocr=0, max_seq=1024, tokenizer=None, testmode=0):
        self.dyn_max_seq_len = max_seq
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.testmode = testmode
        self.sortocr = sortocr
        self.data = self.load_cigroupdata()
        print('self.data',len(self.data))

    def __len__(self):
        return len(self.data)
        #return 1
        
    def load_cigroupdata(self):
        samples = 0
        dirp = '/mnt/efs/RaghavWork/CIGroups'
        outsp = {}
        for root, dir, files in os.walk(dirp):
            for file in files:
                if file.endswith('.json'):continue
                key = os.path.basename(root)
                fullpath = os.path.join(root, file)
                fullpath_v2 = os.path.join(root, file+'_geo.json')
                if all(map(lambda x:os.path.isfile(x), [fullpath, fullpath_v2])):
                    if key not in outsp:outsp[key] = []
                    #outsp[key].append((fullpath, fullpath_v2))
                    outsp[key].append((fullpath, fullpath_v2))
                    samples +=1
                    
        allkys = list(outsp.keys())
        count = 0
        datalist = []
        for k in allkys:
            if len(outsp[k])<2:
                allkys.pop(k)
            else:
                count += len(outsp[k])
                datalist.append(outsp.pop(k))
                
        print('Total Samples:',samples, ' | datalist',len(datalist))
        if self.testmode:
            return datalist[-20:]
        else:
            return datalist[:-20]

    def sort_ocr_by_position(self, ocr_data):
        """
        Sort OCR data by reading order (left to right, top to bottom).
        
        Args:
        ocr_data (list): List of tuples in the format [((x1, y1, x3, y3), text)], where 
                         (x1, y1) is the top-left corner and (x3, y3) is the bottom-right corner of the word.
        
        Returns:
        list: Sorted list of tuples by reading order.
        """
        # Sort first by the y1 (top-to-bottom), then by x1 (left-to-right)
        # Example OCR data
        # ocr_data = [
        #     ((100, 50, 150, 80), 'hello'),
        #     ((200, 50, 250, 80), 'world'),
        #     ((100, 100, 150, 130), 'foo'),
        #     ((200, 100, 250, 130), 'bar')
        # ]
        
        # # Sort OCR data
        # sorted_data = sort_ocr_by_position(ocr_data)
        
        # Display the sorted OCR data
        # for box, text in sorted_data:
        #     print(f"Text: {text}, Coordinates: {box}")
        
        sorted_ocr = sorted(ocr_data, key=lambda item: (item[0][1], item[0][0]))
    
        combined_text = " ".join([item[1] for item in sorted_ocr])
        
        return sorted_ocr, combined_text
    
    def get_image_and_jsongt(self, sample):
        words = []
        word_boxes = []
        word_labels = []
        row_labels = []
        img = None
        
        imgp, jsonp =  sample
        # if self.loadimage:
        #     img = cv2.imread(imgp)
        # else:img = None
        img = None
        if isinstance(jsonp, str) and os.path.isfile(jsonp):
            with open(jsonp, encoding="utf8") as f:
                jsondata = json.load(f)
        elif isinstance(jsonp, dict):jsondata = jsonp
            
    
        for clas in jsondata['parse']['class']:
            items = jsondata['parse']['class'][clas]
            for item in items:
                for wrd_id in item:
                    word = jsondata['words'][wrd_id]['text']
                    #print(f'Word: {word} --> {clas}')
                    #words.append(word)
                    #word_boxes.append(jsondata['words'][wrd_id]['boundingBox'][0] + jsondata['words'][wrd_id]['boundingBox'][2])
                    #word_boxes.append(((jsondata['words'][wrd_id]['boundingBox'][0] + jsondata['words'][wrd_id]['boundingBox'][2]), word))
                    #word_labels.append(clas)
                    rlabel = jsondata['words'][wrd_id].get('row_label', [0])[0]
                    #row_labels.append(min(49, rlabel))
                    word_boxes.append(((jsondata['words'][wrd_id]['boundingBox'][0] + jsondata['words'][wrd_id]['boundingBox'][2]),\
                                       word, clas, rlabel))
                    
        #print('words', len(words),' | word_boxes:',len(word_boxes), ' | word_labels:',len(word_labels))
    
        return imgp, img ,words, word_boxes, word_labels, row_labels

    def get_prompt(self, ocr_formatted_1, ocr_formatted_2, example_output):
        # prompt = "Extract the line-item json from the ocr information of the document."
        prompt = 'Your task is to extract line-items information from the attached invoice. Refer to the example \
        input and output and then return the output for the new input.'
        
        one_shot_prompt = f"""
            Task: {prompt}
            
            Example:
            Input: {ocr_formatted_1}
            Output: {example_output}
            
            New Input: {ocr_formatted_2}
            Output:
            """
        return one_shot_prompt
    
    def get_final_json(self, sample):
        row_wise_data = {'keyvalues':{}, 'lineitems':{}}
        kvf = 1
        for item in sample:
            bx, word, cls, row = item
            if row == 0:
                if cls in ('PO_NUMBER_VALUE', 'HTS_NUMBER_VALUE', 'INVOICE_NUMBER_VALUE'):
                    #if cls not in row_wise_data['keyvalues']:row_wise_data['keyvalues'][cls] = []
                    #row_wise_data['keyvalues'][cls].append(word)
                    if cls not in row_wise_data['keyvalues']:row_wise_data['keyvalues'][cls] = ''
                    row_wise_data['keyvalues'][cls] += word + ' '
                    kvf = 0
                continue
            if not cls.endswith('_VALUE') or 'MISC' in cls:continue
            if row not in row_wise_data['lineitems']:row_wise_data['lineitems'][row] = {}
            # if cls not in row_wise_data['lineitems'][row]:row_wise_data['lineitems'][row][cls] = []
            # row_wise_data['lineitems'][row][cls].append(word)
            if cls not in row_wise_data['lineitems'][row]:row_wise_data['lineitems'][row][cls] = ''
            row_wise_data['lineitems'][row][cls] += word + ' '
        
        if kvf:row_wise_data.pop('keyvalues')
        for row in row_wise_data['lineitems']:
            for cls in row_wise_data['lineitems'][row]:
                row_wise_data['lineitems'][row][cls] = row_wise_data['lineitems'][row][cls].strip()
                
        return row_wise_data
    
    def __getitem__(self, idx, testing_sample=0):
        #idx = np.random.randint(0,len(self.data)))
        sample1, sample2 = random.choices(self.data[idx], k = 2)
        #sample1, sample2 = self.data[idx][0], self.data[idx][1]
        #print('sample1, sample2',sample1, sample2)
        imgp1, img1 ,words1, word_boxes1, word_labels1, row_labels1 = self.get_image_and_jsongt(sample1)
        imgp2, img2 ,words2, word_boxes2, word_labels2, row_labels2 = self.get_image_and_jsongt(sample2)
        
        if self.sortocr:
            word_boxes1, ocrtxt1 = self.sort_ocr_by_position(word_boxes1)
            word_boxes2, ocrtxt2 = self.sort_ocr_by_position(word_boxes2)
        else:
            ocrtxt1 = ' '.join([item[1] for item in word_boxes1])
            ocrtxt2 = ' '.join([item[1] for item in word_boxes2])

        outjson1 = self.get_final_json(word_boxes1)
        outjson2 = self.get_final_json(word_boxes2)
        input_prompt = self.get_prompt(ocrtxt1, ocrtxt2, outjson1)
        output = f"{outjson2}"
        if not testing_sample:
            #print('input_prompt--->',input_prompt)
            #print('output--->',output)
            combined_text = input_prompt + output
            tokens = self.tokenizer(
                combined_text, 
                max_length=self.dyn_max_seq_len, 
                truncation=True, 
                return_tensors="pt", 
                padding="max_length",
                padding_side='right'
            )
            input_length = len(self.tokenizer(input_prompt)["input_ids"])
            labels = tokens["input_ids"].clone()
            labels[:, :input_length] = -100  # Ignore input tokens in the loss
            labels[tokens["attention_mask"]==0] = -100 # Ignore pad tokens in the loss
            
            #return {'input_prompt':input_prompt, 'output':output}
            return {
                "input_ids": tokens["input_ids"].squeeze(0),
                "attention_mask": tokens["attention_mask"].squeeze(0),
                "labels": labels.squeeze(0)
            }
        else:
            return {'input_prompt':input_prompt, 'output':output, 'imgp1':imgp1, 'imgp2':imgp2}

In [2]:
## tokenizer

local_cache = './phi_4_model/'
name = "microsoft/Phi-4-mini-instruct"
# local_cache = './mistral_models'
# name = "mistralai/Mistral-7B-Instruct-v0.3"

#http://localhost:8889/edit/mnt/efs/RaghavWork/VirtualENV/1shot/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=local_cache)
tokenizer.pad_token = tokenizer.eos_token  # For causal LM

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    name,
    cache_dir=local_cache,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
)