In [None]:
!pip install transformers sentencepiece bitsandbytes accelerate --quiet

In [None]:
!pip install torch torchvision torchaudio --quiet

In [None]:
!pip install ultralytics --quiet

In [None]:
!pip install gtts --quiet

In [None]:
!pip install optimum auto-gptq --quiet

In [None]:
!pip install optimum --quiet

In [None]:
!pip install gradio --quiet

In [None]:
!pip install peft --quiet

In [None]:
!wget https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5m.pt

In [None]:
!pip install evaluate --quiet
!pip install rouge_score --quiet
!pip install textstat --quiet

In [None]:
import torch, gc
import gradio as gr
from PIL import Image
from gtts import gTTS
from nltk.translate.bleu_score import sentence_bleu
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    Blip2Processor, Blip2ForConditionalGeneration
)
from ultralytics import YOLO
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from auto_gptq import AutoGPTQForCausalLM
from collections import Counter

In [None]:
torch.set_default_dtype(torch.float32)
torch.backends.cuda.matmul.allow_tf32 = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
yolo_model = YOLO("yolov5m.pt")

In [None]:
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float32
).to(device)
blip_model.gradient_checkpointing_enable()
blip_model.language_model = prepare_model_for_kbit_training(blip_model.language_model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

blip_model.language_model = get_peft_model(blip_model.language_model, lora_config)
blip_model.language_model.print_trainable_parameters()

In [None]:
llama_model_name = "TheBloke/Llama-2-7B-Chat-GPTQ"
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name, use_fast=True)
llama_model = AutoGPTQForCausalLM.from_quantized(
    llama_model_name,
    device="cuda:0",
    torch_dtype=torch.float32,
    trust_remote_code=True,
    use_safetensors=True,
    device_map=None
)

In [None]:
nllb_model_name = "facebook/nllb-200-distilled-600M"
nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_model_name)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_name).to(device)

In [None]:
def fine_tune_blip2(blip_model, blip_processor, correction_text, image, optimizer, epochs=1):
    blip_model.train()
    if not correction_text:
        print("No correction provided. Skipping BLIP2 fine-tuning.")
        return blip_model

    inputs = blip_processor(images=image, text=correction_text, return_tensors="pt").to(device, dtype=torch.float32)
    labels = inputs["input_ids"].clone()
    labels[labels == blip_processor.tokenizer.pad_token_id] = -100

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = blip_model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        print(f"[LoRA Fine-tune] Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

    blip_model.eval()
    return blip_model

In [None]:
def describe_image(image: Image.Image, correction_text: str, trigger_finetune: bool,target_language: str):
    image = image.convert("RGB")


    yolo_results = yolo_model(image)
    object_details = [yolo_model.names[int(cls)] for result in yolo_results for cls in result.boxes.cls.tolist() if result.boxes and result.boxes.cls.numel() > 0]


    with torch.no_grad():
        blip_inputs = blip_processor(images=image, text=["Describe this image."], return_tensors="pt").to(device, dtype=torch.float32)
        blip_output = blip_model.generate(**blip_inputs, max_new_tokens=60)
        initial_caption = blip_processor.tokenizer.decode(blip_output[0], skip_special_tokens=True)

    torch.cuda.empty_cache(); gc.collect()


    if trigger_finetune and correction_text.strip():
        print("Fine-tuning BLIP2 with user correction...")
        optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, blip_model.parameters()), lr=1e-4
        )
        fine_tune_blip2(blip_model, blip_processor, correction_text, image, optimizer)


    object_summary = ", ".join(f"{obj} ({count})" for obj, count in Counter(object_details).items())
    prompt = f"Describe the scene in detail. Caption: '{initial_caption}'. The scene includes: {object_summary}. Write a descriptive paragraph (not a list)."

    # prompt = (
    #     f"A descriptive paragraph about this scene:\n"
    #     f"{initial_caption}\n"
    #     f"The scene includes: {object_summary}."
    # )

    with torch.no_grad():
        llama_inputs = llama_tokenizer(prompt, return_tensors="pt").to(device)
        llama_ids = llama_model.generate(**llama_inputs, max_new_tokens=200, do_sample=True, temperature=0.8, top_k=50, top_p=0.95, repetition_penalty=1.2)
        refined_caption = llama_tokenizer.decode(llama_ids[0], skip_special_tokens=True)



    language_code_map = {
        "French": "fra_Latn",
        "Spanish": "spa_Latn",
        "German": "deu_Latn",
        "Hindi": "hin_Deva",
        "Arabic": "arb_Arab",
        "Chinese (Simplified)": "zho_Hans",
        "Malayalam": "mal_Mlym"
    }

    lang_token_id = nllb_tokenizer.convert_tokens_to_ids(language_code_map.get(target_language, "fra_Latn"))

    with torch.no_grad():
        nllb_inputs = nllb_tokenizer(refined_caption, return_tensors="pt").to(device)
        nllb_tokens = nllb_model.generate(**nllb_inputs, forced_bos_token_id=lang_token_id)
        translated_caption = nllb_tokenizer.decode(nllb_tokens[0], skip_special_tokens=True)


    tts = gTTS(translated_caption)
    tts.save("caption_audio.mp3")

    return initial_caption, refined_caption, translated_caption, "caption_audio.mp3"

In [None]:
interface = gr.Interface(
    fn=describe_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="✏️ Your Caption Correction (Optional)", placeholder="Enter improved caption if needed..."),
        gr.Checkbox(label="✅ Fine-Tune BLIP2 using this correction?"),
        gr.Dropdown(
            choices=["French", "Spanish", "German", "Hindi", "Arabic", "Chinese (Simplified)", "Malayalam"],
            label="🌍 Select Translation Language",
            value="French"
        )
    ],
    outputs=[
        gr.Textbox(label="📌 Initial Caption (BLIP2)"),
        gr.Textbox(label="🪄 Refined Description (LLaMA)"),
        gr.Textbox(label="🌍 Translated Description"),
        gr.Audio(label="🔊 Description Audio")
    ],
    title="🖼️ Image Detailed Description Generator",
    description="Upload an image and optionally provide a better caption to improve the model using fine-tuning."
).queue()

interface.launch(debug=True)