In [1]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Feature extraction

In [2]:
import glob
import os
import numpy as np
import pickle

from PIL import Image, ImageFile
from omegaconf import OmegaConf
from pathlib import PurePosixPath, Path

from config.init import create_baseconfig_from_checkpoint
from model.lgffem import LGFFEM

from tqdm import tqdm

#import umap
import umap.umap_ as umap

import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
from natsort import natsorted
import itertools

%matplotlib inline
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})

In [3]:
import torch
from torchvision.transforms import v2

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def pil_loader(path):
    # to avoid crashing for truncated (corrupted images)
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # open path as file to avoid ResourceWarning 
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [5]:
def umap_embs_from_checkpoint_kimia(l_checkpoint: list[str]):
    umap_proj_embs = []
    
    for path_checkpoint in l_checkpoint:
        
        # Create model from checkpoint
        print(path_checkpoint)
        path_checkpoint = os.path.join(path_checkpoint)
        checkpoint = torch.load(path_checkpoint)
        
        base_config = create_baseconfig_from_checkpoint(checkpoint)

        embedder = LGFFEM(base_config).eval().to(device)
        match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
        print('[++] Loaded neck weights.', match_n)
        match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
        print('[++] Loaded head weights.', match_h)
        
        img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])
        
        # Get embeddings from Kimia Test imgs
        PATH_TEST = '/thesis/classical/test-patches-kimia/'

        dict_img = {}
        for i in range(0,24):
            path_i = f'/thesis/classical/test-patches-kimia/s{i}'
            dict_img[f's{i}'] = natsorted([path for path in Path(path_i).rglob('*.jpg')], key=str)

        img_l = list(itertools.chain.from_iterable(dict_img.values()))

        KIMIA_CLASSES_IDX = dict([(f's{i}',i) for i in range(24)])
        classes = [KIMIA_CLASSES_IDX[path.name.split('_')[0]] for path in img_l]
        
        im_embs_l = []
        for i in tqdm(img_l):
            im = pil_loader(i)
            im = img_transforms(im).unsqueeze(0).to(device)
            im_embs_l.append(embedder(im).detach().cpu().numpy())

        X = np.stack(im_embs_l, axis=0).squeeze(1)
        
        # Projection 2D of the embbedings
        reducer = umap.UMAP(random_state=42)
        reducer.fit(X)

        proj_embedding = reducer.transform(X)

        # Verify that the result of calling transform is
        # idenitical to accessing the embedding_ attribute
        assert(np.all(proj_embedding == reducer.embedding_))
        
        umap_proj_embs.append((proj_embedding,classes))
        
    return umap_proj_embs

In [6]:
def umap_embs_from_checkpoint_revised(l_checkpoint: list[str], test_dataset = 'roxford5k'):
    umap_proj_embs = []
    
    data_root = Path('/thesis/classical/revisitop/datasets')
    img_root = Path(data_root, f'{test_dataset}/jpg/')
    gnd_fname = Path(data_root, f'{test_dataset}/gnd_{test_dataset}.pkl')
    
    with open(gnd_fname, 'rb') as f:
        cfg = pickle.load(f)

    for path_checkpoint in l_checkpoint:
        
        # Create model from checkpoint
        print(path_checkpoint)
        path_checkpoint = os.path.join(path_checkpoint)
        checkpoint = torch.load(path_checkpoint)
        
        base_config = create_baseconfig_from_checkpoint(checkpoint)

        embedder = LGFFEM(base_config).eval().to(device)
        match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
        print('[++] Loaded neck weights.', match_n)
        match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
        print('[++] Loaded head weights.', match_h)
        
        img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])

        # Extraction features dataset
        #img_l = list(data_root.glob('*.jpg'))
        img_l = [Path(img_root, f'{fn}.jpg') for fn in cfg['imlist']]
        landmarks_classes = ['_'.join(p_i.stem.split('_')[:-1]) for p_i in img_l]
        
        im_embs_l = []
        for i in tqdm(img_l):
            im = pil_loader(i)
            im = img_transforms(im).unsqueeze(0).to(device)
            im_embs_l.append(embedder(im).detach().cpu().numpy())
        
        X = np.stack(im_embs_l, axis=0).squeeze(1)
        
        # Projection 2D of the embbedings
        reducer = umap.UMAP(random_state=42)
        reducer.fit(X)

        proj_embedding = reducer.transform(X)

        # Verify that the result of calling transform is
        # idenitical to accessing the embedding_ attribute
        assert(np.all(proj_embedding == reducer.embedding_))
        
        umap_proj_embs.append((proj_embedding, landmarks_classes))

    return umap_proj_embs

