In [None]:
# Fine-Tuning LLaVA 1.5 7B (HF version) for Instagram Captioning on Custom JSONL Data
# Compatible with RTX 3060 / T4 (12-16 GB GPUs)

!pip install -U "transformers>=4.39.0"
!pip install peft bitsandbytes
!pip install -U "trl>=0.8.3"

In [None]:
import os
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
from multiprocessing import Pool, cpu_count
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig


In [None]:
torch.cuda.is_available()

In [None]:
BASE_IMAGES_DIR = "/mnt/InstaCities1M/img_resized_1M/cities_instagram/"
BASE_CAPTIONS_DIR = "/mnt/InstaCities1M/captions_resized_1M/cities_instagram/"
OUTPUT_JSONL_PATH = './datasset_v1.jsonl'
cities = ['newyork']

LLAVA_CHAT_TEMPLATE = (
    "You are a social media influencer. Write a captivating Instagram caption for this image "
    "that will engage more viewers and boost interaction. Analyze the image to decide the tone of the caption."
)

print(os.path.exists(BASE_IMAGES_DIR))
print(os.path.exists(BASE_CAPTIONS_DIR))


In [None]:
images_files = []
captions_files = []

for city in cities:
    img = BASE_IMAGES_DIR + city + '/' + np.array(os.listdir(BASE_IMAGES_DIR + city))
    caption = BASE_CAPTIONS_DIR + city + '/' + np.array(os.listdir(BASE_CAPTIONS_DIR + city))
    images_files.extend(img)
    captions_files.extend(caption)

In [None]:
# Clean filenames
image_ids = {os.path.splitext(os.path.basename(img))[0] for img in images_files}
caption_ids = {os.path.splitext(os.path.basename(cap))[0] for cap in captions_files}

# Now match
common_ids = image_ids & caption_ids

# Filter
filtered_image_files = [img for img in images_files if os.path.splitext(os.path.basename(img))[0] in common_ids]
filtered_caption_files = [cap for cap in captions_files if os.path.splitext(os.path.basename(cap))[0] in common_ids]

images_files = filtered_image_files
captions_files = filtered_caption_files

In [None]:
len(images_files), len(captions_files)

In [None]:
# # --- Worker function ---
# def process_pair(i):
#     try:
#         img_path = images_files[i]
#         caption_path = captions_files[i]

#         with open(caption_path, 'r', encoding='utf-8') as f:
#             caption = f.read().strip().replace('\n', ' ')
#             if not caption:
#                 return None

#         return {
#             "image_path": img_path,
#             "prompt": PROMPT_TEMPLATE,
#             "response": caption
#         }
#     except Exception:
#         return None

# --- Worker function ---
def process_pair(i):
    try:
        img_path = images_files[i]
        caption_path = captions_files[i]

        with open(caption_path, 'r', encoding='utf-8') as f:
            caption = f.read().strip().replace('\n', ' ')
            if not caption:
                return None

        # Create messages field directly
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": LLAVA_CHAT_TEMPLATE},
                    {"type": "image"}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": caption}
                ]
            }
        ]

        return {
            "image_path": img_path,
            "messages": messages
        }

    except Exception:
        return None


# --- Multiprocessing ---
with Pool(cpu_count()) as pool:
    results = list(tqdm(pool.imap(process_pair, range(len(images_files))), total=len(images_files)))


In [None]:
data = [entry for entry in results if entry is not None]

# --- Write JSONL File ---
with open(OUTPUT_JSONL_PATH, 'w', encoding='utf-8') as f:
    for entry in data:
        f.write(json.dumps(entry) + "\n")

print(f"JSONL created: {OUTPUT_JSONL_PATH} with {len(data)} samples.")
PROMPT_TEMPLATE = "Write an Instagram caption for this image to be posted:"

In [None]:
# --- Configuration ---
model_id = "llava-hf/llava-1.5-7b-hf"
data_path = OUTPUT_JSONL_PATH  # path to your formatted JSONL file
output_dir = "./llava_lora_instagram"

In [None]:
# --- Model Loading (4bit Quantization) ---
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

tokenizer.chat_template = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}"
    "{% for item in message['content'] %}"
    "{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}"
    "{% endfor %}"
    "{% if message['role'] == 'assistant' %}{{ eos_token }}{% endif %}"
    "{% endfor %}"
)

In [None]:
def load_image(example):
    example["image"] = Image.open(example["image_path"]).convert("RGB")
    return example

