## Retreaval

In [1]:
from typing import OrderedDict

import torch
import faiss
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.transforms import v2 as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import os
import copy
from typing import List, Dict, Optional
from torch.utils.data import Dataset
from PIL import Image

import json
from lightly.transforms.utils import IMAGENET_NORMALIZE

In [2]:
BS = 64
emb_dim = 2_048

In [3]:
transform = T.Compose([
    T.Resize(96),
    T.ToTensor(),
    T.Normalize(
        mean=IMAGENET_NORMALIZE["mean"],
        std=IMAGENET_NORMALIZE["std"],
    ),
])



In [4]:
classes = {
    'dog': 0, 
    'giraffe': 1, 
    'guitar': 2, 
    'house': 3, 
    'person': 4, 
    'horse': 5, 
    'elephant': 6
}

domains = {
    'sketch': 0, 
    'cartoon': 1, 
    'art_painting': 2, 
    'photo': 3
}

In [5]:
class ImageDataset(Dataset):
    def __init__(self, data_path, set_map: List[Dict], transform=None, 
                 classes: Optional[Dict[str, int]] = None,
                 domains: Optional[Dict[str, int]] = None) -> None:
        ''' Each item in set_map is expected to contain:
                img_path: Full path to image,
                label: Label corresponding to image at img_path
                domain: Domain corresponding to image at img_path
        '''

        self.set_map = copy.deepcopy(set_map)
        self.transform=transform

        if classes:
            for sample in self.set_map:
                sample['label'] = classes[sample['label']]
        if domains:
            for sample in self.set_map:
                sample['domain'] = domains[sample['domain']]

        self.classes = classes
        self.doamins = domains
        self.data_path = data_path

    def __len__(self):
        return len(self.set_map)
    
    def __getitem__(self, index):   
        sample = self.set_map[index]
        image = Image.open(os.path.join(self.data_path, sample['img_path']))

        if self.transform:
            image = self.transform(image)

        return dict(image=image, **sample)

In [6]:
with open('./test.json') as f:
    test_set_map = json.load(f)

In [7]:
dataset = ImageDataset('/data/PACS', test_set_map, transform=transform)
loader = DataLoader(dataset, batch_size=BS, shuffle=False)

### 0. Embed

In [8]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [9]:
weights = get_backbone_from_ckpt("/home/yasin/repos/dispatch_smol/notebooks/data/r50_ms.ckpt")

In [10]:
model = resnet50()
model.load_state_dict(weights, strict=False)
model.fc = torch.nn.Identity()
_ = model.eval()

In [11]:
X=[]; Y=[]; D=[]
with torch.no_grad():
    for batch in tqdm(loader):
        x, y, d = batch['image'], batch['label'], batch['domain']
        
        y = torch.tensor(list(map(classes.get, y)))
        d = torch.tensor(list(map(domains.get, d)))
        
        X.append(model(x))
        Y.append(y)
        D.append(d)
X = torch.cat(X, dim=0).contiguous(); Y = torch.cat(Y, dim=0).contiguous(); D = torch.cat(D, dim=0).contiguous()

100%|██████████| 16/16 [00:07<00:00,  2.22it/s]


### 1. Group by Y and D and Exclude same group

In [12]:
q_idx = 987
q = X[q_idx:q_idx+1]

In [13]:
Y[q_idx], D[q_idx]

(tensor(1), tensor(0))

In [14]:
valid = ~((D == D[q_idx]) & (Y == Y[q_idx])).squeeze()

In [15]:
X_, Y_, D_ = X[valid], Y[valid], D[valid]

### 2. fais k-nn

In [16]:
index = faiss.IndexFlatIP(emb_dim)

In [17]:
index.add(X_)

In [18]:
_, neighbors = index.search(q, 1000)

In [19]:
Y_[neighbors]

tensor([[5, 5, 5, 0, 5, 5, 5, 5, 0, 5, 0, 6, 0, 5, 5, 5, 5, 0, 4, 5, 5, 2, 5, 5,
         3, 5, 1, 2, 0, 2, 5, 2, 3, 0, 2, 2, 5, 5, 5, 2, 2, 2, 5, 3, 5, 5, 2, 1,
         5, 6, 5, 5, 5, 2, 5, 2, 0, 1, 6, 5, 1, 2, 2, 5, 2, 2, 3, 5, 1, 5, 6, 1,
         5, 2, 2, 3, 5, 2, 0, 0, 6, 5, 2, 2, 3, 6, 1, 3, 4, 5, 1, 0, 5, 5, 3, 1,
         2, 0, 4, 5, 2, 3, 2, 6, 1, 2, 4, 2, 5, 4, 1, 2, 2, 6, 4, 3, 5, 2, 0, 6,
         0, 6, 1, 5, 0, 5, 0, 5, 0, 1, 1, 5, 6, 5, 2, 0, 1, 1, 1, 6, 1, 6, 1, 4,
         6, 6, 1, 1, 0, 4, 0, 1, 6, 4, 2, 1, 6, 1, 1, 0, 0, 2, 5, 0, 5, 1, 3, 6,
         5, 0, 2, 4, 0, 6, 1, 4, 1, 0, 4, 1, 0, 5, 0, 6, 1, 1, 5, 2, 2, 6, 5, 0,
         6, 6, 0, 5, 5, 4, 2, 0, 4, 4, 2, 5, 3, 0, 5, 1, 5, 2, 0, 6, 5, 6, 5, 0,
         6, 6, 1, 5, 0, 5, 4, 0, 4, 6, 6, 4, 0, 2, 5, 6, 0, 6, 0, 2, 5, 4, 5, 4,
         2, 6, 4, 6, 0, 0, 0, 0, 5, 2, 2, 5, 6, 6, 6, 0, 5, 6, 6, 6, 4, 6, 2, 0,
         6, 6, 0, 6, 0, 5, 0, 4, 5, 2, 4, 6, 2, 0, 5, 4, 0, 5, 0, 4, 0, 0, 6, 0,
         6, 2, 0, 6, 1, 4, 0

In [20]:
D_[neighbors]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 3, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 3,
         1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 3, 0, 0, 0, 1, 1, 0, 0, 0, 0,
         1, 0, 1, 0, 0, 0, 0, 1, 0, 3, 1, 1, 0, 0, 3, 0, 1, 1, 1, 0, 1, 0, 3, 1,
         0, 1, 3, 1, 0, 0, 0, 1, 0, 0, 1, 3, 1, 1, 1, 0, 0, 1, 0, 0, 1, 3, 1, 0,
         0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 2, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 1, 0,
         1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 1, 1, 3, 0, 0, 1, 0, 0, 0,
         1, 0, 1, 1, 2, 0, 3, 0, 1, 0, 0, 3, 1, 3, 1, 0, 0, 0, 0, 0, 1, 2, 1, 0,
         1, 0, 0, 0, 1, 2, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 3, 0, 1, 1,
         0, 0, 1, 1, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 2, 3, 3, 1, 0, 1,
         0, 0, 1, 1, 2, 1, 0