#### Setup & Imports

In [19]:
import os
from pathlib import Path
from PIL import Image
import imagehash
from tqdm import tqdm
import json
import torch

from transformers import AutoProcessor, LlavaForConditionalGeneration


In [22]:
# Choose the best open-weight annotation 
LLAVA_MODEL = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    LLAVA_MODEL, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)

processor = AutoProcessor.from_pretrained(LLAVA_MODEL)


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 45.63it/s]


##### Configure Directories

In [23]:
SRC_DIR = Path("../data/stickers_png")     # Each folder contains frame_XXXXX.png
OUT_DIR = Path("../data/icons_256")        # Flattened deduped dataset
META_PATH = Path("../data/icons_metadata.jsonl")

OUT_DIR.mkdir(exist_ok=True)


##### Process All Stickers -> Resize + Depublicate + Save

In [None]:
def annotate_image_llava(img: Image.Image, prompt="Describe this icon in detail."):
    # Build conversation prompt for LLaVA-HF
    full_prompt = f"USER: <image>\n{prompt}\nASSISTANT:"

    inputs = processor(
        text=full_prompt,
        images=img,
        return_tensors="pt"
    ).to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=120
    )

    return processor.decode(output[0], skip_special_tokens=True)

def annotate_batch_llava(images, prompts):
    """
    images: list[PIL.Image]
    prompts: list[str] – same length as images
    """
    # Format prompts with proper LLaVA conversation format
    full_prompts = [f"USER: <image>\n{prompt}\nASSISTANT:" for prompt in prompts]
    
    # Preprocess multimodal batch
    inputs = processor(
        text=full_prompts,
        images=images,
        return_tensors="pt",
        padding=True
    ).to(model.device)

    with torch.cuda.amp.autocast(dtype=torch.float16):
        output_ids = model.generate(
            **inputs,
            max_new_tokens=120
        )

    # Decode each generated caption
    captions = processor.batch_decode(output_ids, skip_special_tokens=True)
    return captions

def is_duplicate(phash, seen_hashes, threshold=5):
    """
    phash       : imagehash object for current frame
    seen_hashes : list of previous perceptual hashes
    threshold   : Hamming distance threshold for perceptual duplicates
    """
    for prev in seen_hashes:
        if phash - prev < threshold:
            return True
    return False

In [25]:
import time

total_images = 0
start_time = time.time()
gpu_logs = []

all_metadata = []

BATCH_SIZE = 16
TARGET_SIZE = 256

for sticker_dir in tqdm(list(SRC_DIR.iterdir()), desc="Processing packs"):
    if not sticker_dir.is_dir():
        continue

    frames = sorted(sticker_dir.glob("*.png"))
    if not frames:
        continue

    sticker_name = sticker_dir.name
    seen_hashes = []
    batch_images = []
    batch_paths = []

    saved_count = 0

    for frame_path in frames:
        try:
            img = Image.open(frame_path).convert("RGBA")
        except:
            print("Skipping corrupted image:", frame_path)
            continue

        img = img.resize((TARGET_SIZE, TARGET_SIZE), Image.LANCZOS)

        # perceptual dedupe
        ph = imagehash.phash(img)
        if is_duplicate(ph, seen_hashes):
            continue

        seen_hashes.append(ph)

        # save resized version
        out_filename = f"{sticker_name}_{saved_count:05d}.png"
        img.save(OUT_DIR / out_filename)

        batch_images.append(img)
        batch_paths.append((out_filename, frame_path.name))
        saved_count += 1
        total_images += 1

        # ---- If batch full, run LLaVA ----
        if len(batch_images) == BATCH_SIZE:
            prompts = ["Describe this icon in detail."] * BATCH_SIZE
            captions = annotate_batch_llava(batch_images, prompts)

            # record performance
            gpu_logs.append({
                "time": time.time(),
                "gpu_util": torch.cuda.utilization(),
                "mem": torch.cuda.memory_allocated()
            })

            for (out_filename, orig_name), caption in zip(batch_paths, captions):
                all_metadata.append({
                    "image": out_filename,
                    "original_frame": orig_name,
                    "caption": caption,
                    "sticker": sticker_name,
                })

            batch_images = []
            batch_paths = []

# finish last partial batch
if batch_images:
    prompts = ["Describe this icon in detail."] * len(batch_images)
    captions = annotate_batch_llava(batch_images, prompts)
    for (out_filename, orig_name), caption in zip(batch_paths, captions):
        all_metadata.append({
            "image": out_filename,
            "original_frame": orig_name,
            "caption": caption,
            "sticker": sticker_name,
        })

total_time = time.time() - start_time
print(f"\nProcessed {total_images} images in {total_time:.2f} sec")
print(f"Throughput: {total_images / total_time:.2f} images/sec")

  with torch.cuda.amp.autocast(dtype=torch.float16):
Processing packs:   0%|          | 0/237 [00:00<?, ?it/s]


ValueError: Image features and image tokens do not match: tokens: 0, features 37748736