In [4]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
import torch
import torch.nn as nn
import math

In [2]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
full_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# embedding_layer exposes the ViT embedding utilities (patch projection, cls token, etc.)
embedding_layer = full_model.vit.embeddings
# model hidden size (channels for patch embeddings)
hidden_size = full_model.config.hidden_size
# patch size (typically 16 for this model)
patch_size = getattr(full_model.config, 'patch_size', 16)

In [5]:
class PatchCompressor(nn.Module):
    """Compress ViT patch embeddings (C x Gh x Gw) -> (C' x Gh x Gw).
    This module can accept either precomputed patch embeddings or raw ViT pixel inputs
    and will use the provided `embedding_layer` to compute patches when given pixel tensors."""
    def __init__(self, embedding_layer=None, in_channels=None, compressed_channels=64):
        super().__init__()
        # embedding_layer: the ViT embeddings module (full_model.vit.embeddings)
        self.embedding_layer = embedding_layer
        # determine in_channels from either provided value or embedding_layer config
        if in_channels is None and embedding_layer is not None:
            try:
                in_channels = embedding_layer.position_embeddings.shape[-1]
            except Exception:
                # fall back to None â€” user should provide in_channels
                in_channels = None
        if in_channels is None:
            raise ValueError('in_channels must be provided if embedding_layer does not expose channel size')

        # 1x1 convs act as channel compressors while preserving spatial layout
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, compressed_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(compressed_channels, compressed_channels, kernel_size=1),
        )

    def _patch_embeddings_from_pixels(self, pixel_values):
        # Use the supplied embedding_layer to obtain patch embeddings from pixels
        # embedding_layer.patch_embeddings should accept pixel_values and return either
        # (B, N, C) or (B, C, Gh, Gw). We normalize to (B, C, Gh, Gw).
        patch_emb = self.embedding_layer.patch_embeddings(pixel_values)
        if patch_emb.ndim == 3:
            B, N, C = patch_emb.shape
            G = int(math.sqrt(N))
            patch_emb = patch_emb.transpose(1, 2).reshape(B, C, G, G)
        elif patch_emb.ndim == 4:
            pass
        else:
            raise ValueError(f'Unexpected patch_emb shape: {patch_emb.shape}')
        return patch_emb

    def forward(self, x, is_pixel_values=False):
        """If `is_pixel_values` is True, `x` should be pixel_values (B,3,H,W),
        otherwise `x` should be patch embeddings (B,N,C) or (B,C,Gh,Gw).
        Returns compressed feature map with shape (B, C', Gh, Gw)."""
        if is_pixel_values:
            if self.embedding_layer is None:
                raise RuntimeError('No embedding_layer available to compute patch embeddings from pixels')
            patch_emb = self._patch_embeddings_from_pixels(x)
        else:
            patch_emb = x
            if patch_emb.ndim == 3:
                B, N, C = patch_emb.shape
                G = int(math.sqrt(N))
                patch_emb = patch_emb.transpose(1, 2).reshape(B, C, G, G)
            elif patch_emb.ndim == 4:
                pass
            else:
                raise ValueError(f'Unexpected patch_emb shape: {patch_emb.shape}')
        return self.net(patch_emb)

# create a compressor instance (adjust compressed_channels as needed).
# If embedding_layer is provided we can infer in_channels from the model config (hidden_size).
compressor = PatchCompressor(embedding_layer=embedding_layer, in_channels=hidden_size, compressed_channels=64)

In [None]:
def get_image_embedding(image_path, return_vector=True, device=None):
    """Load an image, compute ViT patch embeddings (via compressor), compress with a small CNN,
    and return either the compressed feature map or flattened vector."""

    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors='pt')
    pixel_values = inputs.get('pixel_values')  # (B, 3, H, W)

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    pixel_values = pixel_values.to(device)
    compressor.to(device)

    compressed_map = compressor(pixel_values, is_pixel_values=True)

    if return_vector:
        # flatten spatial grid into a single vector per image
        return compressed_map.flatten(start_dim=1).cpu()
    return compressed_map.cpu()

In [9]:
full_model.eval()
compressor.eval()

PatchCompressor(
  (embedding_layer): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (net): Sequential(
    (0): Conv2d(768, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [19]:
# Similarity helpers: cosine, L2 and a convenience compare_images() function
def cosine_similarity(a, b, eps=1e-8):
    """Compute cosine similarity between two vectors or batches."""
    # accept 1D or 2D tensors; return scalar for single pair or 1D tensor for batch
    a = a.squeeze()
    b = b.squeeze()
    if a.ndim == 1:
        a = a.unsqueeze(0)
        b = b.unsqueeze(0)
    a = a.float()
    b = b.float()
    a_norm = a / (a.norm(dim=1, keepdim=True) + eps)
    b_norm = b / (b.norm(dim=1, keepdim=True) + eps)
    sims = (a_norm * b_norm).sum(dim=1)
    return sims.item() if sims.numel() == 1 else sims

def l2_distance(a, b):
    """Compute L2 (Euclidean) distance between two vectors or batches."""
    a = a.squeeze()
    b = b.squeeze()
    if a.ndim == 1:
        a = a.unsqueeze(0)
        b = b.unsqueeze(0)
    d = (a - b).norm(p=2, dim=1)
    return d.item() if d.numel() == 1 else d

def compare_images(path1, path2, method='cosine', device=None):
    """Compute similarity between two image paths using compressed embeddings."""
    # Obtain flattened compressed embeddings (shape: (1, D) or (D,))
    e1 = get_image_embedding(path1, return_vector=True, device=device)
    e2 = get_image_embedding(path2, return_vector=True, device=device)

    if method == 'cosine':
        return cosine_similarity(e1, e2)
    elif method == 'l2':
        return l2_distance(e1, e2)
    elif method == 'dot':
        # raw dot product (may be useful if vectors are not normalized)
        v1 = e1.squeeze()
        v2 = e2.squeeze()
        if v1.ndim == 2:
            v1 = v1.squeeze(0)
            v2 = v2.squeeze(0)
        return (v1 * v2).sum().item()
    else:
        raise ValueError(f'Unknown method: {method}')



In [None]:
image_paths = [
    "images/lotr/img (12).jpg",
    "images/lotr/img (100).jpg",
    "images/lotr/img (101).jpg",
]

with torch.no_grad():

    sim_cos = compare_images(image_paths[0], image_paths[1], method='cosine')
    dist_l2 = compare_images(image_paths[0], image_paths[1], method='l2')
    dot_prod = compare_images(image_paths[0], image_paths[1], method='dot')

print('Cosine similarity:', sim_cos)
print('L2 distance:', dist_l2)
print('Dot product:', dot_prod)


Cosine similarity: 0.5283985137939453
L2 distance: 16.557477951049805
Dot product: 153.58360290527344