In [7]:
def umap_embs_from_checkpoint_pannuke(l_checkpoint: list[str]):
    umap_proj_embs = []
    
    data_root = Path('/thesis/classical/pannuke/images/')

    for path_checkpoint in l_checkpoint:
        
        # Create model from checkpoint
        print(path_checkpoint)
        path_checkpoint = os.path.join(path_checkpoint)
        checkpoint = torch.load(path_checkpoint)
        
        base_config = create_baseconfig_from_checkpoint(checkpoint)

        embedder = LGFFEM(base_config).eval().to(device)
        match_n = embedder.neck.load_state_dict(checkpoint['model_neck_state_dict'], strict = False)
        print('[++] Loaded neck weights.', match_n)
        match_h = embedder.head.load_state_dict(checkpoint['model_head_state_dict'], strict = False)
        print('[++] Loaded head weights.', match_h)
        
        img_transforms = v2.Compose([
                            v2.ToImage(),
                            v2.Resize(size=(224, 224)),
                            v2.ToDtype(torch.float32, scale=True),
                            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                          ])

        # Extraction features dataset
        img_l = list(data_root.glob('*.png'))
        tissue_classes = [fn.stem.split('_')[-1] for fn in img_l]
        
        im_embs_l = []
        for i in tqdm(img_l):
            im = pil_loader(i)
            im = img_transforms(im).unsqueeze(0).to(device)
            im_embs_l.append(embedder(im).detach().cpu().numpy())
        
        X = np.stack(im_embs_l, axis=0).squeeze(1)
        
        # Projection 2D of the embbedings
        reducer = umap.UMAP(random_state=42)
        reducer.fit(X)

        proj_embedding = reducer.transform(X)

        # Verify that the result of calling transform is
        # idenitical to accessing the embedding_ attribute
        assert(np.all(proj_embedding == reducer.embedding_))
        
        umap_proj_embs.append((proj_embedding, tissue_classes))

    return umap_proj_embs

In [8]:
final_models = ['/thesis/checkpoint/BASE/A_BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth',
                 '/thesis/checkpoint/BASE/B1_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth',
                 '/thesis/checkpoint/BASE/B2_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth',
                 '/thesis/checkpoint/BASE/B3_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth',
                 '/thesis/checkpoint/BASE/C_BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth'
               ]

In [9]:
list_projs_kimia = umap_embs_from_checkpoint_kimia(final_models)

/thesis/checkpoint/BASE/A_BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:39<00:00, 33.69it/s]


/thesis/checkpoint/BASE/B1_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.61it/s]


/thesis/checkpoint/BASE/B2_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.71it/s]


/thesis/checkpoint/BASE/B3_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.72it/s]


/thesis/checkpoint/BASE/C_BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 1325/1325 [00:38<00:00, 34.33it/s]


In [10]:
list_projs_roxford5k = umap_embs_from_checkpoint_revised(final_models)

/thesis/checkpoint/BASE/A_BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 4993/4993 [02:26<00:00, 34.12it/s]


/thesis/checkpoint/BASE/B1_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 4993/4993 [02:34<00:00, 32.36it/s]


/thesis/checkpoint/BASE/B2_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 4993/4993 [02:32<00:00, 32.80it/s]


/thesis/checkpoint/BASE/B3_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 4993/4993 [02:35<00:00, 32.17it/s]


/thesis/checkpoint/BASE/C_BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 4993/4993 [02:37<00:00, 31.62it/s]


In [11]:
list_projs_rparis6k = umap_embs_from_checkpoint_revised(final_models, test_dataset='rparis6k')

/thesis/checkpoint/BASE/A_BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 6322/6322 [03:18<00:00, 31.90it/s]


/thesis/checkpoint/BASE/B1_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 6322/6322 [03:15<00:00, 32.33it/s]


/thesis/checkpoint/BASE/B2_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 6322/6322 [03:11<00:00, 32.98it/s]


/thesis/checkpoint/BASE/B3_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 6322/6322 [03:13<00:00, 32.64it/s]


/thesis/checkpoint/BASE/C_BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 6322/6322 [03:14<00:00, 32.50it/s]


In [12]:
list_projs_pannuke = umap_embs_from_checkpoint_pannuke(final_models)

/thesis/checkpoint/BASE/A_BASE-EMB-11_convnextv2-02_neck_512_3-00_head_A-epoch10--in1k-224.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 7901/7901 [03:20<00:00, 39.40it/s]


/thesis/checkpoint/BASE/B1_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch200.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 7901/7901 [03:18<00:00, 39.77it/s]


/thesis/checkpoint/BASE/B2_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch198.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 7901/7901 [03:26<00:00, 38.34it/s]


