## Query screenshots using SBERT/OPENAI

The image descriptions were created using OpenAI (script `generate_image_descriptions.py`).

Embeddings may being created using SBERT (`create_weaviate_database_sbert.py`) or text-embedding-3-* from OpenAI (`create_weaviate_database_openai.py`).

Those embeddings are fed to an weviate database, and are retrieved using near_vector.

#### Define which embedding will be (was) used.

In [None]:
IMAGE_DIR = "/Users/gustavofuhr/projects/data/my_screenshots/"
EMBEDDER = "OPENAI" # SBERT or OPENAI

#### (Optional) Call python script that will populate weaviate database

Should not take too long.

In [None]:
if EMBEDDER == "OPENAI":
    !python create_weaviate_database_openai.py --image_dir {IMAGE_DIR}
elif EMBEDDER == "SBERT":
    !python create_weaviate_database_sbert.py --image_dir {IMAGE_DIR}

#### Connect to weaviate

In [None]:
import weaviate
WEAVIATE_URL = "http://localhost:8080"
class_name = "ScreenshotQuerySBERT" if EMBEDDER == "SBERT" else "ScreenshotQueryOPENAI"

client = weaviate.connect_to_local()

classes = list(client.collections.list_all().keys())
if not any(cls == class_name for cls in classes):
    raise Exception(f"Class {class_name} not found in schema")

#### Define embedding function for inference

In [None]:
import os

def embed_query(query_str):
    raise NotImplementedError("Embedder is not set")

if EMBEDDER == "SBERT":
    # define functions to generate embedding from query str
    from sentence_transformers import SentenceTransformer
    SBERT_MODEL = "all-MiniLM-L6-v2"

    model = SentenceTransformer(SBERT_MODEL)
    def embed_query(query_str):
        return model.encode(query_str).tolist()
    
elif EMBEDDER == "OPENAI":
    from create_descriptions_openai_embeddings import get_openai_embedding
    def embed_query(query_str):
        return get_openai_embedding(query_str)
    
def query_screenshots(query_str, distance = 0.7, n_images_limit = 5):

    q_feat = embed_query(query_str)

    collection = client.collections.get(class_name)

    response = collection.query.near_vector(
        near_vector=q_feat,
        distance=distance,
        limit=n_images_limit,
    )

    return [os.path.join(IMAGE_DIR, o.properties["filename"]) for o in response.objects]


#### Finally, create the interface for querying and query screenshots!

In [None]:
from bokeh.plotting import output_notebook, show
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipywidgets import HBox, VBox

from utils import image_and_descriptions_plot

N_COLUMNS = 4
N_ROWS = 3

distance_slider = widgets.FloatSlider(value=0.7, min=0, max=1.0, step=0.01, 
    description='Distance:', continuous_update=False, style={'description_width': 'initial'}, 
    layout=widgets.Layout(width='50%'))

n_images_slider = widgets.IntSlider(value=12, min=0, max=50, step=1, 
    description='n_images:', continuous_update=False, style={'description_width': 'initial'}, 
    layout=widgets.Layout(width='50%'))

text_input = widgets.Text(value='Which screenshots show dogs?', 
                          placeholder='Enter text...', description='Filter text:', 
    disabled=False, style={'description_width': 'initial'}, layout=widgets.Layout(width='100%', height='40px'))

filter_button = widgets.Button(description='Filter it', button_style='success')

sliders_box = HBox([distance_slider, n_images_slider])
widgets_box = VBox([sliders_box, text_input, filter_button])

display(widgets_box)



def on_button_click(b):
    clear_output(wait=True)
    display(widgets_box)
    print(f"Distance: {distance_slider.value}, Number of images: {n_images_slider.value}, Filter Text: {text_input.value}")
    
    images_retrieved = query_screenshots(text_input.value, distance = distance_slider.value, n_images_limit = n_images_slider.value)
    p = image_and_descriptions_plot(images_retrieved, N_COLUMNS, N_ROWS)
    show(p)
    

filter_button.on_click(on_button_click)
