# Generate Image Embeddings

This notebook generates vector embeddings for marketplace images using the DINOv2 (Vision Transformer) model:

- **Load DINOv2 model**: Uses the pre-trained `dinov2` model from Facebook Research
- **Process images**: Batch-processes all dataset images through the model
- **Extract embeddings**: Generates 768-dimensional feature vectors for each image
- **Storage**: Embeddings are stored in a PostgreSQL database with vector search extension

These embeddings enable image similarity search and retrieval in the marketplace image RAG system.

The code from this notebook is refined and integrated into the production codebase, where embeddings are persisted to PostgreSQL with HNSW index for scalable vector search operations.

In [1]:
import os
import sys
from pathlib import Path

from dotenv import load_dotenv

notebooks_dir = Path().absolute()
project_dir = notebooks_dir.parent
os.chdir(project_dir)
load_dotenv()
sys.path.append(project_dir)

In [2]:
import warnings

import torch
from tqdm import tqdm

from marketplace.const import DATA_ROOT
from marketplace.dataset import MarketplaceDataModule
from marketplace.db import ImageEmbedding, save_image_embeddings
from marketplace.model import DinoV2WithNormalize

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Load model

warnings.filterwarnings("ignore", message="xFormers is available")
model = DinoV2WithNormalize()
model = model.to(device)
model.eval()

DinoV2WithNormalize(
  (model): DinoVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0-11): 12 x NestedTensorBlock(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): MemEffAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): LayerScale()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (ls2): LayerScale()
        (drop_path2)

### Test with First Batch

Test the embedding generation with just the first batch to verify everything works correctly.

In [5]:
# Setup
data_module = MarketplaceDataModule(batch_size=4, num_workers=1)
data_module.prepare_data()
data_module.setup(stage="predict")

# Get first batch
predict_loader = data_module.predict_dataloader()
batch_images, batch_labels = next(iter(predict_loader))
print(f"Batch shape: {batch_images.shape}") # Should be (4, 3, 224, 224)

# Generate embeddings for first batch
with torch.no_grad():
    batch_images = batch_images.to(device)
    embeddings = model(batch_images)
    
print(f"Embeddings shape: {embeddings.shape}")  # Should be (4, 768)

# Print first batch results
for i in range(len(embeddings)):
    image_path = predict_loader.dataset.samples[i][0]
    embedding = embeddings[i].cpu()
    print(f"\n{image_path}")
    print(f"Embedding (first 10 values): {embedding[:10]}")
    print(f"Embedding norm: {embedding.norm().item():.4f}")

Batch shape: torch.Size([4, 3, 224, 224])
Embeddings shape: torch.Size([4, 768])

data/adults handicrafts/261226323.jpg
Embedding (first 10 values): tensor([ 0.0478,  0.0214,  0.0208, -0.0664,  0.0213,  0.0381, -0.1012, -0.0127,
        -0.0391,  0.0240])
Embedding norm: 1.0000

data/adults handicrafts/261227883.jpg
Embedding (first 10 values): tensor([-0.0039,  0.0138, -0.0275, -0.0247, -0.0110, -0.0788, -0.0222, -0.0807,
        -0.0557,  0.0400])
Embedding norm: 1.0000

data/adults handicrafts/261233434.jpg
Embedding (first 10 values): tensor([ 0.0139,  0.0183,  0.0140, -0.0045, -0.0407, -0.0198, -0.0852, -0.0207,
         0.0271, -0.0188])
Embedding norm: 1.0000

data/adults handicrafts/261246318.jpg
Embedding (first 10 values): tensor([-0.0027,  0.0496,  0.0160, -0.0441, -0.0478,  0.0509, -0.0549, -0.0247,
        -0.0163,  0.0036])
Embedding norm: 1.0000


### Image Embeddings Pipeline

Compute embeddings for all images in the dataset and persist to PostgreSQL

In [13]:
# Setup
batch_size = 128
data_module = MarketplaceDataModule(batch_size=batch_size, num_workers=8)
data_module.prepare_data()
data_module.setup(stage="predict")

# Get all images
predict_loader = data_module.predict_dataloader()

print(f"Total images to process: {len(predict_loader.dataset)}")
print(f"Batch size: {predict_loader.batch_size}")
print(f"Number of batches: {len(predict_loader)}")
print("\nGenerating embeddings...\n")

# Iterate all batches and generate embeddings
global_idx = 0
with torch.no_grad():
    for batch_images, _batch_labels in tqdm(predict_loader, total=len(predict_loader)):
        # Move batch to device
        batch_images = batch_images.to(device)  # noqa: PLW2901 (redefined-loop-name)
        
        # Generate embeddings
        embeddings = model(batch_images)
        
        # Get image paths for this batch
        batch_size = batch_images.shape[0]
        batch_indices = range(global_idx, global_idx + batch_size)
        image_paths = [predict_loader.dataset.samples[i][0] for i in batch_indices]
        
        # save image path with its embedding to database
        db_objects = [
            ImageEmbedding(image_path=rel_path, embedding=embedding.cpu())
            for path, embedding in zip(image_paths, embeddings, strict=True)
            if (rel_path := str(Path(path).relative_to(DATA_ROOT)))
        ]
        save_image_embeddings(db_objects)
        
        global_idx += batch_size

print(f"\nCompleted! Processed {global_idx} images.")

Total images to process: 138969
Batch size: 128
Number of batches: 1086

Generating embeddings...



100%|██████████| 1086/1086 [07:41<00:00,  2.35it/s]


Completed! Processed 138969 images.