/thesis/checkpoint/BASE/B3_BASE-EMB-PAN-11_convnextv2-02_neck_512_3-00_head_A-epoch199.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 7901/7901 [03:20<00:00, 39.50it/s]


/thesis/checkpoint/BASE/C_BASE-EMB-KIMIA-11_convnextv2-02_neck_512_3-00_head_A-epoch499.pth
[++] Loaded neck weights. <All keys matched successfully>
[++] Loaded head weights. <All keys matched successfully>


100%|██████████| 7901/7901 [03:21<00:00, 39.28it/s]


In [17]:
for (emb_proj,emb_c),path_checkpoint in zip(list_projs_kimia, final_models):
    
    fn_fig = Path(path_checkpoint).stem

    fig = plt.figure()
    ax = fig.add_subplot(111)

    sca = ax.scatter(emb_proj[:, 0], emb_proj[:, 1], c=emb_c, cmap='nipy_spectral', s=11)

    cbar = fig.colorbar(sca, boundaries=np.arange(25)-0.5, pad=.02, fraction=.12, aspect=20, shrink=0.9)    
    cbar.set_ticks(np.arange(24))
    cbar.set_label('ID of classes',size=18)
    cbar.ax.tick_params(labelsize=13)    

    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])

    fig.savefig(f'figures/emb/kimia_{fn_fig}.eps', format='eps', dpi=1200)
#     plt.show()
    plt.close(fig)

In [18]:
for (emb_proj,emb_c), path_checkpoint in zip(list_projs_roxford5k, final_models):
    
    id_classes = {x:i for i,x in enumerate(set(emb_c))}
    classes = list(id_classes.keys())
    
    fn_fig = Path(path_checkpoint).stem

    fig = plt.figure()
    ax = fig.add_subplot(111)

    sca = ax.scatter(emb_proj[:, 0], emb_proj[:, 1], c=[id_classes[id_i] for id_i in emb_c],
                     cmap='nipy_spectral', s=11)

    cbar = fig.colorbar(sca, boundaries=np.arange(len(classes)+1)-0.5, pad=.02, fraction=.12, aspect=20, shrink=0.9)    
    cbar.set_ticks(np.arange(len(classes)))
    cbar.set_label('Landmarks',size=18)
    cbar.ax.tick_params(labelsize=13)    
    cbar.ax.set_yticklabels(classes)

    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])

    fig.savefig(f'figures/emb/roxford5k_{fn_fig}.eps', format='eps', dpi=1200)
#     plt.show()
    plt.close(fig)

In [19]:
for (emb_proj,emb_c), path_checkpoint in zip(list_projs_rparis6k, final_models):
    
    id_classes = {x:i for i,x in enumerate(set(emb_c))}
    classes = list(id_classes.keys())
    
    fn_fig = Path(path_checkpoint).stem

    fig = plt.figure()
    ax = fig.add_subplot(111)

    sca = ax.scatter(emb_proj[:, 0], emb_proj[:, 1], c=[id_classes[id_i] for id_i in emb_c],
                     cmap='nipy_spectral', s=11)

    cbar = fig.colorbar(sca, boundaries=np.arange(len(classes)+1)-0.5, pad=.02, fraction=.12, aspect=20, shrink=0.9)    
    cbar.set_ticks(np.arange(len(classes)))
    cbar.set_label('Landmarks',size=18)
    cbar.ax.tick_params(labelsize=13)    
    cbar.ax.set_yticklabels(classes)

    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])

    fig.savefig(f'figures/emb/rparis6k_{fn_fig}.eps', format='eps', dpi=1200)
#     plt.show()
    plt.close(fig)

In [20]:
for (emb_proj,emb_c), path_checkpoint in zip(list_projs_pannuke, final_models):
    
    id_classes = {x:i for i,x in enumerate(set(emb_c))}
    classes = list(id_classes.keys())
    
    fn_fig = Path(path_checkpoint).stem

    fig = plt.figure()
    ax = fig.add_subplot(111)

    sca = ax.scatter(emb_proj[:, 0], emb_proj[:, 1], c=[id_classes[id_i] for id_i in emb_c],
                     cmap='nipy_spectral', s=11)

    cbar = fig.colorbar(sca, boundaries=np.arange(len(classes)+1)-0.5, pad=.02, fraction=.12, aspect=20, shrink=0.9)    
    cbar.set_ticks(np.arange(len(classes)))
    cbar.set_label('Landmarks',size=18)
    cbar.ax.tick_params(labelsize=13)    
    cbar.ax.set_yticklabels(classes)

    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])

    fig.savefig(f'figures/emb/pannuke_{fn_fig}.eps', format='eps', dpi=1200)
#     plt.show()
    plt.close(fig)