In [1]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Feature extraction

In [2]:
import glob
import os
import numpy as np

from PIL import Image, ImageFile
from omegaconf import OmegaConf

from config.init import create_baseconfig_from_checkpoint
from model.lgffem import LGFFEM

import faiss
from pathlib import Path
from natsort import natsorted
import itertools

from tqdm import tqdm

In [3]:
import torch
from torchvision.transforms import v2

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def pil_loader(path):
    # to avoid crashing for truncated (corrupted images)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # open path as file to avoid ResourceWarning 
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [5]:
def kimia_metrics_from_checkpoint(l_checkpoint: list[str]):

    for path_checkpoint in l_checkpoint:
        
        # Create model from checkpoint
        print(path_checkpoint)
        path_checkpoint = os.path.join(path_checkpoint)
        checkpoint = torch.load(path_checkpoint)
        
        base_config = create_baseconfig_from_checkpoint(checkpoint)

        embedder = LGFFEM(base_config).eval().to(device)
        match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
        print('[++] Loaded neck weights.', match_n)
        match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
        print('[++] Loaded head weights.', match_h)
        
        img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])
        
        # Extraction features dataset
        PATH_TEST = '/thesis/classical/test-patches-kimia/'
        dict_img = {}
        for i in range(0,24):
            path_i = f'/thesis/classical/test-patches-kimia/s{i}'
            dict_img[f's{i}'] = natsorted([path for path in Path(path_i).rglob('*.jpg')], key=str)

        img_l = list(itertools.chain.from_iterable(dict_img.values()))
        
        im_embs_l = []
        for i in tqdm(img_l):
            im = pil_loader(i)
            im = img_transforms(im).unsqueeze(0).to(device)
            im_embs_l.append(embedder(im).detach().cpu().numpy())

        X = np.stack(im_embs_l, axis=0).squeeze(1)
        
        ## Create index and retrieve
        size, dim = X.shape
        index = faiss.IndexFlatL2(dim) ##euclidean
        # index = faiss.IndexFlatIP(dim) ##cosine
        index.add(X)
        sims, ranks_f = index.search(X, size)
        
        eta_p = .0
        eta_w = .0

        K = 5

        i_low = 0
        for s in range(0,24):
            v = dict_img[f's{s}']

            i_hight = i_low + len(v)

            labels = set([path.name for path in img_l[i_low:i_hight]])
            retrieve = set([path.name for path in [img_l[idx] for idx in ranks_f[i_low,:i_hight]]]) ## k=1

            cardi = len(retrieve&labels)

            eta_p+=cardi
            eta_w+=(cardi/len(v))

            i_low = i_hight

        eta_p = eta_p/len(img_l)
        eta_w = eta_w/24
        eta_tot = eta_p*eta_w
        
        print(f'>> {Path(path_checkpoint).stem}')
        print(f"eta_tot: {round(eta_tot*100,2)} ; eta_p: {round(eta_p*100,2)} ; eta_w {round(eta_w*100,2)}")
        print('\n\n')

In [6]:
kimia_m_checks = [
                '/thesis/checkpoint/BASE/BASE-A-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth',
                '/thesis/checkpoint/BASE/BASE-B-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth',
                '/thesis/checkpoint/BASE/BASE-C-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth'
               ]

In [7]:
kimia_metrics_from_checkpoint(kimia_m_checks)

/thesis/checkpoint/BASE/BASE-A-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:39<00:00, 33.62it/s]


>> BASE-A-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200
eta_tot: 58.0 ; eta_p: 75.25 ; eta_w 77.08



/thesis/checkpoint/BASE/BASE-B-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:39<00:00, 33.53it/s]


>> BASE-B-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198
eta_tot: 61.33 ; eta_p: 77.36 ; eta_w 79.28



/thesis/checkpoint/BASE/BASE-C-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.33it/s]


>> BASE-C-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199
eta_tot: 57.57 ; eta_p: 75.25 ; eta_w 76.51





In [8]:
final_models = ['/thesis/checkpoint/BASE/BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth',
                '/thesis/checkpoint/BASE/BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth'
               ]

In [9]:
kimia_metrics_from_checkpoint(final_models)

/thesis/checkpoint/BASE/BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.02it/s]


>> BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224
eta_tot: 53.6 ; eta_p: 72.08 ; eta_w 74.37



/thesis/checkpoint/BASE/BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.01it/s]


>> BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499
eta_tot: 98.87 ; eta_p: 99.4 ; eta_w 99.47



