In [1]:
from annoy import AnnoyIndex
import torch, numpy as np, os, clip, pickle as pkl, pandas as pd, numpy as np, urllib.request, ipywidgets as widgets
from PIL import Image
from IPython.display import display, Markdown, HTML, clear_output

### Load CLIP and Fashion Name Index

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
BASE_DIR = "/home/ubuntu/AI_Search/fashion_data/"


with open(os.path.join(BASE_DIR, "idx_to_pth_fashion_example_6_4_2023.pkl"), "rb") as f:
    idx_to_pth = pkl.load(f)

##### Create dictionaries for displaying fashion info

In [3]:
with open(os.path.join(BASE_DIR, "styles.csv"), "r") as f:
    csv_str = f.readlines()

id_to_displayName = {}
displayName_to_id = {}
for csv_row in csv_str[1:]:
    id_ = csv_row.split(",")[0]
    display_name = csv_row.split(",")[-1].strip()
    id_to_displayName[id_] = display_name
    displayName_to_id[display_name] = id_

## Launch Search Interface

In [4]:


clear_output()

query = widgets.Text(placeholder="Natural Language Query",layout=widgets.Layout(width='400px'))


button_style = {'button_color': 'lightblue', 'font_weight': 'bold'}
button = widgets.Button(description="Search", style=button_style)
output = widgets.Output()

display(widgets.HBox([query, button],
                     layout=widgets.Layout(justify_content='center')),
        output)

def image_search(query, n_results=6):
    """
    Perform a natural language image search based on the given query.

    Args:
        query (str): The query string for the image search.
        n_results (int, optional): The number of results to return. Defaults to 6.

    Returns:
        list: A list of filepaths representing the top matching images.

    """

    text = clip.tokenize(query).to(device)

    with torch.no_grad():

        f = 768

        u = AnnoyIndex(f, 'angular')
        u.load(os.path.join(BASE_DIR, "fashion_example_6_4_2023.ann"))  # super fast, will just map the file

        text_features = model.encode_text(text)

        nns = u.get_nns_by_vector(text_features[0], n_results,search_k=-1, include_distances=False)
        
        filepaths = []
        for i in range(0, len(nns)):
            filepath = idx_to_pth[nns[i]].replace("/home/ubuntu/", "/home/ubuntu/AI_Search/")
            filepaths.append(filepath)   
    
    return filepaths


def on_button_clicked(b):
    """
    Executes image search when a button is clicked.
    
    Parameters:
        b (Button): The button that was clicked.
        
    Returns:
        None
    """
    if len(query.value) > 0:
        filepaths = image_search(query.value)
        output.clear_output()
        with output:
            display_images_table(filepaths)
            



def display_images_table(filepaths):
    """
    Generates a table of images with their display names.

    Parameters:
        filepaths (List[str]): A list of filepaths of the images.

    Returns:
        None
    """
    table_rows = ""
    num_columns = 3  # Number of columns in the table
    num_images = len(filepaths)

    # Calculate the number of rows based on the number of columns
    num_rows = -(-num_images // num_columns)  # Equivalent to math.ceil(num_images / num_columns)

    for i, filepath_og in enumerate(filepaths):
        filepath = filepath_og.replace('/home/ubuntu/AI_Search/', '')
        # print("filepath ", filepath)
        displayName = id_to_displayName[filepath.split(".")[-2].split("/")[-1]]

        image_html = f'<img src="{filepath}" style="max-width:200px; max-height:200px; display: block; margin: 0 auto; padding: 5%;">'
        table_rows += f'<td style="padding: 2%; text-align: center; cursor: pointer; background-color: lightblue;" onmouseover="this.style.backgroundColor=\'lightgrey\';" onmouseout="this.style.backgroundColor=\'lightblue\';"><span style="font-weight: bold;">{displayName}</span><br>{image_html}</td>'

        # Add empty cells to make the table square for four images
        if num_images == 4 and (i + 1) % num_columns == 0:
            table_rows += '<td style="padding: 2%;"></td>'

        # Add a new row after each column
        if (i + 1) % num_columns == 0:
            table_rows = f'<tr>{table_rows}</tr>'

    # Add an empty cell if the last row is incomplete
    if num_images % num_columns != 0:
        table_rows += '<td style="padding: 2%;"></td>'

    table_html = f'<table style="border-collapse: collapse; width: 100%;">' \
                 f'<tr>{table_rows}</tr></table>'
    display(HTML(table_html))

button.on_click(on_button_clicked)
# dataset.observe(on_button_clicked, names='value')

HBox(children=(Text(value='', layout=Layout(width='400px'), placeholder='Natural Language Query'), Button(desc…

Output()

ValueError: No such metric