In [None]:
from datasets import load_dataset
from  transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig, AutoProcessor, LlavaForConditionalGeneration
import bitsandbytes as bnb
import numpy as np
import torch
from dotenv import find_dotenv, load_dotenv
from trl import SFTTrainer
import os
import json
from sklearn.model_selection import train_test_split
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import Dataset, random_split


In [None]:
load_dotenv()
MODEL_NAME = os.getenv("MODEL_NAME")
MAX_LENGTH = int(os.getenv("MAX_LENGTH"))
OUTPUT_SIZE = (384, 384)

In [None]:
# Llava 1.5, images are 336x336
processor = AutoProcessor.from_pretrained(MODEL_NAME,
                                         torch_dtype=torch.float16,
                                         use_auth_token=True)
processor.tokenizer.padding_side = "right" # always on right for training
# if processor.tokenizer.chat_template is None:
#     print("Setting chat template for processor")
#     processor.tokenizer.chat_template = (
#         "{% for message in messages %}"
#         "{% if message['role'] == 'user' %}"
#         "USER: {{ message['content'] }}\n"
#         "{% elif message['role'] == 'assistant' %}"
#         "ASSISTANT: {{ message['content'] }}\n"
#         "{% endif %}"
#         "{% endfor %}"
#     )

In [None]:
USE_LORA = False
USE_QLORA = True

### Load Model 
Load model from HuggingFace with 4 bit quantization

In [None]:
if USE_LORA or USE_QLORA:
    if USE_QLORA:
        print("Using QLoRA")
        # Load the model with 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
    else:
        print("Using LoRA")
        bnb_config = None
        
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
        device_map="auto",
    )
else:
    print("Using full precision")
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        _attn_implementation="flash_attention_2",
    )

### Performance Efficient Fine Tuning (PEFT)

Add adapter to all linear layers of the model except multi_modal_projector and vision_model

In [None]:
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ["multi_modal_projector", "vision_model"]

    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')

    return list(lora_module_names)

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    init_lora_weights="gaussian",
    target_modules=find_all_linear_names(model),
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

### Download reciept data for training 
The objective is to scan the receipt and convert the content to JSON format

In [None]:
from datasets import load_dataset
dataset = load_dataset("naver-clova-ix/cord-v2")

In [None]:
dataset

In [None]:
len(load_dataset("naver-clova-ix/cord-v2", split="train"))

In [None]:
dataset["train"][0]

In [None]:
import json

data = json.loads(dataset["train"][0]["ground_truth"])["gt_parse"]
data

In [None]:
dataset["train"][0]["image"].resize((400, 400))

In [None]:
from typing import Dict
import random
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, dataset_name:str, split = "train", transform=None):
        self.data = load_dataset(dataset_name, split=split)
        self.transform = transform

    def __getitem__(self, idx: int) -> Dict:
        sample = self.data[idx]
        image = sample["image"]

        # Load the image
        try:
            image = image.convert("RGB")
        except Exception as e:
            raise Exception(f"Error loading image {image_path}: {e}")
        
        if self.transform is not None:
            image = self.transform(image)

        # Load the text
        text = json.dumps(json.loads(sample["ground_truth"])["gt_parse"])

        return image, text
    
    def __len__(self) -> int:
        return len(self.data)

In [None]:
# Test the dataset class
from torchvision import transforms

output_size = OUTPUT_SIZE  # Define the output size for resizing
transform = None
transform = transforms.Compose([
        transforms.Resize(output_size),
        # transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

dataset = CustomImageDataset(
    dataset_name="naver-clova-ix/cord-v2",
    split="train",
    transform=transform,
)
print(f"Dataset size: {len(dataset)}")
print(f"First sample: {dataset[0]}")

In [None]:
# Split the dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")

print("Train dataset sample:", train_dataset[0])
print("Validation dataset sample:", val_dataset[0])

In [None]:
print("Train dataset image shape:", train_dataset[0][0])
print("Train dataset text shape:", train_dataset[0][1])

In [None]:
# Process the dataset
def preprocess_image(image_path, output_size=(224, 224)):
    from PIL import Image
    from torchvision import transforms

    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize(output_size),
        transforms.ToTensor(),
    ])

    return transform(image).unsqueeze(0)  # Add batch dimension

