# Feature extraction using VAE's latent space `z`

In [1]:
import glob
import os
import numpy as np

from PIL import Image
from omegaconf import OmegaConf

In [2]:
import torch

import albumentations as A
from albumentations.pytorch import ToTensorV2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from cabifpn.utils.getter import IntermediateLayerGetter
from cabifpn.utils.datasets import CocoDetectionV2, LVISDetection

from model.neck_vae import NeckVAE
from utils import _create_model, _create_config

## Base configuration

### Create VAE model

In [4]:
PATH_VAE_CHECKPOINT = '/thesis/checkpoint/20240201_1708_VAE_convnext_small_cabifpn_12.pth'

In [5]:
checkpoint_vae = torch.load(os.path.join(PATH_VAE_CHECKPOINT))
vae_config = OmegaConf.create(checkpoint_vae['vae_config'])

In [6]:
# === GLOBAL VARIABLES ===
## Create the dict with layer names neck
set_neck_indices = vae_config.NECK_INDICES
_RETURN_NECK_NODES = dict([(f'backbone.neck.neck.neck_layer_{idx}.proj_p4_2', f'p4_2_l{idx}') for idx in set_neck_indices])

# === Create and load extractor model ===
print(f'[+] Loading extractor model ...')

## Load extractor model
base_config, checkpoint = _create_config(os.path.join('/thesis/checkpoint/',checkpoint_vae['fn_checkpoint']))
model_extractor = _create_model(base_config, checkpoint).to(device).eval()

## freeze the extractor model
for param in model_extractor.parameters():
    param.requires_grad = False

## Define the hooker neck's layers fuction
mid_extractor_getter = IntermediateLayerGetter(model_extractor,
                                               return_layers=_RETURN_NECK_NODES,
                                               keep_output=False)
# === Create NECK VAE base model ===
print('[+] Building the NECK VAE base model ...')
print(f'[++] Using VAE configs : total VAEs->{len(set_neck_indices)} | in_channels->{vae_config.IN_CHANNELS} | in_shape->{vae_config.IN_SHAPE} | latent_dim->{vae_config.LATENT_DIM}.')
base_model = NeckVAE(len(set_neck_indices), vae_config.IN_CHANNELS, vae_config.IN_SHAPE, vae_config.LATENT_DIM).to(device)

base_model.load_state_dict(checkpoint_vae['model_state_dict'])

model_extractor.eval()
base_model.eval()

print('[+] Ready !')

[+] Loading extractor model ...
[+] Loading checkpoint...
[+] Ready !
[+] Preparing base configs...
[+] Ready !
[i+] Configuring backbone and neck models with variables: {'BACKBONE': {'MODEL_NAME': 'convnext_small', 'OUT_INDICES': [0, 1, 2, 3]}, 'NECK': {'MODEL_NAME': 'cabifpn', 'IN_CHANNELS': [96, 192, 384, 768], 'NUM_CHANNELS': 256, 'NUM_LAYERS': 3}}
[i+] Ready !
[i+] Building the base model with MaskRCNN head ...
[++] Numbers of classes: 91
[+] Loading checkpoint...
[++] All keys matched successfully
[+] Ready. last_epoch: 12 - last_loss: 1.0497519969940186
[i+] Ready !
[+] Building the NECK VAE base model ...
[++] Using VAE configs : total VAEs->3 | in_channels->256 | in_shape->[25, 25] | latent_dim->256.
[+] Ready !


In [7]:
## Albumentations to use
pre_transform = A.Compose([A.Resize(base_config.DATASET.IMAGE_SIZE, base_config.DATASET.IMAGE_SIZE),
                             A.Normalize(mean=base_config.DATASET.MEAN,
                                         std=base_config.DATASET.STD,
                                         max_pixel_value=255.0),
                             ToTensorV2()
                            ]
                           )

### Example: Extracion latent space `z`

In [8]:
input_img = np.random.randint(low=0,high=255,size=(500,500,3), dtype='uint8')
input_img = pre_transform(image=input_img)['image'].unsqueeze(0).to(device)

input_img.shape

torch.Size([1, 3, 224, 224])

In [9]:
neck_layers_vector, _ = mid_extractor_getter(input_img)

In [10]:
%%timeit

return_layers_vae = {'latent_z':'latent_z'}

for i in range(len(base_model.vaes)):
    vae_i = base_model.vaes[i]
    layer_neck_i = neck_layers_vector[f'p4_2_l{i}']
    
    vae_extractor_getter = IntermediateLayerGetter(vae_i,
                                                   return_layers=return_layers_vae,
                                                   keep_output=False)
    
    latent_z_i, _ = vae_extractor_getter(layer_neck_i)
    
#     print(f'Node p4_2_l{i} shape:',latent_z_i['latent_z'].squeeze(0).shape)

3.63 ms ± 41.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Extraction features dataset: `oxford5k`

In [11]:
DATASET_IMG = '/thesis/classical/cbir/oxford5k/test/img'

In [12]:
img_l = glob.glob(os.path.join(DATASET_IMG, "*"))

In [13]:
return_layers_vae = {'latent_z':'latent_z'}

dict_latent_z = dict()
dict_latent_z['fn_id'] = []

for img_i in img_l:
    
    fn_i = os.path.splitext(os.path.basename(img_i))[0]
    dict_latent_z['fn_id'].append(fn_i)
    
    img_raw = Image.open(img_i)
    t_img = pre_transform(image=np.asarray(img_raw))['image'].unsqueeze(0).to(device)
    
    neck_layers_vector, _ = mid_extractor_getter(t_img)

    for i in range(len(base_model.vaes)):
        vae_i = base_model.vaes[i]
        layer_neck_i = neck_layers_vector[f'p4_2_l{i}']

        vae_extractor_getter = IntermediateLayerGetter(vae_i,
                                                       return_layers=return_layers_vae,
                                                       keep_output=False)

        latent_z_i, _ = vae_extractor_getter(layer_neck_i)
        
        if f'p4_2_l{i}' not in dict_latent_z:
            dict_latent_z[f'p4_2_l{i}'] = []
        
        dict_latent_z[f'p4_2_l{i}'].append(latent_z_i['latent_z'].squeeze(0).detach().cpu())
        

In [14]:
PATH_EMB = os.path.join('/thesis/embedding')

In [15]:
for i in range(len(base_model.vaes)):
    dict_latent_z[f'p4_2_l{i}'] = torch.stack(dict_latent_z[f'p4_2_l{i}'])
    
    fn_emb_i = f'oxford5k-VAE_convnext_small_cabifpn_12-p4_2_l{i}.pt'
    torch.save(dict_latent_z[f'p4_2_l{i}'], os.path.join(PATH_EMB, fn_emb_i))
    
index_body = ' '.join(dict_latent_z['fn_id'])
ff = open(os.path.join(PATH_EMB, f'oxford5k-VAE_convnext_small_cabifpn_12-index.txt'),'w')
ff.write(index_body)
ff.close()