## Image Embeddings using EfficientNet

### Preface

**Summary**
- Frameworks: Tensor Flow and FAISS
- Area: Image Embeddings
- Pre-Trained Model: EfficientNet
- Algorithms: KNN

**Acknowledgements**
- Blog Article: https://rom1504.medium.com/image-embeddings-ed1b194d113e
- Code File: https://colab.research.google.com/github/rom1504/image_embeddings/blob/master/notebooks/using_the_lib.ipynb
- Code Repo: https://github.com/rom1504/image_embeddings
- Code File: https://github.com/rom1504/image_embeddings/blob/master/notebooks/from_scratch.ipynb

### Initialization: Download and Save

**Packages**

In [None]:
import numpy as pkg_num
import os as pkg_os
import time as pkg_time
import math as pkg_math
import random as pkg_random
import warnings as pkg_warnings
import matplotlib.pyplot as pkg_mplot
import matplotlib.image as pkg_mp_image
import pathlib as pkg_pathlib
import shutil as pkg_shutil

In [None]:
# Run Start Time
run_start_time = pkg_time.time()

In [None]:
# Miscellaneous
%matplotlib inline

# Tensor Flow is optimized for CUDA-GPU, 
# But we are running on a different GPU or simply CPU, so ignore warnings
pkg_warnings.filterwarnings(action="ignore")

# OpenMP library may be linked and loaded multiple times due to transitive dependencies
# Informed that multiple instances of OpenMP library is OK for us
# That error goes away with following setting
pkg_os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [None]:
import pandas as pkg_pandas
import absl.logging as pkg_logging
import IPython.display as pkg_disp
import ipywidgets as pkg_widgets
import PIL as pkg_pil
import PIL.Image as pkg_pil_image
import pyarrow as pkg_arrow
import pyarrow.parquet as pkg_parquet
import tensorflow as pkg_tf
import tensorflow_datasets as pkg_tfds
import tensorflow_datasets.core.dataset_utils as pkg_tfds_utils
import tensorflow_datasets.core.features as pkg_tfds_features
import efficientnet as pkg_effinet
import efficientnet.preprocessing as pkg_effinet_preprocessing
import efficientnet.tfkeras as pkg_tfkeras
from efficientnet.tfkeras import EfficientNetB6 as EfficientNetRef
import faiss as pkg_faiss

**Common**

In [None]:
# Performance related settings
AUTOTUNE = pkg_tf.data.AUTOTUNE

# Image Size (Target)
image_height = 180
image_width = 180
image_size = (image_height, image_width)

# Path stuff
repo_root_dirpath = pkg_pathlib.Path.cwd().parent
temp_root_dirpath = repo_root_dirpath.joinpath(".outputs/.datasets")
data_root_dirpath = repo_root_dirpath.joinpath("data/images")

In [None]:
def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds

def configure_for_batch_performance(ds,batch_size):
  ds = configure_for_performance(ds)
  ds = ds.batch(batch_size=batch_size)
  return ds

def create_dirpath(dirpath):
  # Create the data directory
  dirpath.mkdir(parents=True, exist_ok=True)
  return dirpath

def recreate_dirpath(dirpath):
  # (Re)create the data directory
  pkg_shutil.rmtree(dirpath, ignore_errors=True)
  dirpath.mkdir(parents=True, exist_ok=False)
  return dirpath

def print_dirpath_stats(dirpath):
  checkpoint_time = int(pkg_time.time() - run_start_time)
  print("Directory Stats ({}), at {} seconds: {}".format(\
    dirpath.name, checkpoint_time, dirpath.stat()))

**Load Data**

In [None]:
def download_dataset(dataset_name, data_dirpath):
    ds, ds_info = pkg_tfds.load(name=dataset_name, data_dir=data_dirpath, 
        split='train', with_info=True)
    return ds, ds_info

