# Neural Fashion Product Search

Based on [VSE++](https://arxiv.org/pdf/1707.05612) and the Kaggle Fashion Product Images [dataset](https://www.kaggle.com/paramaggarwal/fashion-product-images-dataset).  

Please refer to this GitHub [repo](https://github.com/muelerma/vsepp) for implementation details.


In [None]:
import ipywidgets as widgets
import voila

import torch
import os
from pathlib import Path

from inference import load_model, load_embs, query_embd, get_data_loader, top_k_img

In [None]:
ON_GPU = True
DATASET= "fashion"
EMBS_PATH = "data/fashion"
SPLIT = "test"
VOCAB_PATH = "vocab"

In [None]:
height = 4
width = 4

In [None]:
checkpoint = torch.load("./runs/runX/model_best.pth.tar", map_location=torch.device("cpu"))

model, vocab = load_model(checkpoint, DATASET)
img_embs = load_embs(EMBS_PATH, SPLIT)


In [None]:
def create_image(query, img_file):
    img_path = Path(
        f"/home/dl-station/deep-learning/search/vsepp_orig/data/fashion/images/{img_file}"
    )
    with open(img_path, "rb") as f:
        img = f.read()
    return widgets.Image(
                        value=img,
                        format='png',
                        width=128,
                        height=128,
                        layout=widgets.Layout(height='auto', width='auto')
                        )

In [None]:
def create_grid(h, w, img_files):
    grid = widgets.GridspecLayout(height, width,grid_gap="10px")

    for i in range(height):
        for j in range(width):
            grid[i, j] = create_image(inp.value, img_files[i*width + j])
    
    return grid

In [None]:
def search_btn_cb(event):
    print("Search term: ", inp.value)
    txt_emb = query_embd(inp.value, model, vocab, ON_GPU)
    print("Text Embd shape: ", txt_emb.shape)
    data_loader = get_data_loader(SPLIT, vocab, checkpoint["opt"])
    img_files = top_k_img(txt_emb, img_embs, height * width, data_loader)
    print("Img files: ", img_files)
    img_grid = create_grid(height, width, img_files)

    out.clear_output()
    with out:
        display(img_grid)


In [None]:
## create widgets

In [None]:
btn = widgets.Button(description="Search")
inp = widgets.Text()
out = widgets.Output()

In [None]:
btn.on_click(search_btn_cb)

In [None]:
display(widgets.HBox([inp, btn]))

In [None]:
display(out)