In [None]:
def train_collate_fn(examples):
    """ Collate function to process a batch of images and texts.
    Args:
        batch (list): List of tuples containing image tensors and text strings.
        Returns:
            dict: Dictionary containing processed images and tokenized texts.
    """
    images = []
    texts = []

    print(f"Number of examples: {len(examples)}")
    for example in examples:
        image, label = example
        if not isinstance(image, Image.Image):
            raise ValueError(f"Expected PIL Image, got {type(image)}")
        
        images.append(image)
        # messages = [
        #     {"role": "user", "content": "<image>"},
        #     {"role": "assistant", "content": label}
        # ]
        
        # prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False)
        
        prompt = "USER: <image>\n Extract JSON\n" + "ASSISTANT: " + label

        # DEBUG: START
        print(f"Train prompt: {prompt}")
        # Check if the prompt contains the <image> token
        tokenized_prompt = processor.tokenizer.encode(prompt, add_special_tokens=True)
        image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
        num_image_tokens_in_prompt = sum(1 for token in tokenized_prompt if token == image_token_id)
        print(f"Number of <image> tokens in prompt: {num_image_tokens_in_prompt}")
        # DEBUG: END

        print(f"Prompt: {prompt}")
        texts.append(prompt)
    
    batch = processor(text=texts, 
                      images=images, 
                      padding=True, 
                    #   truncation=True, 
                    #   max_length=MAX_LENGTH, 
                      return_tensors="pt")
    print(f"Image shape = {batch["pixel_values"][0].shape}")
    
    labels = batch["input_ids"]
    labels[labels == processor.tokenizer.pad_token_id] = -100

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    pixel_values = batch["pixel_values"]

    print("Input IDs shape:", input_ids.shape)
    print("Attention mask shape:", attention_mask.shape)
    print("Pixel values shape:", pixel_values.shape)
    print("Labels shape:", labels.shape)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
        "labels": labels
    }

In [None]:
def eval_collate_fn(examples):
    # we only feed the prompt to the model
    images = []
    texts = []
    answers = []
    print(f"eval_collate_fn: Number of examples: {len(examples)}")
    for example in examples:
        image, ground_truth = example
        if not isinstance(image, Image.Image):
            raise ValueError(f"Expected PIL Image, got {type(image)}")
        
        images.append(image)
        # messages = [
        #     {"role": "user", "content": "<image>\nExtract JSON."},
        #     {"role": "assistant", "content": ""}
        # ]
        # prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False)
        # print(f"Eval prompt: {prompt}")
        prompt = "USER: <image>\n Extract JSON\n" + "ASSISTANT: "
        texts.append(prompt)
        answers.append(ground_truth)

    batch = processor(text=texts, 
                        images=images, 
                        padding=True,
                        # truncation=True,
                        # max_length=MAX_LENGTH,
                        return_tensors="pt")

    return {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"],
        "pixel_values": batch["pixel_values"],
        "answers": answers  # Keep as list
    }

In [None]:
import lightning as L
from torch.utils.data import DataLoader

class LlavaModelModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.batch_size = config.get("batch_size", 8)
        
    def on_train_start(self):
        print(f"Training started")
        print(f"Model config: {self.model.config}")
        print(f"Processor config: {self.processor.tokenizer}")
        print(f"Batch size: {self.batch_size}")
        print(f"Training arguments: {self.config}")
        print(f"Training dataset size: {len(train_dataset)}")
        print(f"Validation dataset size: {len(val_dataset)}")
        
    def training_step(self, batch, batch_idx):
        print(f"Training step {batch_idx}")
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        pixel_values = batch["pixel_values"].to(self.device)
        labels = batch["labels"].to(self.device)

        # print(f"Forward pass:")
        # print(f"Input IDs shape = {input_ids.shape}")
        # print(f"Attention mask shape = {attention_mask.shape}")
        # print(f"Pixel values shape = {pixel_values.shape}")
        # print(f"Labels shape = {labels.shape}")

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels,
        )
        # print(f"Output shape: {outputs.logits.shape}")
        
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def compute_score(self, pred: str, label: str) -> float:
        try:
            pred_json = json.loads(pred.strip())
            label_json = json.loads(label.strip())
            return 1.0 if pred_json == label_json else 0.0
        except json.JSONDecodeError:
            return 0.0
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        pixel_values = batch["pixel_values"].to(self.device)
        answers = batch["answers"]

        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            max_new_tokens=MAX_LENGTH,
        )
        predictions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        scores = []
        for pred, label in zip(predictions, answers):
            scores.append(self.compute_score(pred, label))

        return scores

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
        return optimizer
    
    def train_dataloader(self):
        print(f"train_dataloader called")
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=train_collate_fn,
        )
        return train_loader
    
    def val_dataloader(self):
        print(f"val_dataloader called")
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=eval_collate_fn,
        )
        return val_loader

