In [None]:
# Fine-Tuning LLaVA 1.5 7B (HF version) for Instagram Captioning on Custom JSONL Data
# Compatible with RTX 3060 / T4 (12-16 GB GPUs)

!pip install -U "transformers>=4.39.0"
!pip install peft bitsandbytes
!pip install -U "trl>=0.8.3"

Loading Libraries

In [None]:
import os
import json
import torch
import wandb
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from trl import SFTTrainer
from peft import LoraConfig
from datasets import Dataset
from torch.optim import AdamW
from datasets import load_metric
from datasets import load_dataset
from torch.utils.data import Dataset
from multiprocessing import Pool, cpu_count
from transformers.integrations import WandbCallback
from transformers import get_cosine_schedule_with_warmup
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig

Setting Up WandB

In [None]:
os.environ["WANDB_API_KEY"] = "d707117fb8a8f4cf9916b4cf42fe630e09c93b6b"

Setting Up Script Configs

In [None]:
# data configs
DATASET = "InstaCities1M"
BASE_IMAGES_DIR = "/mnt/InstaCities1M/img_resized_1M/cities_instagram/"
BASE_CAPTIONS_DIR = "/mnt/InstaCities1M/captions_resized_1M/cities_instagram/"
OUTPUT_JSONL_PATH = './datasset_v1.jsonl'
CITIES = ['newyork']

LLAVA_CHAT_TEMPLATE = (
    "You are a social media influencer. Write a captivating Instagram caption for this image "
    "that will engage more viewers and boost interaction. Analyze the image to decide the tone of the caption."
)

print(os.path.exists(BASE_IMAGES_DIR))
print(os.path.exists(BASE_CAPTIONS_DIR))

# Expirementation Details
PROJECT = "Snap2Caption"
RUN_NAME = "llava-7b-ft-instagram-v1"

# model configs
MODEL_NAME = "LLaVA-7B-HF (LLaVA-1.5-7B)"
TASK = "Image Captioning"
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"]

# Search/Filter Tags
TAGS = ["llava", "image-captioning", "LoRA", "fine-tuning"]
GROUP = "llava-instagram-experiments"
NOTES = "Baseline fine-tuning on InstaCities1M with 4-bit quantized model and LoRA adapters."

# Transformer Setting
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
MODEL_SAVE_PATH = "./llava_lora_instagram"

# Optimization Strategy
TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 2
NO_OF_EPOCHS = 5
LEARNING_RATE = 2e-5
LOGGING_STEPS = 100
WEIGHT_DECAY = 0.01

Verifying GPU

In [None]:
torch.cuda.is_available()

Dataset Parsing

In [None]:
images_files = []
captions_files = []

for city in CITIES:
    img = BASE_IMAGES_DIR + city + '/' + np.array(os.listdir(BASE_IMAGES_DIR + city))
    caption = BASE_CAPTIONS_DIR + city + '/' + np.array(os.listdir(BASE_CAPTIONS_DIR + city))
    images_files.extend(img)
    captions_files.extend(caption)

In [None]:
# Clean filenames
image_ids = {os.path.splitext(os.path.basename(img))[0] for img in images_files}
caption_ids = {os.path.splitext(os.path.basename(cap))[0] for cap in captions_files}

# Now match
common_ids = image_ids & caption_ids

# Filter
filtered_image_files = [img for img in images_files if os.path.splitext(os.path.basename(img))[0] in common_ids]
filtered_caption_files = [cap for cap in captions_files if os.path.splitext(os.path.basename(cap))[0] in common_ids]

images_files = filtered_image_files
captions_files = filtered_caption_files

In [None]:
len(images_files), len(captions_files)

In [None]:
# --- Worker function ---
def process_pair(i):
    try:
        img_path = images_files[i]
        caption_path = captions_files[i]

        with open(caption_path, 'r', encoding='utf-8') as f:
            caption = f.read().strip().replace('\n', ' ')
            if not caption:
                return None

        # Create messages field directly
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": LLAVA_CHAT_TEMPLATE},
                    {"type": "image"}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": caption}
                ]
            }
        ]

        return {
            "image_path": img_path,
            "messages": messages
        }

    except Exception:
        return None


# --- Multiprocessing ---
with Pool(cpu_count()) as pool:
    results = list(tqdm(pool.imap(process_pair, range(len(images_files))), total=len(images_files)))


In [None]:
data = [entry for entry in results if entry is not None]

