# 4Quant Interactive Visual Search
Here we use deep learning to train and use a visual search machine in the browser 

In [None]:
IMG_SIZE = (224, 224, 3)

In [None]:
from keras import applications as kapps, models, layers
from skimage.io import imread
import ipywidgets as ipw
from ipywebrtc import ImageRecorder, CameraStream
import PIL, io
import numpy as np
import seaborn as sns
from IPython.display import display, Javascript
def setup_appmode():
    js_str = """$('#appmode-leave').hide()
        // Hides the edit app button.
        $('#appmode-busy').hide()
        // Hides the kernel busy indicator.
        IPython.OutputArea.prototype._should_scroll = function(lines) {
            return false
            // disable scrolling
        }"""
    sns.set_style("whitegrid", {'axes.grid': False})
    display(Javascript(js_str))

def get_app_user_id():
    """
    Get the userid from the `jupyter_notebook_url`
    injected by the appmode extension (if in use)
    otherwise return a 'nobody'
    :return: appmode username of current user
    >>> get_app_user_id()
    'nobody'
    >>> jupyter_notebook_url = 'https://a.b.c?user=dan#hello'
    >>> get_app_user_id()
    'dan'
    """
    cur_url = globals().get('jupyter_notebook_url', None)
    if cur_url is None:
        # black magic to get the 'injected' variable
        frame = inspect.currentframe()
        try:
            out_locals = frame.f_back.f_locals
            cur_url = out_locals.get('jupyter_notebook_url', 'nobody')
        finally:
            del frame
    qs_info = parse_qs(urlparse(cur_url).query)
    return qs_info.get('user', ['nobody'])[0]
setup_appmode()

In [None]:
if True:
    camera = CameraStream.facing_user(audio=False)
else:
    camera = CameraStream.facing_environment(audio=False)
image_recorder = ImageRecorder(stream=camera)

In [None]:
raw_model = kapps.MobileNetV2(input_shape = IMG_SIZE, include_top=False)
mn_model = models.Sequential()
mn_model.add(raw_model)
mn_model.add(layers.GlobalAvgPool2D())
prep_func = kapps.mobilenet_v2.preprocess_input

In [None]:
global vec_list, name_list, img_list
vec_list, name_list, img_list = [], [], []

In [None]:
search_results = ipw.Image()
output = ipw.Output()
tool_options = [('train', 1), ('search', 0)]
tool_mode = ipw.ToggleButtons(options=tool_options)
progress_bar = ipw.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Waiting...',
    bar_style='info',
    orientation='horizontal'
)
@output.capture()
def update_image(change):
    global vec_list, name_list, img_list
    progress_bar.value=0
    progress_bar.bar_style='info'
    full_im = PIL.Image.open(io.BytesIO(image_recorder.image.value))
    im_in = full_im.resize(IMG_SIZE[:2])
    im_array = np.array(im_in)[...,:3] # no alpha
    progress_bar.description='Processing Image...'
    progress_bar.value=25
    
    im_array = np.expand_dims(im_array.astype(np.float32), 0)
    prep_array = prep_func(im_array)
    out_vec = mn_model.predict(prep_array)[0]
    if tool_mode.value==1:
        progress_bar.description='Training...'
        progress_bar.value=75
        vec_list += [out_vec]
        name_list += ['Local']
        img_list += [full_im]
    else:
        progress_bar.description='Searching...'
        progress_bar.value=50
        if len(vec_list)>0:
            vec_mat = np.stack(vec_list, 0)
            v_score = np.dot(vec_mat, out_vec)
            match_idx = np.argsort(-1*v_score)[:2]
            if len(match_idx)>1:
                out_img = PIL.Image.fromarray(np.concatenate([img_list[i] for i in match_idx], 1))
            else:
                out_img = img_list[match_idx]
            temp_io = io.BytesIO()
            out_img.save(temp_io, format='png')
            search_results.value = temp_io.getvalue()
    progress_bar.value=100
    progress_bar.bar_style='success'
    
image_recorder.image.observe(update_image, 'value')
def _tool_switch(change):
    progress_bar.value=0
    progress_bar.bar_style='info'
    if tool_mode.value==0:
        progress_bar.description='Search'
    else:
        progress_bar.description='Train'
        
tool_mode.observe(_tool_switch, 'value')
ipw.VBox([
    ipw.HBox([tool_mode,progress_bar]), 
          camera, 
          ipw.HBox([image_recorder, 
                    ipw.VBox([ipw.Label('Search Results:'), search_results])
                   ]), 
          output])