def save_examples(ds, ds_info, data_dirpath, num_examples = 10, image_key=None):
  """Save images from an image classification dataset.

  Only works with datasets that have 1 image feature and optionally 1 label
  feature (both inferred from `ds_info`). Note the dataset should be unbatched.

  Usage:

  ```python
  ds, ds_info = tfds.load('cifar10', split='train', with_info=True)
  fig = save_examples(ds, ds_info, data_dir)
  ```

  Args:
    ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples
      should not be batched.
    num_examples: `int`. Number of examples to save
    ds_info: The dataset info object to which extract the label and features
      info. Available either through `tfds.load('mnist', with_info=True)` or
      `tfds.builder('mnist').info`
    data_dir: `pathlib.Path`. Where to save images
    image_key: `string`, name of the feature that contains the image. If not
       set, the system will try to auto-detect it.

  Returns:
  """

  if not image_key:
    # Infer the image and label keys
    image_keys = [
        k for k, feature in ds_info.features.items()
        if isinstance(feature, pkg_tfds_features.Image)
    ]

    if not image_keys:
      raise ValueError(
          "Visualisation not supported for dataset `{}`. Was not able to "
          "auto-infer image.".format(ds_info.name))

    if len(image_keys) > 1:
      raise ValueError(
          "Multiple image features detected in the dataset. Using the first one. You can "
          "use `image_key` argument to override. Images detected: %s" %
          (",".join(image_keys)))

    image_key = image_keys[0]

  label_keys = [
      k for k, feature in ds_info.features.items()
      if isinstance(feature, pkg_tfds_features.ClassLabel)
  ]

  label_key = label_keys[0] if len(label_keys) == 1 else None
  if not label_key:
    pkg_logging.info("Was not able to auto-infer label.")

  examples = list(pkg_tfds_utils.as_numpy(ds.take(num_examples)))
  
  # Save the images as files on disk
  for i, ex in enumerate(examples):
    if not isinstance(ex, dict):
      raise ValueError(
          "tensorflow_datasets.show_examples requires examples as `dict`, with the same "
          "structure as `ds_info.features`. It is currently not compatible "
          "with `as_supervised=True`. Received: {}".format(type(ex)))

    # Plot the image
    image = ex[image_key]
    if len(image.shape) != 3:
      raise ValueError(
          "Image dimension should be 3. tensorflow_datasets.show_examples does not support "
          "batched examples or video.")
    _, _, c = image.shape
    if c == 1:
      image = image.reshape(image.shape[:2])
    image = pkg_effinet_preprocessing.center_crop_and_resize(image, 224).astype(pkg_num.uint8)
    im = pkg_pil_image.fromarray(image)
    # Plot the label
    if label_key:
      label = ex[label_key]
      label_str = ds_info.features[label_key].int2str(label)
    else:
      label_str = ""
    filepath = data_dirpath.joinpath("image_{:04d}_{}.jpeg".format(i, label_str))
    im.save(filepath)

In [None]:
def download_and_save_dataset(dataset_name, downloads_dirpath, example_count, images_dirpath):
    create_dirpath(downloads_dirpath)
    ds, ds_info = download_dataset(dataset_name=dataset_name, data_dirpath=downloads_dirpath)
    print_dirpath_stats(downloads_dirpath)

    recreate_dirpath(images_dirpath)
    save_examples(ds, ds_info, images_dirpath, example_count)
    print_dirpath_stats(images_dirpath)

    ds = configure_for_performance(ds)
    fig = pkg_tfds.show_examples(ds, ds_info)
    return ds, ds_info

### Transform: Calculate Embeddings

**Routines**

In [None]:
def list_dirpath(dirpath, pattern="*", shuffle=False):
    return pkg_tf.data.Dataset.list_files(dirpath.as_posix() + "/" + pattern, shuffle=shuffle)

def load_image(filepath):
    image_name = filepath.stem
    image_data = pkg_tf.io.read_file(filepath.as_posix())
    return image_name, image_data

def read_data_from_files(list_ds):
    return list_ds.map(load_image,
        num_parallel_calls=pkg_tf.data.experimental.AUTOTUNE)#.apply(pkg_tf.data.experimental.ignore_errors())

def images_to_embeddings(model, dataset, batch_size):
    return model.predict(dataset.batch(batch_size).map(lambda image_name, image_data: image_data), verbose=1)

def save_embeddings_ds_to_parquet(embeddings, dataset, path):
    embeddings = pkg_arrow.array(embeddings.tolist(), type=pkg_arrow.list_(pkg_arrow.float32()))
    image_names = pkg_arrow.array(dataset.map(lambda image_name, image_data: image_name).as_numpy_iterator())
    table = pkg_arrow.Table.from_arrays([image_names, embeddings], ["image_name", "embedding"])
    pkg_parquet.write_table(table, path)

def compute_save_embeddings(image_dirpath, embeddings_dirpath, num_shards, model, batch_size):
    start = pkg_time.time()
    list_ds = list_dirpath(image_dirpath, "*.jpeg").cache()

    for shard_id in range(0, num_shards):
        shard_list = list_ds.shard(num_shards=num_shards, index=shard_id)
        shard = read_data_from_files(shard_list)
        embeddings = images_to_embeddings(model, shard, batch_size)
        print("Shard " + str(shard_id) + " done after " + str(int(pkg_time.time() - start)) + "s")
        embeddings_filepath = embeddings_dirpath.joinpath("part-{:04d}.parquet".format(shard_id))
        save_embeddings_ds_to_parquet(embeddings, shard, embeddings_filepath)
        print("Shard " + str(shard_id) + " saved after " + str(int(pkg_time.time() - start)) + "s")
    print("Total time : " + str(int(pkg_time.time() - start)))

