### Extract Embeddings

Based on https://github.com/rom1504/clip-retrieval

First, pip install clip-retrieval

In [1]:
import os
import torch
import clip
from PIL import Image
from tqdm import tqdm
import numpy as np

ModuleNotFoundError: No module named 'clip'

In [8]:
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

AttributeError: module 'clip' has no attribute 'load'

#### iNaturalist Embeddings

In [5]:
# Settings
IMG_DIR = "inat_images" #should probably just load the huggingface dataset if possible
OUT_DIR = "inat_embs"
os.makedirs(OUT_DIR, exist_ok=True)

In [4]:
# Get image paths
valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
image_paths = sorted([os.path.join(IMG_DIR, f) for f in os.listdir(IMG_DIR) 
                      if f.lower().endswith(valid_exts)])

print(f"Processing {len(image_paths)} images on {device}...")

Processing 1238 images on cuda...


In [5]:
# Extraction loop
embeddings = []
metadata = []

for path in tqdm(image_paths):
    try:
        image = preprocess(Image.open(path)).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = model.encode_image(image)
            features /= features.norm(dim=-1, keepdim=True)    # normalize for cosine similarity
            
        embeddings.append(features.cpu().numpy())
        metadata.append(path)
        
    except Exception as e:
        print(f"Skipping corrupt image {path}: {e}")

100%|██████████| 1238/1238 [01:59<00:00, 10.38it/s]


In [6]:
if embeddings:
    emb_matrix = np.vstack(embeddings)
    
    np.save(os.path.join(OUT_DIR, "embeddings.npy"), emb_matrix)
    
    with open(os.path.join(OUT_DIR, "filenames.txt"), "w") as f:
        f.write("\n".join(metadata))
        
    print(f"Success! Saved {emb_matrix.shape} matrix to {OUT_DIR}/embeddings.npy")

Success! Saved (1238, 512) matrix to inat_emb/embeddings.npy


#### Kaggle Embeddings

In [2]:
# Settings
IMG_DIR = "clean_insect_images"
OUT_DIR = "clean_embs"
os.makedirs(OUT_DIR, exist_ok=True)
from sample_clean_data import sampled_clean_data 

In [3]:
# Extraction loop
embeddings = []
metadata = []

for path in tqdm(sampled_clean_data): # only look at the samples clean images
    try:
        image = preprocess(Image.open(path)).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = model.encode_image(image)
            features /= features.norm(dim=-1, keepdim=True)    # normalize for cosine similarity
            
        embeddings.append(features.cpu().numpy())
        metadata.append(path)
        
    except Exception as e:
        print(f"Skipping corrupt image {path}: {e}")

100%|██████████| 400/400 [00:00<00:00, 116041.06it/s]

Skipping corrupt image clean_insect_images/Ant/Ant_283.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_525.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_234.png: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_88.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_639.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_598.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_489.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_46.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_262.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_319.jpg: name 'preprocess' is not defined
Skipping corrupt image clean_insect_images/Ant/Ant_127.jpg: name 'prepro




In [None]:
if embeddings:
    emb_matrix = np.vstack(embeddings)
    
    np.save(os.path.join(OUT_DIR, "embeddings.npy"), emb_matrix)
    
    with open(os.path.join(OUT_DIR, "filenames.txt"), "w") as f:
        f.write("\n".join(metadata))
        
    print(f"Success! Saved {emb_matrix.shape} matrix to {OUT_DIR}/embeddings.npy")