In [1]:
import os, json
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
import clip

device = "mps" if torch.backends.mps.is_available() else "cpu"



'mps'

# EMBED USING CLIP VIT/B32

In [2]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
IMAGE_DIR = "data/coco/train2017"
IMG_JSON = "data/image_clip_b32.json"

In [10]:
results = []
def embed_batch(images):
    
    imgs = torch.stack([preprocess(img).to(device) for img in images])

    with torch.no_grad():
        vec = model.encode_image(imgs).float()

    vec = vec / vec.norm(dim=-1, keepdim=True)

    return vec.cpu().numpy()


In [11]:
files = sorted(os.listdir(IMAGE_DIR))
batch_imgs = []
batch_paths = []

for fname in tqdm(files):
    if not fname.endswith(".jpg"):
        continue

    img_path = os.path.join(IMAGE_DIR, fname)

    try:
        img = Image.open(img_path).convert("RGB")
        batch_imgs.append(img)
        batch_paths.append(img_path)

        if len(batch_imgs) == 32:
            vecs = embed_batch(batch_imgs)
            for p, v in zip(batch_paths, vecs):
                results.append({
                    "image_path": p,
                    "embedding": v.tolist()
                })
            batch_imgs, batch_paths = [], []  # reset

    except Exception as e:
        print("E:", fname, e)

if batch_imgs:
    vecs = embed_batch(batch_imgs)
    for p, v in zip(batch_paths, vecs):
        results.append({
            "image_path": p,
            "embedding": v.tolist()
        })

100%|██████████| 118287/118287 [1:20:42<00:00, 24.43it/s]    


In [12]:
with open(IMG_JSON, "w") as f:
    json.dump(results, f)

print("Saved:", IMG_JSON)
print("Total images embedded:", len(results))


Saved: data/image_clip_b32.json
Total images embedded: 118287
