In [None]:
import os
import json
import logging
from PIL import Image
import torch
from transformers import AutoProcessor,LlavaForConditionalGeneration, AutoTokenizer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Configure logging for production-ready output.
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")

# Define the class mapping.
CLASS_MAPPING = {
    0: "hole",
    1: "pole",
    2: "stairs",
    3: "bottle/glass",
    4: "rock",
    5: "no objects"
}

IMG_DIR = "/content/drive/MyDrive/AIS/dataset/data_subset/images"
OUTPUT_JSON = "/content/drive/MyDrive/AIS/dataset/prepared/descriptions_llava.json"
ANNOT_JSON = "/content/drive/MyDrive/AIS/dataset/prepared/img_label.json"
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def load_model_and_processor():
    logging.info(f"Loading LLAVA teacher '{MODEL_NAME}' on {DEVICE}")
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
    model.eval()
    return processor, model

In [None]:
def build_prompt(annotations):
    # collect all non‑"no objects" labels
    labels = []
    for ann in annotations:
        for c in ann.get("class_label", []):
            name = CLASS_MAPPING.get(c, "")
            if name and name != "no objects":
                labels.append(name)
    if labels:
        objs = ", ".join(set(labels))
        return (
            f"Detected objects: {objs}. "
            "Generate crisp, complete description of image and background for visually impaired users, "
            "mentioning count, shape, and approximate distance and position of the detected objects."
        )
    else:
        return (
            "Generate crisp, complete description of image and background for visually impaired users. "
            "Mention count, shape, approximate distance and position of the objects in the image."
        )

In [None]:
def generate_caption_for_image(image_path: str, prompt: str, processor, model, device: str) -> str:
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        logging.error(f"Error opening image '{image_path}': {e}")
        return ""

    """
    # Use the USER/ASSISTANT template so LLAVA v1.5 knows where the image is
    full_prompt = f"USER: <image>\n{prompt} ASSISTANT:"
    inputs = processor(
        text=full_prompt,
        images=image,
        return_tensors="pt"
    ).to(device)
    """
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt}
            ],
        }
    ]

    # 3. Apply the chat template (inserts vision tokens)
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    # 4. Prepare inputs
    inputs = processor(
        text=[text_prompt],
        images=[image],
        padding=True,
        return_tensors="pt",
    ).to(device)


    with torch.no_grad():
      # 5. Generate and decode
      output_ids = model.generate(**inputs, max_new_tokens=64)
    # remove the input prefix tokens, then decode
    generated_ids = [
        output_ids[i, inputs.input_ids.shape[-1]:] for i in range(output_ids.shape[0])
    ]
    captions = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return captions[0]

In [None]:
def main():
    proc, mdl = load_model_and_processor()

    # load annotations mapping
    with open(ANNOT_JSON, "r") as f:
        ann_map = json.load(f)

    pseudo = {}
    for img_name, anns in ann_map.items():
        path = os.path.join(IMG_DIR, img_name)
        if not os.path.exists(path):
            logging.warning(f"Image missing: {path}, skipping")
            continue

        prompt = build_prompt(anns)
        logging.info(f"Prompt for '{img_name}': {prompt}")
        cap = generate_caption_for_image(path, prompt, proc, mdl, DEVICE)
        if cap:
            pseudo[img_name] = cap
            logging.info(f"Caption: {cap}")
        else:
            logging.warning(f"No caption for '{img_name}'")

    # write out pseudo captions
    os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
    with open(OUTPUT_JSON, "w") as out:
        json.dump(pseudo, out, indent=2)
    logging.info(f"Wrote pseudo captions to {OUTPUT_JSON}")

In [None]:
%%time
if __name__ == "__main__":
    main()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

CPU times: user 7min 51s, sys: 20.8 s, total: 8min 12s
Wall time: 7min 57s
