In [None]:
#default_exp streamlit_app

In [None]:
#hide
from nbdev.showdoc import *

# Streamlit app

Streamlit is a more convenient way to activate a quick user-facing GUI than Voila was, especially because of Voila having conflicting dependencies with nbdev.

However, Streamlit wants a `.py` file instead of a notebook for development. This is kind of annoying, because to get the hot-reload effect from Streamlit we have to develop outside the notebook, but to maintain documentation (and compile with everything else) we have to keep the main source of truth right here. Perhaps a solution will present itself later; meanwhile, I have been using a scratch file `streamlit-app.py` for development and then copied it back here.

This is a workaround for the query_flow printing to stdout. Maybe it should be handled natively in Streamlit? 

In [None]:
#export 
import streamlit as st
from memery import core

from pathlib import Path
from PIL import Image

from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME
from threading import current_thread
from contextlib import contextmanager
from io import StringIO
import sys

In [None]:
#export 
@contextmanager
def st_redirect(src, dst):
    placeholder = st.empty()
    output_func = getattr(placeholder, dst)

    with StringIO() as buffer:
        old_write = src.write

        def new_write(b):
            if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None):
                buffer.write(b + '')
                output_func(buffer.getvalue() + '')
            else:
                old_write(b)

        try:
            src.write = new_write
            yield
        finally:
            src.write = old_write


@contextmanager
def st_stdout(dst):
    with st_redirect(sys.stdout, dst):
        yield


@contextmanager
def st_stderr(dst):
    with st_redirect(sys.stderr, dst):
        yield

Trying to make good use of streamlit's caching service here; if the search query and folder are the same as a previous search, it will serve the cached version. Might present some breakage points though, yet to see.

In [None]:
#export
@st.cache
def send_image_query(path, text_query, image_query):
    ranked = core.query_flow(path, text_query, image_query=img)
    return(ranked)

@st.cache
def send_text_query(path, text_query):
    ranked = core.query_flow(path, text_query)
    return(ranked)

This is the sidebar content

In [None]:
#export
st.sidebar.title("Memery")

path = st.sidebar.text_input(label='Directory', value='./images')
text_query = st.sidebar.text_input(label='Text query', value='')
image_query = st.sidebar.file_uploader(label='Image query')
im_display_zone = st.sidebar.beta_container()
logbox = st.sidebar.beta_container()

2021-06-19 21:10:02.837 
  command:

    streamlit run /home/mage/.local/lib/python3.7/site-packages/ipykernel_launcher.py [ARGUMENTS]


The image grid parameters

In [None]:
#export
sizes = {'small': 115, 'medium':230, 'large':332, 'xlarge':600}

l, m, r = st.beta_columns([4,1,1])
with l:
    num_images = st.slider(label='Number of images',value=12)
with m:
    size_choice = st.selectbox(label='Image width', options=[k for k in sizes.keys()], index=1)
with r:
    captions_on = st.checkbox(label="Caption filenames", value=False)

And the main event loop, triggered every time the query parameters change.

This doesn't really work in Jupyter at all. Hope it does once it's compiled.

In [None]:
#export
if text_query or image_query:
    with logbox:
        with st_stdout('info'):
            if image_query is not None:
                img = Image.open(image_query).convert('RGB')
                with im_display_zone:
                    st.image(img)
                ranked = send_image_query(path, text_query, image_query)
            else:
                ranked = send_text_query(path, text_query)
    ims = [Image.open(o).convert('RGB') for o in ranked[:num_images]]
    names = [o.replace(path, '') for o in ranked[:num_images]]

    if captions_on:
        images = st.image(ims, width=sizes[size_choice], channels='RGB', caption=names)
    else:
        images = st.image(ims, width=sizes[size_choice], channels='RGB')