In [6]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import json
import numpy as np
import os
import time

# --- Config ---
captions_path = "./coco_dataset/llava_captions.json"
checkpoint_path = "./coco_dataset/llava_embeddings_checkpoint.json"
final_output_path = "./coco_dataset/llava_embeddings.npz"
batch_size = 64  # Much faster with batches

# --- Load all captions ---
with open(captions_path, "r") as f:
    all_captions = json.load(f)

# --- Load checkpoint if exists ---
if os.path.exists(checkpoint_path):
    with open(checkpoint_path, "r") as f:
        checkpoint = json.load(f)
else:
    checkpoint = {}

# --- Setup ---
model = SentenceTransformer("all-MiniLM-L6-v2")
pending_items = [(img_id, cap) for img_id, cap in all_captions.items() if img_id not in checkpoint]

print(f"Total: {len(all_captions)} | Done: {len(checkpoint)} | Remaining: {len(pending_items)}")

# --- Batching loop ---
for i in tqdm(range(0, len(pending_items), batch_size), desc="Embedding in batches"):
    batch = pending_items[i:i + batch_size]
    ids = [img_id for img_id, _ in batch]
    texts = [caption for _, caption in batch]

    try:
        embeddings = model.encode(texts, show_progress_bar=False)
    except Exception as e:
        print(f"❌ Error embedding batch starting at index {i}: {e}")
        continue

    # Add results to checkpoint
    for img_id, embedding in zip(ids, embeddings):
        checkpoint[img_id] = embedding.tolist()

    # Save safely using a temp file to avoid corruption
    tmp_path = checkpoint_path + ".tmp"
    with open(tmp_path, "w") as f:
        json.dump(checkpoint, f)
    os.replace(tmp_path, checkpoint_path)

# --- Final Save ---
image_ids = list(checkpoint.keys())
embedding_matrix = np.array([checkpoint[img_id] for img_id in image_ids])
np.savez_compressed(final_output_path, image_ids=image_ids, embeddings=embedding_matrix)

print(f"\n✅ All {len(image_ids)} embeddings saved to {final_output_path}")


Total: 118287 | Done: 0 | Remaining: 118287


Embedding in batches: 100%|██████████| 1849/1849 [13:02:18<00:00, 25.39s/it]  



✅ All 118287 embeddings saved to ./coco_dataset/llava_embeddings.npz
