In [4]:
import sys
sys.path.insert(0, "../")

from transformers import ViTForImageClassification
from appraiser.retrievers.bruteforce import BruteForceRetriever
from appraiser.retrievers.annoy import AnnoyRetriever
from appraiser.extractors import embeddings_vec, raw_vec, raw_vec_flatten, read_image, embeddings_vec
from appraiser.config import MODEL_DIR
from appraiser.distances import scapy_distance
import torch
import dotenv
import pymongo
import ipyplot
from IPython.display import display
from ipywidgets import Dropdown, Output
from IPython.display import display, clear_output
import os

# Initaliaze a model
model = ViTForImageClassification.from_pretrained(MODEL_DIR)



### Choose test image

In [6]:
test_images_folder = 'C:\\research\\data\\test_images'
images_list = os.listdir(test_images_folder)

output = Output()
dropdown = Dropdown(description="Choose image:", options=images_list)



def dropdown_eventhandler(change):
    with output:
        clear_output()
        image_path = os.path.join(test_images_folder,change.new)
        image = read_image(image_path)
        aspect_ratio = image.size[0] / image.size[1]
        max_height = 250
        resized_width = int(max_height * aspect_ratio)
        resized_img = image.resize((resized_width, max_height))
        display(resized_img)


dropdown.observe(dropdown_eventhandler, names='value')
display(dropdown, output)

Dropdown(description='Choose image:', options=('mizulina_trench_masked.png', 'on_people_obvious.png', 'palto.j…

Output()

### Predict classes

In [7]:
test_image_path = os.path.join(test_images_folder,dropdown.value)
vec_to_predict = raw_vec(test_image_path)

In [8]:
#Predict the top k classes for the test image
outputs = model(vec_to_predict)
logits = outputs.logits

top_classes = torch.topk(outputs.logits, 3).indices.flatten().tolist()
most_accurate_label, most_accurate_class_idx = model.config.id2label[top_classes[0]], top_classes[0]
for i, class_idx in enumerate(top_classes):
    print(str(i + 1), "- Predicted class:", model.config.id2label[class_idx])

1 - Predicted class: verkhnyaya
2 - Predicted class: kardiganyi
3 - Predicted class: platya


### Retrieve closest images

In [9]:
from functools import partial
# Dirty move to set emdedding vec properly
embedding_vec_setted = partial(embeddings_vec, model=model, class_idx=most_accurate_class_idx)
embedding_vec_setted.__name__ = embeddings_vec.__name__

In [10]:
brute_on_embedding = BruteForceRetriever(most_accurate_label, feature_extractor=embedding_vec_setted)

In [11]:
brute_on_raw_flat = BruteForceRetriever(most_accurate_label, feature_extractor=raw_vec_flatten)

In [12]:
annoy_retriver_by_embedding = AnnoyRetriever(most_accurate_label,
                                             feature_extractor=embedding_vec_setted, 
                                             vector_size=embedding_vec_setted(test_image_path).shape[-1])

In [13]:
closest_by_embedding = brute_on_embedding.find_n_closest(embedding_vec_setted(test_image_path), n_count=10, distance_comparator=scapy_distance)

In [14]:
closest_by_flatten = brute_on_raw_flat.find_n_closest(raw_vec_flatten(test_image_path), n_count=10, distance_comparator=scapy_distance)

In [15]:
closest_by_embedding_annoy = annoy_retriver_by_embedding.find_n_closest(embedding_vec_setted(test_image_path)[0], n_count=10, distance_comparator=scapy_distance)

### Get source pages

In [16]:
# !pip3 install python-dotenv PyMongo ipyplot

In [17]:
dotenv.load_dotenv()
# Initalize mongodb client
client = pymongo.MongoClient(f"mongodb://{os.environ['MONGO_INITDB_ROOT_USERNAME']}:{os.environ['MONGO_INITDB_ROOT_PASSWORD']}@35.195.198.192:28123/")

mdb = client['clothes']
scrapy_items = mdb['scrapy_items']

In [18]:
def enrich_closest_images(closest):
    # Doesnt keeep search order??
    # image_names = ["full/" + each.split('\\')[-1] for each in closest]
    # items = scrapy_items.find({'images.0.path':{"$in":image_names}}, {'source_url':1, 'prices':1, 'images':1})
    items = []
    for closest_image_path in closest:
        search_request =  'full/' + closest_image_path.split('\\')[-1]
        query = {'images.0.path':search_request}
        item = scrapy_items.find_one(query, {'prices':1, 'source_url':1})
        item['img_full_path'] = closest_image_path
        items.append(item)
    return items

def visualise_images(closest):
    enriched = enrich_closest_images(closest_by_flatten)
    images = [each['img_full_path'] for each in enriched]
    labels = [each['prices'][-1] for each in enriched]
    ipyplot.plot_images(images, labels=labels, max_images=30, img_width=150)

In [19]:
visualise_images(closest_by_flatten)

In [20]:
visualise_images(closest_by_embedding)

In [21]:
visualise_images(closest_by_embedding_annoy)