In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Feature extraction

In [2]:
import glob
import os
import numpy as np

from PIL import Image, ImageFile
from omegaconf import OmegaConf

from config.init import create_baseconfig_from_checkpoint
from model.lgffem import LGFFEM

from tqdm import tqdm

import umap

import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
from natsort import natsorted
import itertools

%matplotlib inline
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})

In [3]:
import torch
from torchvision.transforms import v2

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def pil_loader(path):
    # to avoid crashing for truncated (corrupted images)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # open path as file to avoid ResourceWarning 
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [5]:
def umap_embs_from_checkpoint(l_checkpoint: list[str]):
    umap_proj_embs = []
    
    for path_checkpoint in l_checkpoint:
        
        # Create model from checkpoint
        print(path_checkpoint)
        path_checkpoint = os.path.join(path_checkpoint)
        checkpoint = torch.load(path_checkpoint)
        
        base_config = create_baseconfig_from_checkpoint(checkpoint)

        embedder = LGFFEM(base_config).eval().to(device)
        match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
        print('[++] Loaded neck weights.', match_n)
        match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
        print('[++] Loaded head weights.', match_h)
        
        img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])
        
        # Get embeddings from Kimia Test imgs
        PATH_TEST = '/thesis/classical/test-patches-kimia/'

        dict_img = {}
        for i in range(0,24):
            path_i = f'/thesis/classical/test-patches-kimia/s{i}'
            dict_img[f's{i}'] = natsorted([path for path in Path(path_i).rglob('*.jpg')], key=str)

        img_l = list(itertools.chain.from_iterable(dict_img.values()))

        KIMIA_CLASSES_IDX = dict([(f's{i}',i) for i in range(24)])
        classes = [KIMIA_CLASSES_IDX[path.name.split('_')[0]] for path in img_l]
        
        im_embs_l = []
        for i in tqdm(img_l):
            im = pil_loader(i)
            im = img_transforms(im).unsqueeze(0).to(device)
            im_embs_l.append(embedder(im).detach().cpu().numpy())

        X = np.stack(im_embs_l, axis=0).squeeze(1)
        
        # Projection 2D of the embbedings
        reducer = umap.UMAP(random_state=42)
        reducer.fit(X)

        proj_embedding = reducer.transform(X)

        # Verify that the result of calling transform is
        # idenitical to accessing the embedding_ attribute
        assert(np.all(proj_embedding == reducer.embedding_))
        
        umap_proj_embs.append((proj_embedding,classes))
        
    return umap_proj_embs

In [6]:
final_models = ['/thesis/checkpoint/BASE/BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth',
                '/thesis/checkpoint/BASE/BASE-B-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth',
                '/thesis/checkpoint/BASE/BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth'
               ]

In [7]:
list_projs = umap_embs_from_checkpoint(final_models)

/thesis/checkpoint/BASE/BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:39<00:00, 33.78it/s]


/thesis/checkpoint/BASE/BASE-B-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:41<00:00, 31.71it/s]


/thesis/checkpoint/BASE/BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:42<00:00, 30.90it/s]


In [8]:
for (emb_proj,emb_c),path_checkpoint in zip(list_projs, final_models):
    
    fn_fig = Path(path_checkpoint).stem

    fig, ax = plt.subplots()
    
    sca = ax.scatter(emb_proj[:, 0], emb_proj[:, 1], c=emb_c, cmap='nipy_spectral', s=10)
    fig.colorbar(sca, ax=ax, boundaries=np.arange(25)-0.5, label='ID of classes').set_ticks(np.arange(24))
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    fig.savefig(f'figures/{fn_fig}.eps', format='eps', dpi=1200)
#     plt.show()
    plt.close(fig)