In [None]:
config = {"max_epochs": 10,
          # "val_check_interval": 0.2, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-4,
          "batch_size": 1,
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 50,
          "result_path": "./result",
          "verbose": True,
}

In [None]:
model_module = LlavaModelModule(config, processor, model)

In [None]:
from lightning.pytorch.loggers import WandbLogger

WANDB_PROJECT = "LLaVa"
WANDB_NAME = "llava-demo-cord"

# wandb_logger = WandbLogger(project=WANDB_PROJECT, name=WANDB_NAME)

trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        max_epochs=config.get("max_epochs"),
        accumulate_grad_batches=config.get("accumulate_grad_batches"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        precision="16-mixed",
        limit_val_batches=5,
        num_sanity_val_steps=0,
        # logger=wandb_logger
)

In [None]:
trainer.fit(model_module)

In [None]:
print(trainer.model)

In [None]:
# Save training checkpoint locally
trainer.save_checkpoint("/home/mahadev/code/deepspeed/checkpoints/")

In [None]:
# Push adapter weights to HuggingFace
model_module.model.push_to_hub("mngaonkar/Llava-receipt-json", commit_message="Training completed")

In [None]:
# Save adapter weigths locally

model_module.model.save_pretrained("/home/mahadev/code/deepspeed/adapter")

In [None]:
if USE_LORA or USE_QLORA:
    if USE_QLORA:
        print("Using QLoRA")
        # Load the model with 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
    else:
        print("Using LoRA")
        bnb_config = None
        
    base_model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
        device_map="auto",
    )
else:
    print("Using full precision")
    base_model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        _attn_implementation="flash_attention_2",
    )

In [None]:
from adapters import AutoAdapterModel
from peft import PeftModel

In [None]:
# Load the PEFT adapter
adapter_path = "/home/mahadev/code/deepspeed/adapter"  # Local path or Hugging Face Hub repo
peft_model = PeftModel.from_pretrained(base_model, adapter_path)

In [None]:
fused_model = peft_model.merge_and_unload()

In [None]:
fused_model.save_pretrained("./fused_models/")

In [None]:
# Inference

from transformers import AutoProcessor, BitsAndBytesConfig, LlavaForConditionalGeneration
import torch

In [None]:
processor = AutoProcessor.from_pretrained(MODEL_NAME,
                                         torch_dtype=torch.float16,
                                         use_auth_token=True)
processor.tokenizer.padding_side = "right" # always on right for training

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)

In [None]:
model = LlavaForConditionalGeneration.from_pretrained("mngaonkar/Llava-receipt-json", torch_dtype=torch.float16, quantization_config=quantization_config).to("cuda")

In [None]:
test_example = val_dataset[1]

In [None]:
test_example[0]

In [None]:
test_example[1]

In [None]:
prompt = "USER: <image>\nExtract JSON\n ASSISTANT: "

In [None]:
inputs = processor(text=prompt, images=[test_example[0]], padding=True, return_tensors="pt").to("cuda")

In [None]:
for k,v in inputs.items():
    print(k, v.shape)

In [None]:
generated_ids = model.generate(**inputs, max_new_tokens=512)

In [None]:
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
print(generated_text)

In [None]:
model.save_pretrained("./fused_models/")