# Load and preprocess dataset with pre-formatted messages
dataset = load_dataset("json", data_files=data_path)["train"]

In [None]:
dataset[0]

In [None]:
# Read JSONL manually
dataset = []
with open(data_path, "r", encoding="utf-8") as f:
    for line in f:
        example = json.loads(line.strip())
        dataset.append(example)

print(f"Loaded {len(dataset)} samples.")

from datasets import Dataset
dataset = Dataset.from_list(dataset)

In [None]:
# --- Data Collator ---
# class LLavaDataCollator:
#     def __init__(self, processor):
#         self.processor = processor

#     def __call__(self, examples):
#         texts = []
#         images = []
#         for example in examples:
#             messages = example["messages"]
#             text = self.processor.tokenizer.apply_chat_template(
#                 messages, tokenize=False, add_generation_prompt=False
#             )
#             texts.append(text)
             
#             # Load image dynamically during batching
#             image = Image.open(example["image_path"]).convert("RGB")
#             images.append(image)
            
#         batch = self.processor(texts, images=images, return_tensors="pt", padding=True)
#         labels = batch["input_ids"].clone()
#         if self.processor.tokenizer.pad_token_id is not None:
#             labels[labels == self.processor.tokenizer.pad_token_id] = -100
#         batch["labels"] = labels
#         return batch

# class LLavaDataCollator:
#     def __init__(self, processor):
#         self.processor = processor

#     def __call__(self, examples):
#         texts = []
#         images = []
        
#         # Fix: Reconstruct example as dict if needed
#         if isinstance(examples[0], dict):
#             batch = examples
#         else:
#             batch = [
#                 {k: v for k, v in zip(self.processor.tokenizer.model_input_names, example)}
#                 for example in examples
#             ]
        
#         for example in batch:
#             messages = example["messages"]

#             text = self.processor.tokenizer.apply_chat_template(
#                 messages, tokenize=False, add_generation_prompt=False
#             )
#             texts.append(text)

#             image = Image.open(example["image_path"]).convert("RGB")
#             images.append(image)

#         batch = self.processor(texts, images=images, return_tensors="pt", padding=True)
#         labels = batch["input_ids"].clone()
#         if self.processor.tokenizer.pad_token_id is not None:
#             labels[labels == self.processor.tokenizer.pad_token_id] = -100
#         batch["labels"] = labels
#         return batch

# class LLavaDataCollator:
#     def __init__(self, processor, dataset_features):
#         self.processor = processor
#         self.dataset_features = dataset_features

#     def __call__(self, examples):
#         texts = []
#         images = []

#         for example in examples:
#             if isinstance(example, dict):
#                 ex = example
#             else:
#                 ex = {k: v for k, v in zip(self.dataset_features.keys(), example)}

#             messages = ex["messages"]

#             text = self.processor.tokenizer.apply_chat_template(
#                 messages, tokenize=False, add_generation_prompt=False
#             )
#             texts.append(text)

#             image = Image.open(ex["image_path"]).convert("RGB")
#             images.append(image)

#         batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
#         labels = batch["input_ids"].clone()
#         if self.processor.tokenizer.pad_token_id is not None:
#             labels[labels == self.processor.tokenizer.pad_token_id] = -100
#         batch["labels"] = labels
#         return batch

# class LLavaDataCollator:
#     def __init__(self, processor):
#         self.processor = processor

#     def __call__(self, examples):
#         texts = []
#         images = []
        
#         for example in examples:
#             messages = example["messages"]

#             text = self.processor.tokenizer.apply_chat_template(
#                 messages, tokenize=False, add_generation_prompt=False
#             )
#             texts.append(text)

#             image = Image.open(example["image_path"]).convert("RGB")
#             images.append(image)

#         batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
#         labels = batch["input_ids"].clone()
#         if self.processor.tokenizer.pad_token_id is not None:
#             labels[labels == self.processor.tokenizer.pad_token_id] = -100
#         batch["labels"] = labels
#         return batch

class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        # Add this print when collator is called
        print("="*50)
        print("Batch samples received by collator:")
        print(examples)
        print("="*50)

        # Now continue your normal logic
        texts = []
        images = []
        
        for example in examples:
            messages = example["messages"]

            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)

            image = Image.open(example["image_path"]).convert("RGB")
            images.append(image)

        batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

In [None]:
# --- LoRA Configuration ---
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# --- SFT Trainer ---
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=2,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=100,
    save_strategy="epoch",
    eval_strategy="no",
    report_to="none",
)

