# Interactive Visual Feature Search: Local Demo

We provide a small-scale demo (without caching) of Visual Feature Search for the reader to use. We use a MobileNet_v2 model trained on ImageNet for this visualization, and we provide ten queries from the ImageNet test set in the `test_imgs` directory. 

We choose to provide a small-scale notebook to be run locally (on a CPU only) as this is the easiest way to share a demo that is (1) anonymous and (2) able to run on any computer.

###  Setup
First, it is required to download the ImageNet validation set for this demo. We use a subset of the validation set (1,000 images) as our searchable database for visual feature search. To get the dataset:
1. Go to image-net.org and login or signup for access.
2. Go to the following URL: https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
3. Download "ILSVRC2012_img_val.tar" file from the link titled "Validation images (all tasks)" under the "Images" header (The file size should be 6.3Gb).
4. Extract the .tar file, and set IMAGENET_VAL_DIR below to the directory containing the extracted images.

For running our demo, we recommend using anaconda/miniconda to create a temporary environment. To set everything up, complete the following steps:
1. Run `conda create -n tmp-vfs`
2. Run `conda activate tmp-vfs`, followed by `conda install jupyter`
3. Start a Jupyter notebook server by running `jupyter notebook`. **NOTE:** other notebook environments, such as Jupyter Lab, may not work with our interactive widget! We have only tested our library for Google Colab and `jupyter notebook`.
4. Open this notebook file in the Jupyter Notebook, and run the following code blocks to install & import components.
5. Run the remaining code blocks to create the interactive widget and search for similar regions in a small-scale example by using MobileNet_v2 features.

In [None]:
IMAGENET_VAL_DIR = 'path/to/val/images' # TODO: replace me
!pip install -r requirements.txt

In [None]:
import os
import numpy as np
import requests
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 10]

import torch
import torchvision
from torchvision import transforms

from lucent.optvis.render import ModuleHook

from vissearch import widgets, util, data
from vissearch.searchtool import LiveSearchTool, get_crop_rect

## Set up model and search tool
Search for similar regions in activations from the last bottleneck layer in MobileNetV2. The following code may take several minutes to complete.

In [None]:
SEARCH_DATASET_SIZE = 1000

model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()

def model_layer(X):
    hook = ModuleHook(model.features[17])
    model(X)
    hook.close()
    return hook.features

CONV5_FEATURE_SIZE = 7 # row/column length for the layer of interest

imagenet_dataset = data.SimpleDataset(IMAGENET_VAL_DIR, return_idxs=False)
# only search with a subset of images
imagenet_dataset._all_images = imagenet_dataset._all_images[:SEARCH_DATASET_SIZE] 

device = torch.device('cpu')
search_tool = LiveSearchTool(model_layer, device, imagenet_dataset, batch_size=64)

## Set up region selection UI

**Instructions:** In the region selection UI below, select a query image. Then, to highlight a region, click on the image, holding down the mouse press, and drag your cursor to highlight a region.

In [None]:
test_files = sorted(os.listdir('test_imgs'))

query_model_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

query_imgs = [Image.open('test_imgs/' + file_name) for file_name in test_files]
query_vis_imgs = [data.vis_transform(img) for img in query_imgs]
# convert images to Data URLs so we can pass them into the HTML widget
query_img_urls = [util.image_to_durl(img) for img in query_vis_imgs]

highlight_data = None
highlight_index = None
def highlight_callback(user_data):
    global highlight_data, highlight_index
    highlight_data = user_data[:-2]
    highlight_index = user_data[-1]
util.create_callback('highlight_callback', highlight_callback)

widgets.MultiHighlightWidget(all_urls=query_img_urls, callback_name='highlight_callback')

## Show selected region

In [None]:
assert highlight_data is not None, "Use the widget to highlight an image region"

# overlay the mask onto the user's selected image
curr_img = query_vis_imgs[int(highlight_index)]

mask = util.durl_to_image(highlight_data)
mask_arr = np.asarray(mask)[:,:,3] / 256 # take just the alpha channel

curr_mask_overlay = util.mask_overlay(curr_img, x=0, y=0, mask_size=224, mask=mask_arr, alpha=0.5, beta=0.4)

fig = plt.figure(figsize=(10, 3))

plt.subplot(1,3,1)
plt.axis('off')
plt.imshow(curr_img, cmap='gray')

plt.subplot(1,3,2)
plt.axis('off')
plt.imshow(mask_arr, cmap='gray')

plt.subplot(1,3,3)
plt.axis('off')
_ = plt.imshow(curr_mask_overlay)

## Compute activations and find similar regions
This code should only take a few seconds to complete.

In [None]:
# input selected image into model
curr_img_tensor = data.net_transform(query_imgs[int(highlight_index)])
search_tool.set_input_image(curr_img_tensor)

# assemble masks
donsample_transform = transforms.Resize(CONV5_FEATURE_SIZE)
small_mask = donsample_transform(mask)
small_mask_arr = np.asarray(small_mask)[:,:,3] / 255

# compute the similarities
print('Loading Results...')
sims, xs, ys = search_tool.compute(small_mask_arr)
image_order = torch.argsort(sims, descending=True)
print('Done.')

## Display most similar regions in activation space

These highlighted regions share the most similar activations after the last bottleneck in MobileNetV2.

In [None]:
# visualize results
# set up the figure
DISPLAY_NUM = 6

HEIGHT = 5
WIDTH = 15
plt.figure(figsize=(WIDTH, HEIGHT))

# show the query region on the left-hand side
ax = plt.subplot(1, DISPLAY_NUM, 1)
plt.axis('off')
plt.imshow(curr_mask_overlay)

# draw a border between query and results
ax.add_line(matplotlib.lines.Line2D([245,245], [0,224], lw=4, color='black')).set_clip_on(False)


for i in range(DISPLAY_NUM-1):
  idx = image_order[i]
  curr_img_out = util.mask_overlay(imagenet_dataset.get_vis_image(idx), 
                                   x=xs[idx], 
                                   y=ys[idx], 
                                   mask_size=CONV5_FEATURE_SIZE, 
                                   mask=util.crop_mask(small_mask_arr))

  plt.subplot(1, DISPLAY_NUM, i+2)
  plt.axis('off')
  plt.imshow(curr_img_out, cmap='gray')
  plt.title(str(np.round(sims[idx].cpu().numpy(), 2)))