In [None]:
#|default_exp embeddings

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#|hide
from nbdev.showdoc import *

In [None]:
#| export

from clip_plot.utils import timestamp, clean_filename
from clip_plot.images import image_to_array, Image

from pathlib import Path

import torch
import timm

### Silence tensorflow
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# import tensorflow as tf

# from tensorflow.keras.applications import InceptionV3
# from tensorflow.keras.models import Model
# from tensorflow.keras.applications.inception_v3 import preprocess_input

from tqdm.auto import tqdm
import numpy as np

# Create and/or load embeddings

In [None]:
#| export

def timm_embed_model(model_name: str):
    '''
    Load model and image transform to create embeddings
    Reference: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled

    input:          model name as found in timm documentation
    return tuple:   pre-trained embedding model,
                    transform function to prep images for inference
    '''

    m = timm.create_model(model_name, pretrained=True, num_classes=0)
    m.eval()

    # Reference on transform: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled
    t = timm.data.create_transform(
        **timm.data.resolve_data_config(m.pretrained_cfg)
    )
    return m, t

def timm_embed(img, model, transform) -> np.ndarray:
    '''
    apply transform to image and run inference on it to generate an embedding

    input:      img: Pillow image or similar
                model: Torch model
                transform: Torch image transformation pipeline to match how model was trained
    returns: embedding vector as 1D numpy array
    '''
    emb = model(transform(img).unsqueeze(0))
    return emb.detach().numpy().squeeze()

def get_timm_embeds(imageEngine, model_name: str, **kwargs):
    '''
    Create embedding vectors for input images using a pre-trained model from timm
    '''
    vector_dir = Path(kwargs["out_dir"]) / "image-vectors" / "inception"
    vector_dir.mkdir(exist_ok=True, parents=True)

    torch.manual_seed(kwargs["seed"])

    print(timestamp(), f"Creating embeddings using {model_name}")
    embeds = []

    model, transform = timm_embed_model(model_name)

    for img in tqdm(imageEngine, total=imageEngine.count):
        embed_path = vector_dir / (clean_filename(img.path) + ".npy")
        if embed_path.exists() and kwargs["use_cache"]:
            vec = np.load(embed_path)
        else:
            emb = timm_embed(img.original.resize((299, 299)), model, transform)
            np.save(embed_path, emb)
        embeds.append(emb)
    return np.array(embeds)

In [None]:
#| export

# def get_inception_vectors(imageEngine, **kwargs):
#     """Create and return Inception vector representation of Image() instances"""

#     vector_dir = Path(kwargs["out_dir"]) / "image-vectors" / "inception"
#     vector_dir.mkdir(exist_ok=True, parents=True)
#     base = InceptionV3(
#         include_top=True,
#         weights="imagenet",
#     )
#     model = Model(inputs=base.input, outputs=base.get_layer("avg_pool").output)
#     tf.random.set_seed(kwargs["seed"])

#     print(timestamp(), "Creating Inception vectors")
#     vecs = []   

#     for img in tqdm(imageEngine, total=imageEngine.count):
#         vector_path = vector_dir / (clean_filename(img.path) + ".npy")
#         if vector_path.exists() and kwargs["use_cache"]:
#             vec = np.load(vector_path)
#         else:
#             img_processed = preprocess_input(image_to_array(img.original.resize((299, 299))))
#             vec = model.predict(np.expand_dims(img_processed, 0), verbose = 0).squeeze()
#             np.save(vector_path, vec)
#         vecs.append(vec)
#     return np.array(vecs)

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()