data_collator = LLavaDataCollator(processor
                                  # , dataset.features
                                 )



from torch.utils.data import Dataset
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

# Wrap
train_dataset = SimpleDataset(dataset)




trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class =tokenizer,
    # data_collator=data_collator,
    peft_config=lora_config,
    args=training_args
)

In [None]:
# --- Start Fine-tuning ---
trainer.train()

In [None]:
# --- Save Final Model ---
trainer.model.save_pretrained(output_dir)
print(f"✅ Training complete. Model saved at {output_dir}")

In [None]:
import shutil
shutil.make_archive('llava_lora_instagram', 'zip', 'llava_lora_instagram')

In [None]:
# --- Inference function
def generate_caption(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")

    # Prepare the input prompt (same as training)
    messages = [
        {"role": "user", "content": [
            {"type": "text", "text": "You are a social media influencer. Write a captivating Instagram caption for this image that will engage more viewers and boost interaction. Analyze the image to decide the tone of the caption."},
            {"type": "image"}
        ]}
    ]

    inputs = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    input_tensors = processor(text=inputs, images=[image], return_tensors="pt", padding=True).to(model.device)

    # Generate
    with torch.no_grad():
        # output = model.generate(**input_tensors, max_new_tokens=80)
        output = model.generate(
            **input_tensors,
            max_new_tokens=80,
            repetition_penalty=1.2,   # Encourage less repetition
            temperature=0.7,          # Add some randomness
            top_p=0.9,                # Top-p sampling (nucleus sampling)
            do_sample=True            # Enable sampling instead of greedy decoding
        )

    # Decode
    generated_text = processor.batch_decode(output[:, input_tensors["input_ids"].shape[1]:], skip_special_tokens=True)[0]

    return generated_text.strip()

In [None]:
# --- Example usage
caption = generate_caption("./temp.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test1.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test3.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test4.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test5.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test6.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test7.jpg")
print("Generated Caption:", caption)

In [None]:
# def chat_with_model(messages, image=None):
#     """
#     Function to send messages to the model and get a reply.
#     - `messages`: current conversation list
#     - `image`: PIL.Image if needed for the first user input
#     """
#     # Prepare input
#     if image:
#         inputs = processor.apply_chat_template(messages, images=[image], return_tensors="pt", tokenize=True, add_generation_prompt=True)
#     else:
#         inputs = processor.apply_chat_template(messages, return_tensors="pt", tokenize=True, add_generation_prompt=True)
    
#     inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
#     # Generate
#     with torch.no_grad():
#         output = model.generate(
#             **inputs,
#             max_new_tokens=100,
#             temperature=0.7,
#             top_p=0.9,
#             repetition_penalty=1.1,
#             do_sample=True
#         )
    
#     # Decode output
#     reply = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    
#     return reply

def chat_with_model(messages, image=None):
    """
    Function to send messages to the model and get a reply.
    """
    # Step 1: Create chat template
    prompt_text = processor.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Step 2: Encode inputs
    if image:
        inputs = processor(text=prompt_text, images=[image], return_tensors="pt", padding=True)
    else:
        inputs = processor(text=prompt_text, return_tensors="pt", padding=True)
    
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Step 3: Generate
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True
        )
    
    # Step 4: Decode output
    reply = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    
    return reply


In [None]:

# ------------------
# Start a conversation
# ------------------

# Step 1: Initial messages with an image
image_path = "test2.jpg"
image = Image.open(image_path).convert("RGB")

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "You are a social media influencer. Write a catchy Instagram caption for this image."},
            {"type": "image"}
        ]
    }
]

# First model reply
caption = chat_with_model(messages, image=image)
print(f"\n🧠 Model: {caption}")

# Add model's response to messages
messages.append({
    "role": "assistant",
    "content": [{"type": "text", "text": caption}]
})

# Step 2: Loop for continuous chat
while True:
    user_input = input("\n💬 Your input (type 'quit' to stop): ")
    
    if user_input.lower() == "quit":
        print("👋 Ending chat. Goodbye!")
        break
    
    # Add user's new message
    messages.append({
        "role": "user",
        "content": [{"type": "text", "text": user_input}]
    })
    
    # Get model's reply
    model_reply = chat_with_model(messages)
    print(f"\n🧠 Model: {model_reply}")
    
    # Add model's reply back to messages
    messages.append({
        "role": "assistant",
        "content": [{"type": "text", "text": model_reply}]
    })