# --- Write JSONL File ---
with open(OUTPUT_JSONL_PATH, 'w', encoding='utf-8') as f:
    for entry in data:
        f.write(json.dumps(entry) + "\n")

print(f"JSONL created: {OUTPUT_JSONL_PATH} with {len(data)} samples.")

Configuring Model

In [None]:
# --- Configuration ---
model_id = MODEL_ID
data_path = OUTPUT_JSONL_PATH  # path to your formatted JSONL file
output_dir = MODEL_SAVE_PATH

In [None]:
# --- Model Loading (4bit Quantization) ---
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

tokenizer.chat_template = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}"
    "{% for item in message['content'] %}"
    "{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}"
    "{% endfor %}"
    "{% if message['role'] == 'assistant' %}{{ eos_token }}{% endif %}"
    "{% endfor %}"
)

In [None]:
def load_image(example):
    example["image"] = Image.open(example["image_path"]).convert("RGB")
    return example

# Load and preprocess dataset with pre-formatted messages
dataset = load_dataset("json", data_files=data_path)["train"]

In [None]:
dataset[0]

In [None]:
# Read JSONL manually
dataset = []
with open(data_path, "r", encoding="utf-8") as f:
    for line in f:
        example = json.loads(line.strip())
        dataset.append(example)

print(f"Loaded {len(dataset)} samples.")

dataset = Dataset.from_list(dataset)

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

# Wrap
train_dataset = SimpleDataset(dataset)

In [None]:
bleu = load_metric("bleu")
rouge = load_metric("rouge")
cider = load_metric("cider")  # cider is available in some versions or custom

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds.argmax(dim=-1) if hasattr(preds, "argmax") else preds
    preds = preds.tolist()
    labels = labels.tolist()

    # Assuming you have a decode function:
    pred_texts = tokenizer.batch_decode(preds, skip_special_tokens=True)
    label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Format inputs properly
    pred_texts = [text.strip().split() for text in pred_texts]
    label_texts = [[text.strip().split()] for text in label_texts]  # note double brackets for corpus

    bleu_score = bleu.compute(predictions=pred_texts, references=label_texts)["bleu"]
    rouge_score = rouge.compute(predictions=[" ".join(p) for p in pred_texts], references=[" ".join(l[0]) for l in label_texts])["rougeL"]
    
    # cider expects slightly different input
    try:
        cider_score = cider.compute(predictions=[" ".join(p) for p in pred_texts], references=[" ".join(l[0]) for l in label_texts])["CIDEr"]
    except:
        cider_score = 0.0

    # Exact Match (if captions are identical)
    exact_matches = sum([" ".join(p) == " ".join(l[0]) for p, l in zip(pred_texts, label_texts)])
    exact_match_score = exact_matches / len(pred_texts)

    return {
        "BLEU": bleu_score,
        "ROUGE_L": rouge_score,
        "CIDEr": cider_score,
        "Exact_Match": exact_match_score,
    }

In [None]:
# --- LoRA Configuration ---
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
# --- SFT Trainer ---
training_args = TrainingArguments(
    fp16=True,
    report_to="wandb",
    eval_strategy="no",
    optim="adamw_torch",
    save_strategy="epoch",
    lr_scheduler_type="cosine",
    output_dir=output_dir,
    logging_steps=LOGGING_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NO_OF_EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
)

# Intiailize Wandb
wandb.init(
    project=PROJECT,
    name=RUN_NAME,
    config={
        **training_args.to_dict(),
        "custom_config": {
            "model_name": MODEL_NAME,
            "dataset": DATASET,
            "task": TASK,
            "LoRA_r": LORA_R,
            "LoRA_alpha": LORA_ALPHA,
            "LoRA_dropout": LORA_DROPOUT,
            "target_modules": TARGET_MODULES,
        }
    },
    tags=TAGS,
    group=GROUP,
    notes=NOTES,
    mode="online"
)

wandb.watch(model, log="all", log_freq=100)

In [None]:
optimizer = AdamW(
    model.parameters(), 
    lr=LEARNING_RATE,
    betas=(0.9, 0.95),
    weight_decay=WEIGHT_DECAY
)

total_steps = (len(dataset) // (TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)) * NO_OF_EPOCHS
warmup_steps = int(0.05 * total_steps)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
)

In [None]:
class CustomWandbCallback(WandbCallback):
    def on_step_end(self, args, state, control, model=None, **kwargs):
        if model is not None and state.global_step % 100 == 0:  # every 100 steps
            for name, param in model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    wandb.log({f"weights/{name}": wandb.Histogram(param.data.cpu())}, step=state.global_step)
        return super().on_step_end(args, state, control, model=model, **kwargs)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    peft_config=lora_config,
    processing_class =tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[CustomWandbCallback()],
    optimizers=(optimizer, lr_scheduler),
)

