In [None]:
# Testing out different methods of pooling patch embeddings from DINOv2 hidden states
# Baselines were just average pooling and using the CLS token
# I try max pooling and combined max / avg pooling embeddings here

In [1]:
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')
model.eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Dinov2Model(
  (embeddings): Dinov2Embeddings(
    (patch_embeddings): Dinov2PatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Dinov2Encoder(
    (layer): ModuleList(
      (0-11): 12 x Dinov2Layer(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attention): Dinov2Attention(
          (attention): Dinov2SelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): Dinov2SelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (layer_scale1): Dinov2LayerScale()
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06,

In [3]:
# Approach #1: Average pooling (this was what was done in the baseline)

def get_avg_emb(path):
    image = Image.open(path).convert('RGB')
    inputs = processor(image, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    patches = outputs.last_hidden_state[:, 1:, :]
    # print(f"Patches shape: {patches.shape}")
    emb = torch.mean(patches, dim=1)

    # Normalize embeddings
    emb = emb / emb.norm(dim=-1, keepdim=True)
    # print(f"Embedding shape (before squeeze): {emb.shape}")
    return emb.squeeze(0).numpy()

In [4]:
# Approach #2: Max pooling; take max val from each patch

def get_max_emb(path):
    image = Image.open(path).convert("RGB")
    inputs = processor(image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    patches = outputs.last_hidden_state[:, 1:, :]
    emb = torch.max(patches, dim=1)[0] # Max pooling returns (max values, indices), onyl keep values
    emb = emb / emb.norm(dim=-1, keepdim=True)
    
    # Squeeze down to make a vec
    return emb.squeeze(0).numpy()

In [5]:
# Approach #3: Combined avg and max pooling
# Concatenate avg pool vectors and max pooled vectors

def get_comb_emb(path):
    image = Image.open(path).convert('RGB')
    inputs = processor(image, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**inputs)
    
    patches = outputs.last_hidden_state[:, 1:, :]

    avg = torch.mean(patches, dim=1)
    maxed = torch.max(patches, dim=1)[0] 

    combined = torch.cat([avg, maxed], dim=1)
    combined = combined / combined.norm(dim=-1, keepdim=True)

    return combined.squeeze(0).numpy()

In [6]:
# Approach #4: CLS token only (another baseline measure)

def get_cls_emb(path):
    image = Image.open(path).convert("RGB")
    inputs = processor(image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    cls = outputs.last_hidden_state[:, 0, :]
    cls = cls / cls.norm(dim=-1, keepdim=True)
    
    # Squeeze down to make a vec
    return cls.squeeze(0).numpy()

In [7]:
# Import in dataframe of image payloads
payloads = pd.read_csv("payloads.csv")
# payloads.head()

image_paths = payloads['image_url'] 
materials = payloads['material']
color_labels = payloads["color label"]


In [8]:
import re
def clean_tuple(s):
    # extract all integers from the string
    if isinstance(s, str):
        s = s.replace("int64","")
        nums = list(map(int, re.findall(r"\d+", s)))
        return tuple(nums)
    return s

In [9]:
payloads["avg rgb"] = payloads["avg rgb"].apply(clean_tuple)
avg_rgb = payloads["avg rgb"]
avg_rgb

0     (230, 223, 199)
1     (223, 186, 158)
2     (162, 191, 156)
3     (241, 233, 218)
4     (144, 123, 100)
5        (94, 69, 46)
6        (23, 92, 87)
7     (176, 154, 126)
8        (92, 90, 78)
9     (188, 183, 171)
10    (121, 111, 102)
11    (200, 204, 203)
12    (223, 196, 185)
13    (167, 153, 124)
14    (154, 155, 146)
15    (167, 165, 158)
16    (199, 197, 189)
17       (94, 69, 46)
18    (234, 235, 222)
19       (74, 70, 63)
20    (233, 214, 181)
21    (240, 224, 202)
22    (233, 217, 185)
23    (224, 227, 201)
24    (229, 230, 203)
25    (237, 231, 200)
26    (234, 221, 187)
27    (242, 228, 192)
28    (194, 174, 134)
29    (153, 161, 138)
30     (90, 114, 101)
31    (174, 191, 176)
32    (204, 208, 189)
33    (202, 208, 187)
34    (172, 197, 169)
35    (230, 227, 202)
36    (236, 235, 211)
37    (201, 201, 188)
38    (229, 197, 168)
39    (236, 205, 176)
40    (189, 163, 127)
41    (225, 220, 186)
42    (227, 233, 207)
43    (202, 218, 191)
44    (229, 225, 213)
45    (234

In [10]:
# Calculate l2 distance btw two RGB colors
# 0 if exact match
# ~441 if exact opposite
def color_distance(rgb1, rgb2):
    return np.sqrt(sum((int(a) - int(b))**2 for a, b in zip(rgb1, rgb2)))

# Check threshold
def are_colors_similar(rgb1, rgb2, threshold=50):
    return color_distance(rgb1, rgb2) < threshold

In [None]:
methods = {
    'avg': get_avg_emb,
    'max': get_max_emb,
    'combined': get_comb_emb,
    'cls': get_cls_emb
}

for method_name, method_func in methods.items():
    print(f"\n{'='*50}")
    print(f"Testing: {method_name.upper()}")
    print(f"{'='*50}")
    
    # Generate embeddings
    embeddings = []
    for path in image_paths:
        emb = method_func(path)
        embeddings.append(emb)
    
    embeddings = np.array(embeddings)
    print(f"Embedding shape: {embeddings.shape}")
    
    # Compute similarity matrix
    sim_matrix = cosine_similarity(embeddings)
    
    # Get statistics
    upper_tri = np.triu_indices_from(sim_matrix, k=1)
    similarities = sim_matrix[upper_tri]

    # Analyze pairs with high similarity
    high_sim_pairs = []
    for i in range(len(sim_matrix)):
        for j in range(i+1, len(sim_matrix)):
            if sim_matrix[i, j] >= 0.6:
                high_sim_pairs.append({
                    'mat_match': materials[i] == materials[j],
                    'col_match': color_labels[i] == color_labels[j],
                    'col_rgb_distance': color_distance(avg_rgb[i], avg_rgb[j]),
                    'col_rgb_similarity': color_distance(avg_rgb[i], avg_rgb[j]) < 100
                })
    
    mat_match_rate = np.mean([p['mat_match'] for p in high_sim_pairs]) if high_sim_pairs else 0
    col_label_match_rate = np.mean([p['col_match'] for p in high_sim_pairs]) if high_sim_pairs else 0
    col_rgb_similar_rate = np.mean([p['col_rgb_similarity'] for p in high_sim_pairs]) if high_sim_pairs else 0
    
    
    print(f"  Mean similarity: {similarities.mean():.3f}")
    print(f"  Std similarity:  {similarities.std():.3f}")
    print(f"  High-sim pairs (≥0.6): {len(high_sim_pairs)}")
    print(f"  Material match rate: {mat_match_rate:.1%}")
    print(f"  Color label match: {col_label_match_rate:.1%}")
    print(f"  Color RGB similar: {col_rgb_similar_rate:.1%}")



Testing: AVG
Embedding shape: (50, 768)
  Mean similarity: 0.549
  Std similarity:  0.222
  High-sim pairs (≥0.6): 440
  Material match rate: 33.6%
  Color label match: 6.1%
  Color RGB similar: 68.4%

Testing: MAX
Embedding shape: (50, 768)
  Mean similarity: 0.913
  Std similarity:  0.035
  High-sim pairs (≥0.6): 1225
  Material match rate: 12.1%
  Color label match: 4.7%
  Color RGB similar: 59.4%

Testing: COMBINED
Embedding shape: (50, 1536)
  Mean similarity: 0.857
  Std similarity:  0.063
  High-sim pairs (≥0.6): 1225
  Material match rate: 12.1%
  Color label match: 4.7%
  Color RGB similar: 59.4%

Testing: CLS
Embedding shape: (50, 768)
  Mean similarity: 0.418
  Std similarity:  0.279
  High-sim pairs (≥0.6): 412
  Material match rate: 35.9%
  Color label match: 6.3%
  Color RGB similar: 70.1%


In [None]:
# Best so far is CLS token
# High sim pairs are showing 70% match in average color
# and 35% match for material
# 