In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms

from coin_ai.data.inference import InferenceImageDataset, ResizeAndKeepRatio
from coin_ai.augmentations import CircleCrop
from coin_ai.config import load_config

In [None]:
config_root = 'checkpoints/24.02.20/attn-readout-rotations-flips'
data_root = '/Users/jatentaki/Data/archeo/coins/FMP/slices-high-res/just_coins'
embedding_save_name = f'{config_root}/embeddings.pt'
version = 0
config = load_config(f'{config_root}/config.py')

In [None]:
model = config.learner
state_dict = torch.load(f'{config_root}/lightning_logs/version_{version}/checkpoints/val_1_acc_at_1.ckpt', map_location='cpu')['state_dict']
model.load_state_dict(state_dict)

In [None]:
config.val_augmentation

In [None]:
transform = transforms.Compose([
    ResizeAndKeepRatio(518),
    CircleCrop(518),
    config.val_augmentation,
])

dataset = InferenceImageDataset(data_root, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
images, names = next(iter(dataloader))

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(images[i].permute(1, 2, 0))
    ax.set_title(names[i])
    ax.axis('off')

In [None]:
def run_inference():
    device = torch.device('mps')
    model = model.to(device)

    embeddings, names = [], []
    for batch_images, batch_names in tqdm(dataloader):
        batch_images = batch_images.to(device)
        with torch.inference_mode():
            out = model(batch_images)
        
        embeddings.append(out.cpu())
        names.extend(batch_names)

    embeddings = torch.cat(embeddings)

    torch.save({'embeddings': embeddings, 'names': names}, embedding_save_name)

#run_inference()

In [None]:
_state = torch.load(embedding_save_name)
embeddings = _state['embeddings']
names = _state['names']

In [None]:
similarity = config.loss_fn.similarity
all_to_all = similarity(embeddings)

In [None]:
diag_ix = torch.arange(all_to_all.shape[0])
all_to_all[diag_ix, diag_ix] = -1.
all_to_all = torch.tril(all_to_all, diagonal=-1)

In [None]:
all_to_all_flat = all_to_all.flatten()
_, top_ix = all_to_all_flat.topk(1000)
top_i, top_j = top_ix // all_to_all.shape[0], top_ix % all_to_all.shape[0]

In [None]:
for i, j in zip(top_i[:25], top_j):
    fig, (a1, a2) = plt.subplots(1, 2)
    a1.imshow(dataset.load_by_name(names[i]).permute(1, 2, 0))
    a1.set_title(names[i])
    a1.axis('off')
    a2.imshow(dataset.load_by_name(names[j]).permute(1, 2, 0))
    a2.set_title(names[j])
    a2.axis('off')

In [None]:
# n_examples = 64
# #rng = torch.Generator().manual_seed(42)
# #indices = torch.randperm(len(embeddings), generator=rng)[:n_examples]
# indices = torch.arange(n_examples)
# example_names = [names[i] for i in indices]
# #example_images = [dataset.load_by_name(name) for name in example_names]
# example_embeddings = embeddings[indices]

In [None]:

# from dataclasses import dataclass
# @dataclass
# class SimilarCoin:
#     file_name: str
#     similarity: float

#     def load_image(self) -> Tensor:
#         return dataset.load_by_name(self.file_name).permute(1, 2, 0)

# @dataclass
# class SimilarSeries:
#     example: str
#     similar_coins: list[SimilarCoin]

#     @classmethod
#     def from_values_and_indices(cls, example: str, values: Tensor, indices: Tensor) -> "SimilarSeries":
#         similar_coins = [SimilarCoin(names[i], v.item()) for i, v in zip(indices, values)]
#         return cls(example, similar_coins)

#     def plot(self):
#         n_similar = len(self.similar_coins)
#         nearest_square = int(n_similar ** 0.5)

#         fig, axes = plt.subplots(nearest_square, nearest_square, figsize=(20, 20), tight_layout=True)
#         for ax, coin in zip(axes.flat, self.similar_coins):
#             ax.imshow(coin.load_image())
#             coin_id = coin.file_name.removesuffix('.png')
#             ax.set_title(f"{coin_id}\n({coin.similarity:.2f})")
#             ax.axis('off')
        
#         return fig

# similarities = similarity(example_embeddings, embeddings)
# values, similar_indices = similarities.topk(25, dim=-1)

# series = []
# for example, value, indices in zip(example_names, values, similar_indices):
#     series.append(SimilarSeries.from_values_and_indices(example, value, indices))

In [None]:
# similar_indices.shape

In [None]:
# for s in series:
#     fig = s.plot()
#     fig.savefig(f'similarity_tables/{s.example}')
#     plt.close(fig)