Training

In [None]:
# --- Start Fine-tuning ---
trainer.train()

In [None]:
# --- Save Final Model ---
trainer.model.save_pretrained(output_dir)
print(f"✅ Training complete. Model saved at {output_dir}")

In [None]:
import shutil
shutil.make_archive('llava_lora_instagram', 'zip', 'llava_lora_instagram')

Inference

In [None]:
# --- Inference function
def generate_caption(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")

    # Prepare the input prompt (same as training)
    messages = [
        {"role": "user", "content": [
            {"type": "text", "text": "You are a social media influencer. Write a captivating Instagram caption for this image that will engage more viewers and boost interaction. Analyze the image to decide the tone of the caption."},
            {"type": "image"}
        ]}
    ]

    inputs = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
    input_tensors = processor(text=inputs, images=[image], return_tensors="pt", padding=True).to(model.device)

    # Generate
    with torch.no_grad():
        # output = model.generate(**input_tensors, max_new_tokens=80)
        output = model.generate(
            **input_tensors,
            max_new_tokens=80,
            repetition_penalty=1.2,   # Encourage less repetition
            temperature=0.7,          # Add some randomness
            top_p=0.9,                # Top-p sampling (nucleus sampling)
            do_sample=True            # Enable sampling instead of greedy decoding
        )

    # Decode
    generated_text = processor.batch_decode(output[:, input_tensors["input_ids"].shape[1]:], skip_special_tokens=True)[0]

    return generated_text.strip()

In [None]:
# --- Example usage
caption = generate_caption("./temp.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test1.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test3.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test4.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test5.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test6.jpg")
print("Generated Caption:", caption)

In [None]:
# --- Example usage
caption = generate_caption("./test7.jpg")
print("Generated Caption:", caption)

In [None]:
# def chat_with_model(messages, image=None):
#     """
#     Function to send messages to the model and get a reply.
#     - `messages`: current conversation list
#     - `image`: PIL.Image if needed for the first user input
#     """
#     # Prepare input
#     if image:
#         inputs = processor.apply_chat_template(messages, images=[image], return_tensors="pt", tokenize=True, add_generation_prompt=True)
#     else:
#         inputs = processor.apply_chat_template(messages, return_tensors="pt", tokenize=True, add_generation_prompt=True)
    
#     inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
#     # Generate
#     with torch.no_grad():
#         output = model.generate(
#             **inputs,
#             max_new_tokens=100,
#             temperature=0.7,
#             top_p=0.9,
#             repetition_penalty=1.1,
#             do_sample=True
#         )
    
#     # Decode output
#     reply = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    
#     return reply

def chat_with_model(messages, image=None):
    """
    Function to send messages to the model and get a reply.
    """
    # Step 1: Create chat template
    prompt_text = processor.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Step 2: Encode inputs
    if image:
        inputs = processor(text=prompt_text, images=[image], return_tensors="pt", padding=True)
    else:
        inputs = processor(text=prompt_text, return_tensors="pt", padding=True)
    
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Step 3: Generate
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True
        )
    
    # Step 4: Decode output
    reply = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    
    return reply


In [None]:

# ------------------
# Start a conversation
# ------------------

# Step 1: Initial messages with an image
image_path = "test2.jpg"
image = Image.open(image_path).convert("RGB")

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "You are a social media influencer. Write a catchy Instagram caption for this image."},
            {"type": "image"}
        ]
    }
]

# First model reply
caption = chat_with_model(messages, image=image)
print(f"\n🧠 Model: {caption}")

# Add model's response to messages
messages.append({
    "role": "assistant",
    "content": [{"type": "text", "text": caption}]
})

# Step 2: Loop for continuous chat
while True:
    user_input = input("\n💬 Your input (type 'quit' to stop): ")
    
    if user_input.lower() == "quit":
        print("👋 Ending chat. Goodbye!")
        break
    
    # Add user's new message
    messages.append({
        "role": "user",
        "content": [{"type": "text", "text": user_input}]
    })
    
    # Get model's reply
    model_reply = chat_with_model(messages)
    print(f"\n🧠 Model: {model_reply}")
    
    # Add model's reply back to messages
    messages.append({
        "role": "assistant",
        "content": [{"type": "text", "text": model_reply}]
    })
