In [59]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import torch
import matplotlib.pyplot as plt
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import io
from torchvision.transforms import v2 as transforms

from model import DinoWithHead, AttentionReadout
from inference import InferenceImageDataset

In [60]:
root = '/Users/jatentaki/Data/archeo/coins/FMP/slices-high-res/just_coins'
dataset = InferenceImageDataset(root)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

In [61]:
device = torch.device('mps')
model = DinoWithHead(AttentionReadout(head_dim=48, n_head=8)).to(device)
head_state_dict = torch.load('attention_readout_8h_48dim_32out_5.pt', map_location=device)
model.load_state_dict(head_state_dict, strict=False)

Using cache found in /Users/jatentaki/.cache/torch/hub/facebookresearch_dinov2_main


_IncompatibleKeys(missing_keys=['dino.cls_token', 'dino.pos_embed', 'dino.register_tokens', 'dino.mask_token', 'dino.patch_embed.proj.weight', 'dino.patch_embed.proj.bias', 'dino.blocks.0.norm1.weight', 'dino.blocks.0.norm1.bias', 'dino.blocks.0.attn.qkv.weight', 'dino.blocks.0.attn.qkv.bias', 'dino.blocks.0.attn.proj.weight', 'dino.blocks.0.attn.proj.bias', 'dino.blocks.0.ls1.gamma', 'dino.blocks.0.norm2.weight', 'dino.blocks.0.norm2.bias', 'dino.blocks.0.mlp.fc1.weight', 'dino.blocks.0.mlp.fc1.bias', 'dino.blocks.0.mlp.fc2.weight', 'dino.blocks.0.mlp.fc2.bias', 'dino.blocks.0.ls2.gamma', 'dino.blocks.1.norm1.weight', 'dino.blocks.1.norm1.bias', 'dino.blocks.1.attn.qkv.weight', 'dino.blocks.1.attn.qkv.bias', 'dino.blocks.1.attn.proj.weight', 'dino.blocks.1.attn.proj.bias', 'dino.blocks.1.ls1.gamma', 'dino.blocks.1.norm2.weight', 'dino.blocks.1.norm2.bias', 'dino.blocks.1.mlp.fc1.weight', 'dino.blocks.1.mlp.fc1.bias', 'dino.blocks.1.mlp.fc2.weight', 'dino.blocks.1.mlp.fc2.bias', 'dino.

In [62]:
# from tqdm.auto import tqdm

# embeddings, names = [], []

# for batch_images, batch_names in tqdm(dataloader):
#     batch_images = batch_images.to(device)
#     with torch.inference_mode():
#         out = model(batch_images)
    
#     embeddings.append(out.cpu())
#     names.extend(batch_names)

In [63]:
# embeddings = torch.cat(embeddings)

# torch.save({'embeddings': embeddings, 'names': names}, 'embeddings.pt')

In [64]:
state = torch.load('embeddings.pt')
embeddings = state['embeddings']
names = state['names']

In [73]:
n_examples = 64
#rng = torch.Generator().manual_seed(42)
#indices = torch.randperm(len(embeddings), generator=rng)[:n_examples]
indices = torch.arange(n_examples)
example_names = [names[i] for i in indices]
#example_images = [dataset.load_by_name(name) for name in example_names]
example_embeddings = embeddings[indices]

In [76]:
from torch import nn
def similarity(a: Tensor, b: Tensor) -> Tensor:
    a_norm = nn.functional.normalize(a, dim=-1, p=2)
    b_norm = nn.functional.normalize(b, dim=-1, p=2)
    return torch.einsum("ic,jc->ij", a_norm, b_norm)

from dataclasses import dataclass
@dataclass
class SimilarCoin:
    file_name: str
    similarity: float

    def load_image(self) -> Tensor:
        return dataset.load_by_name(self.file_name).permute(1, 2, 0)

@dataclass
class SimilarSeries:
    example: str
    similar_coins: list[SimilarCoin]

    @classmethod
    def from_values_and_indices(cls, example: str, values: Tensor, indices: Tensor) -> "SimilarSeries":
        similar_coins = [SimilarCoin(names[i], v.item()) for i, v in zip(indices, values)]
        return cls(example, similar_coins)

    def plot(self):
        n_similar = len(self.similar_coins)
        nearest_square = int(n_similar ** 0.5)

        fig, axes = plt.subplots(nearest_square, nearest_square, figsize=(20, 20), tight_layout=True)
        for ax, coin in zip(axes.flat, self.similar_coins):
            ax.imshow(coin.load_image())
            coin_id = coin.file_name.removesuffix('.png')
            ax.set_title(f"{coin_id}\n({coin.similarity:.2f})")
            ax.axis('off')
        
        return fig

similarities = similarity(example_embeddings, embeddings)
values, similar_indices = similarities.topk(25, dim=-1)

series = []
for example, value, indices in zip(example_names, values, similar_indices):
    series.append(SimilarSeries.from_values_and_indices(example, value, indices))

In [77]:
similar_indices.shape

torch.Size([64, 25])

In [78]:
for s in series:
    fig = s.plot()
    fig.savefig(f'similarity_tables/{s.example}')
    plt.close(fig)