# 4Quant Interactive Visual Search
Here we use deep learning to train and use a visual search machine in the browser.
## Train
To train the algorithm switch to train mode, fill in the item name and click the camera button to record a picture.

## Search
To search switch to search mode, and click the camera button to run the search. To show all the items that have been recorded, click the `Show entire dataset` button

In [None]:
IMG_SIZE = (224, 224, 3)

In [None]:
from keras import applications as kapps, models, layers
from skimage.io import imread
import ipywidgets as ipw
from ipywebrtc import ImageRecorder, CameraStream
import PIL, io
import numpy as np
import seaborn as sns
import pandas as pd
from IPython.display import display, Javascript
empty_image_io = io.BytesIO()
PIL.Image.fromarray(np.zeros((3, 3), dtype='uint8')).save(empty_image_io, format='png')
def setup_appmode():
    js_str = """$('#appmode-leave').hide()
        // Hides the edit app button.
        $('#appmode-busy').hide()
        // Hides the kernel busy indicator.
        IPython.OutputArea.prototype._should_scroll = function(lines) {
            return false
            // disable scrolling
        }"""
    sns.set_style("whitegrid", {'axes.grid': False})
    display(Javascript(js_str))

def get_app_user_id():
    """
    Get the userid from the `jupyter_notebook_url`
    injected by the appmode extension (if in use)
    otherwise return a 'nobody'
    :return: appmode username of current user
    >>> get_app_user_id()
    'nobody'
    >>> jupyter_notebook_url = 'https://a.b.c?user=dan#hello'
    >>> get_app_user_id()
    'dan'
    """
    cur_url = globals().get('jupyter_notebook_url', None)
    if cur_url is None:
        # black magic to get the 'injected' variable
        frame = inspect.currentframe()
        try:
            out_locals = frame.f_back.f_locals
            cur_url = out_locals.get('jupyter_notebook_url', 'nobody')
        finally:
            del frame
    qs_info = parse_qs(urlparse(cur_url).query)
    return qs_info.get('user', ['nobody'])[0]
import base64
def _wrap_uri(data_uri):
    return "data:image/png;base64,{0}".format(data_uri)
def image_to_uri(in_img, width=100, height=100):
    out_img_data = io.BytesIO()
    in_img.resize((width, height)).save(out_img_data, format='png')
    out_img_data.seek(0)  # rewind
    uri = _wrap_uri(base64.b64encode(out_img_data.read()
                                     ).decode("ascii").replace("\n", ""))
    return '<img src="{uri}" width="{width}px" height="{height}px"/>'.format(uri=uri, width=width, height=height)
def raw_html_render(temp_df):
    old_wid = pd.get_option('display.max_colwidth')
    pd.set_option('display.max_colwidth', -1)
    tab_html = temp_df.to_html(classes="table table-striped table-hover",
                               escape=False,
                               float_format=lambda x: '%2.2f' % x,
                               na_rep='',
                               index=False,
                               max_rows=None,
                               max_cols=None)

    pd.set_option('display.max_colwidth', old_wid)
    return tab_html
setup_appmode()

In [None]:
if False:
    camera = CameraStream.facing_user(audio=False)
else:
    camera = CameraStream.facing_environment(audio=False)
image_recorder = ImageRecorder(stream=camera)

In [None]:
raw_model = kapps.MobileNetV2(input_shape = IMG_SIZE, include_top=False)
mn_model = models.Sequential()
mn_model.add(raw_model)
mn_model.add(layers.GlobalAvgPool2D())
prep_func = kapps.mobilenet_v2.preprocess_input

In [None]:
global vec_list, name_list, img_list
vec_list, name_list, img_list = [], [], []

In [None]:
search_results = ipw.Image(value=empty_image_io.getvalue())
search_results = ipw.HTML()
output = ipw.Output()
tool_options = [('train', 1), ('search', 0)]
tool_mode = ipw.ToggleButtons(options=tool_options)
progress_bar = ipw.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Waiting...',
    bar_style='info',
    orientation='horizontal'
)
name_text = ipw.Text(description='Item Name', value='Default')

@output.capture()
def update_image(change):
    global vec_list, name_list, img_list
    progress_bar.value=0
    progress_bar.bar_style='info'
    search_results.value='<b>Loading...</b>'
    full_im = PIL.Image.open(io.BytesIO(image_recorder.image.value))
    im_in = full_im.resize(IMG_SIZE[:2])
    im_array = np.array(im_in)[...,:3] # no alpha
    progress_bar.description='Processing Image...'
    progress_bar.value=25
    im_array = np.expand_dims(im_array.astype(np.float32), 0)
    prep_array = prep_func(im_array)
    out_vec = mn_model.predict(prep_array)[0]
    if tool_mode.value==1:
        progress_bar.description='Training...'
        progress_bar.value=75
        vec_list += [out_vec]
        name_list += [name_text.value]
        img_list += [full_im]
    else:
        progress_bar.description='Searching...'
        progress_bar.value=50
        if len(vec_list)>0:
            vec_mat = np.stack(vec_list, 0)
            v_score = np.dot(vec_mat, out_vec)/np.dot(out_vec,out_vec)
            show_results(v_score, None, 4)
    progress_bar.value=100
    progress_bar.bar_style='success'
    
def show_results(v_score, cut_off = 0.1, max_results=5):
    out_results = []
    match_idx = np.argsort(-1*v_score)
    for i, c_idx in enumerate(match_idx):
        if max_results is not None:
            if i>max_results:
                break
        if cut_off is not None:
            if v_score[c_idx]<cut_off:
                break
        
        out_results += [dict(image=image_to_uri(img_list[c_idx]), 
                             name=name_list[c_idx],
                             score=v_score[c_idx])]
    out_html=raw_html_render(pd.DataFrame(out_results))
    search_results.value=out_html
    
image_recorder.image.observe(update_image, 'value')
show_all = ipw.Button(description='Show entire database')
show_all.on_click(lambda x: show_results(np.ones((len(vec_list))), max_results=None))
item_toolbar = ipw.HBox()
sr_box = ipw.VBox()
def _tool_switch(change):
    progress_bar.value=0
    progress_bar.bar_style='info'
    if tool_mode.value==0:
        progress_bar.description='Search'
        item_toolbar.children=[tool_mode, progress_bar]
        sr_box.children = [ipw.Label('Search Results:'), search_results, show_all]
    else:
        progress_bar.description='Train'
        item_toolbar.children=[tool_mode, name_text, progress_bar]
        sr_box.children = []
# initialize
_tool_switch(None)
        
tool_mode.observe(_tool_switch, 'value')
ipw.VBox([item_toolbar, 
          ipw.HBox([camera, 
                    sr_box
                   ]), 
          image_recorder,
          output])