In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Feature extraction

In [2]:
import glob
import os
import numpy as np

from PIL import Image, ImageFile
from omegaconf import OmegaConf

from config.init import create_baseconfig_from_checkpoint
from model.lgffem import LGFFEM

from tqdm import tqdm

In [3]:
import torch
from torchvision.transforms import v2

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Create model from checkpoint

In [4]:
path_checkpoint = os.path.join('/thesis/checkpoint/BASE/BASE-B-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch50.pth')
checkpoint = torch.load(path_checkpoint)

In [5]:
base_config = create_baseconfig_from_checkpoint(checkpoint)

embedder = LGFFEM(base_config).eval().to(device)
match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
print('[++] Loaded neck weights.', match_n)
match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
print('[++] Loaded head weights.', match_h)

[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


In [6]:
img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])

## Get embeddings from Kimia Test imgs

In [7]:
from pathlib import Path
from natsort import natsorted
import itertools

PATH_TEST = '/thesis/classical/test-patches-kimia/'

# img_l = [path for path in Path(PATH_TEST).rglob('*.jpg')]
# label_l = [path.name for path in img_l]

dict_img = {}
for i in range(0,24):
    path_i = f'/thesis/classical/test-patches-kimia/s{i}'
    dict_img[f's{i}'] = natsorted([path for path in Path(path_i).rglob('*.jpg')], key=str)
    
img_l = list(itertools.chain.from_iterable(dict_img.values()))

In [8]:
def pil_loader(path):
    # to avoid crashing for truncated (corrupted images)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # open path as file to avoid ResourceWarning 
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [9]:
im_embs_l = []
for i in tqdm(img_l):
    im = pil_loader(i)
    im = img_transforms(im).unsqueeze(0).to(device)
    im_embs_l.append(embedder(im).detach().cpu().numpy())

X = np.stack(im_embs_l, axis=0).squeeze(1)

100%|██████████| 1325/1325 [00:38<00:00, 34.13it/s]


## Create index and retrieve

In [10]:
import faiss

In [11]:
size, dim = X.shape

index = faiss.IndexFlatL2(dim) ##euclidean
# index = faiss.IndexFlatIP(dim) ##cosine

index.add(X)

sims, ranks_f = index.search(X, size)

In [12]:
eta_p = .0
eta_w = .0

K = 5

i_low = 0
for s in range(0,24):
    v = dict_img[f's{s}']
    
    i_hight = i_low + len(v)
    
    labels = set([path.name for path in img_l[i_low:i_hight]])
    retrieve = set([path.name for path in [img_l[idx] for idx in ranks_f[i_low,:i_hight]]]) ## k=1
    
    cardi = len(retrieve&labels)
    
    eta_p+=cardi
    eta_w+=(cardi/len(v))

    i_low = i_hight
    
eta_p = eta_p/len(img_l)
eta_w = eta_w/24
eta_tot = eta_p*eta_w

print(round(eta_tot*100,2),round(eta_p*100,2),round(eta_w*100,2))

75.11 86.19 87.14