def infer_images(input_folder_name, output_folder_name, num_shards=100, batch_size=1000):
    model = EfficientNetRef(weights='imagenet', include_top=False, pooling="avg")
    compute_save_embeddings(input_folder_name, output_folder_name, num_shards, model, batch_size)

In [None]:
def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return pkg_tf.train.Feature(int64_list=pkg_tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(pkg_tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return pkg_tf.train.Feature(bytes_list=pkg_tf.train.BytesList(value=[value]))

def serialize_example(image_name, image_data):
    feature = {
        'image_name': _bytes_feature(image_name),
        'image_data': _bytes_feature(image_data)
    }

    example_proto = pkg_tf.train.Example(features=pkg_tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def tf_serialize_example(image_name, image_data):
    tf_string = pkg_tf.py_function(
        serialize_example,
        (image_name, image_data),
        pkg_tf.string)
    return pkg_tf.reshape(tf_string, ())

def load_image_tf(file_path):
    parts = pkg_tf.strings.split(file_path, '/')
    image_name = pkg_tf.strings.split(parts[-1], '.')[0]
    image_data = pkg_tf.io.read_file(file_path)
    return image_name, image_data

def read_image_file_write_tfrecord(files_ds, output_filepath):
    image_ds = files_ds.map(load_image_tf, num_parallel_calls=pkg_tf.data.experimental.AUTOTUNE)
    serialized_features_dataset = image_ds.map(tf_serialize_example, num_parallel_calls=pkg_tf.data.experimental.AUTOTUNE)
    writer = pkg_tf.data.experimental.TFRecordWriter(output_filepath.as_posix())
    writer.write(serialized_features_dataset)

def image_files_to_tfrecords(images_dirpath, tfrecords_dirpath, num_shard):
    list_ds = list_dirpath(images_dirpath, "*.jpeg")
    print(list_ds)
    start = pkg_time.time()
    for shard_id in range(0, num_shard):
        shard_list = list_ds.shard(num_shards=num_shard, index=shard_id)
        tfrecords_filepath = tfrecords_dirpath.joinpath("part-{:04d}.tfrecord".format(shard_id))
        read_image_file_write_tfrecord(shard_list, tfrecords_filepath)
        print("Shard " + str(shard_id) + " saved after " + str(int(pkg_time.time() - start)) + "s")

def _parse_function(example_proto):
    feature_description = {
        'image_name': pkg_tf.io.FixedLenFeature([], pkg_tf.string),
        'image_data': pkg_tf.io.FixedLenFeature([], pkg_tf.string)
    }
    return pkg_tf.io.parse_single_example(example_proto, feature_description)

def preprocess_image(d):
    image_name = d['image_name']
    image_data = d['image_data']
    image_data = pkg_tf.image.decode_jpeg(image_data)
    image_data = pkg_tf.image.convert_image_dtype(image_data, pkg_tf.float32)

    return image_name, image_data

def read_tfrecord(filepath):
    raw_dataset =  pkg_tf.data.TFRecordDataset(filenames=[filepath])
    return raw_dataset \
        .map(_parse_function, num_parallel_calls=pkg_tf.data.experimental.AUTOTUNE) \
        .map(preprocess_image, num_parallel_calls=pkg_tf.data.experimental.AUTOTUNE) \
        .apply(pkg_tf.data.experimental.ignore_errors())

def tfrecords_to_embeddings(tfrecords_dirpath, embeddings_dirpath, model, batch_size):
    tfrecords = [f.numpy().decode("utf-8") for f in list_dirpath(tfrecords_dirpath, "*.tfrecord")]
    start = pkg_time.time()
    for shard_id, tfrecord in enumerate(tfrecords):
        shard = read_tfrecord(tfrecord)
        embeddings = images_to_embeddings(model, shard, batch_size)
        print("Shard " + str(shard_id) + " done after " + str(int(pkg_time.time() - start)) + "s")
        embeddings_filepath = embeddings_dirpath.joinpath("part-{:04d}.parquet".format(shard_id))
        save_embeddings_ds_to_parquet(embeddings, shard, embeddings_filepath)
        print("Shard " + str(shard_id) + " saved after " + str(int(pkg_time.time() - start)) + "s")

**Models**

In [None]:
def calculate_and_save_embeddings(model, images_dirpath, tfrecords_dirpath, embeddings_dirpath, shard_count, batch_size):
    recreate_dirpath(tfrecords_dirpath)
    image_files_to_tfrecords(images_dirpath, tfrecords_dirpath, shard_count)
    print_dirpath_stats(tfrecords_dirpath)

    recreate_dirpath(embeddings_dirpath)
    tfrecords_to_embeddings(tfrecords_dirpath, embeddings_dirpath, model, batch_size)
    print_dirpath_stats(embeddings_dirpath)

### Search

**Display**

In [None]:
def get_image_filepath(images_dirpath, image_name):
  return images_dirpath.joinpath(image_name+".jpeg")

def display_picture(image_filepath):
  print("\nQuery Image: {}\n".format(image_filepath.stem))
  pkg_disp.display(pkg_disp.Image(filename=image_filepath, height=image_height, width=image_width))

def display_picture_of_directory(images_dirpath, image_name):
  display_picture(get_image_filepath(images_dirpath, image_name))

def display_results(images_dirpath, results):
  hbox = pkg_widgets.HBox([pkg_widgets.VBox([
    pkg_widgets.widgets.Label(f"{distance:.2f} {image_name}"),
    pkg_widgets.widgets.Image(value=open(get_image_filepath(images_dirpath, image_name), 'rb').read())
  ]) for distance, image_name in results])
  print("\nResut Images: \n")
  pkg_disp.display(hbox)

**Load Data**

In [None]:
def load_embeddings_into_dataframe(embeddings_dirpath):
  file_list = [f.numpy().decode("utf-8") for f in list_dirpath(embeddings_dirpath, "*.parquet")]
  all_df = None
  for file_path in file_list:
    shard_df = pkg_parquet.read_table(file_path).to_pandas()
    if (all_df is None): 
      all_df = shard_df
    else:
      all_df = pkg_pandas.concat([all_df, shard_df])
  return all_df

def populate_index(embeds):
  num_dimensions = embeds.shape[1]
  index = pkg_faiss.IndexFlatIP(num_dimensions)
  index.add(embeds)
  return index

def build_maps(df):
  id_to_name = {k:v.decode("utf-8") for k,v in enumerate(list(df["image_name"]))}
  #name_to_id = {v:k for k,v in id_to_name.items()}
  return id_to_name#, name_to_id

def build_index_from_embeddings(embeddings_dirpath):
  embed_df = load_embeddings_into_dataframe(embeddings_dirpath)
  #print(embed_df.head())

  embeds = pkg_num.stack(embed_df["embedding"].to_numpy())
  print("Shapes: DataFrame = {}, Embeddings = {}".format(embed_df.shape, embeds.shape))

  index = populate_index(embeds)

  id_map = build_maps(embed_df)
  print("Map Sizes: id_map = {}".format(len(id_map)))

  return index, embed_df, id_map

In [None]:
def build_search_index(dataset_name, model, example_count, shard_count, batch_size,
  should_download=False, should_compute_embeddings=False):
  
  temp_dataset_dirpath = temp_root_dirpath.joinpath(dataset_name)
  temp_downloads_dirpath = temp_dataset_dirpath.joinpath("downloads")
  temp_images_dirpath = temp_dataset_dirpath.joinpath("images")
  temp_tfrecords_dirpath = temp_dataset_dirpath.joinpath("tfrecords")
  temp_embeddings_dirpath = temp_dataset_dirpath.joinpath("embeddings")

  # Once downloaded and saved, same can be reused for all the future runs
  # By passing should_download=True, caller can retrigger download of dataset
  if (should_download):
    download_and_save_dataset(
      downloads_dirpath=temp_downloads_dirpath, images_dirpath=temp_images_dirpath, 
      dataset_name=dataset_name, example_count=example_count)

  # Once embeddings are calculated for a selected model, same can be reused for all the future runs
  # By passing should_compute_embeddings=True, caller can retrigger computation of embeddings
  if (should_download or should_compute_embeddings):
    calculate_and_save_embeddings(model=model, images_dirpath=temp_images_dirpath, 
      tfrecords_dirpath=temp_tfrecords_dirpath, embeddings_dirpath=temp_embeddings_dirpath,
      shard_count=shard_count, batch_size=batch_size)

  index, embed_df, id_map = build_index_from_embeddings(embeddings_dirpath=temp_embeddings_dirpath)
  return index, embed_df, temp_images_dirpath, id_map

In [None]:
def load_test_embeddings(model, dataset_name, images_dirpath, shard_count, 
  batch_size, should_compute_embeddings=False):

  temp_testdata_dirpath = temp_root_dirpath.joinpath(dataset_name + "_test")
  temp_tfrecords_dirpath = temp_testdata_dirpath.joinpath("tfrecords")
  temp_embeddings_dirpath = temp_testdata_dirpath.joinpath("embeddings")

  if (should_compute_embeddings):
    calculate_and_save_embeddings(model=model, images_dirpath=images_dirpath, 
      tfrecords_dirpath=temp_tfrecords_dirpath, embeddings_dirpath=temp_embeddings_dirpath,
      shard_count=shard_count, batch_size=batch_size)

  embed_df = load_embeddings_into_dataframe(embeddings_dirpath=temp_embeddings_dirpath)
  print("Shapes: DataFrame = {}".format(embed_df.shape))

  id_map = build_maps(embed_df)
  print("Map Sizes: id_map = {}".format(len(id_map)))

  return embed_df, id_map

**Find**

In [None]:
def search_similar_images(index, id_map, image_embedding, k=5):
  dim, id_list = index.search(pkg_num.expand_dims(image_embedding, 0), k)
  return list(zip(dim[0], [id_map[x] for x in id_list[0]]))

def search_and_display_existing(index, dataset_embed_df, dataset_id_map, 
  dataset_images_dirpath, existing_image_name, k=5):
  display_picture_of_directory(dataset_images_dirpath, existing_image_name)
  existing_image_df = dataset_embed_df[dataset_embed_df["image_name"] == bytes(existing_image_name, "utf-8")]
  image_embedding = existing_image_df["embedding"].to_numpy()[0]
  display_results(dataset_images_dirpath, search_similar_images(index, dataset_id_map, image_embedding, k))

def search_and_display_test(index, dataset_images_dirpath, dataset_id_map, 
  test_images_dirpath, test_embed_df, test_image_name, k=5):
  display_picture_of_directory(test_images_dirpath, test_image_name)
  test_image_df = test_embed_df[test_embed_df["image_name"] == bytes(test_image_name, "utf-8")]
  image_embedding = test_image_df["embedding"].to_numpy()[0]
  display_results(dataset_images_dirpath, search_similar_images(index, dataset_id_map, image_embedding, k))

### Main

**Misc**

In [25]:
# Tested Datasets
tested_datasets = ["tf_flowers", "stanford_dogs"]

**Initialization**

In [26]:
# Config
target_example_count = 2000
target_shard_count = 10
target_batch_size = min(int(target_example_count/target_example_count), 20)
target_match_count = 5
target_dataset_name = tested_datasets[1]

model = EfficientNetRef(weights='imagenet', include_top=False, pooling="avg")

**Setup**

In [None]:
index, dataset_embed_df, dataset_images_dirpath, dataset_id_map = build_search_index(
  dataset_name=target_dataset_name, model=model, example_count=target_example_count, 
  shard_count=target_shard_count, batch_size=target_batch_size,
  should_download=False, should_compute_embeddings=False
)

**Test Setup**

In [None]:
test_images_dirpath = data_root_dirpath.joinpath(target_dataset_name)
test_embed_df, test_id_map = load_test_embeddings(model=model, 
  dataset_name=target_dataset_name, images_dirpath=test_images_dirpath,
  shard_count=target_shard_count, batch_size=target_batch_size,
  should_compute_embeddings=False)

**Match: Test Images**

In [None]:
for test_image_name in test_id_map.values():
  search_and_display_test(index, dataset_images_dirpath, dataset_id_map, 
    test_images_dirpath, test_embed_df, test_image_name, target_match_count)

**Match: Random Dataset Images**

In [None]:
ds_image_names = list(dataset_id_map.values())
for _ in range(10):
  ds_image_name = pkg_random.choice(ds_image_names)
  search_and_display_existing(index, dataset_embed_df, dataset_id_map, 
    dataset_images_dirpath, ds_image_name, target_match_count)

**Match: Select Dataset Images**

In [None]:
existing_image_ids = [
  7,    31,   67,   107,  131,  167,  278,  343,  476,  555,
  649,  761,  873,  966,  1234, 1376, 1524, 1711, 1857, 1970
]

for image_id in existing_image_ids:
  search_and_display_existing(index, dataset_embed_df, dataset_id_map, 
    dataset_images_dirpath, dataset_id_map[image_id], target_match_count)

**Conclusion**

In [None]:
# Compute Total Run Time
run_time_seconds = int(pkg_time.time() - run_start_time)
print("Total Run Time: {} seconds".format(run_time_seconds))