In [None]:
import openai
import os
from PIL import Image
import pandas as pd
import json
from torchvision.transforms import ToTensor

## 0. Load necessary models and functions

Connect to pineconeDB

In [None]:
from pinecone import Pinecone

pc = Pinecone(api_key="YOUR_PINECONE_API_KEY")

index = pc.Index("fastcampus")
index.describe_index_stats()

CLIP for generating dense vectors for images & text

In [None]:
from image_utils import fetch_clip, draw_images, extract_img_features

clip_model, clip_processor, clip_tokenizer = fetch_clip(model_name="patrickjohncyh/fashion-clip")

SPLADE for generating sparse vectors for text

In [None]:
from splade.splade.models.transformer_rep import Splade
from transformers import AutoTokenizer

splade_model_id = 'naver/splade-cocondenser-ensembledistil'

splade_model = Splade(splade_model_id, agg='max')
splade_model.to('cpu')
splade_model.eval()

splade_tokenizer = AutoTokenizer.from_pretrained(splade_model_id)

Functions for implementing various search methods

In [None]:
from search_utils import fashion_query_transformer, clothes_detector, text_search, gen_sparse_vector, describe_clothes, additional_search

local_DB for two-stage search

In [None]:
local_db = pd.read_csv("local_db.csv")
local_db['values'] = local_db['values'].apply(json.loads)

YOLO for image detection

In [None]:
from yolo_utils import fix_channels, visualize_predictions, rescale_bboxes, plot_results, box_cxcywh_to_xyxy
from transformers import YolosFeatureExtractor, YolosForObjectDetection

MODEL_NAME = "valentinafeve/yolos-fashionpedia"

yolo_feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
yolo_model = YolosForObjectDetection.from_pretrained(MODEL_NAME)

# Pre-selected prediction labels
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']

In [None]:
# initialize openai
os.environ['OPENAI_API_KEY']= "YOUR_OPENAI_API_KEY"
openai.api_key = os.environ["OPENAI_API_KEY"]

# Define user input

In [None]:
from search_utils import fashion_query_transformer, clothes_detector, get_top_indices

# 1. Text input only

In [None]:
text_input = "a black cat"

In [None]:
# gateway
text_query = fashion_query_transformer(text_input)
print("### Result from the text_input gateway : \
{}".format(text_query))

if text_query:
    print("Searching ...")
    # text search
    result = text_search(index, text_query, clip_model, clip_tokenizer, splade_model, splade_tokenizer, top_k=10, hybrid=True)

    # Get image paths
    paths = dict()
    for k,v in result.items():
        paths[k] = [i['metadata']['img_path'] for i in v['matches']]

    # Show images
    for k,v in paths.items():
        print(k)
        draw_images([Image.open(i) for i in v])
else:
    print("This text is not related to fashion. Please enter again.")

```python
def text_search(index, items_dict, model, tokenizer, splade_model, splade_tokenizer, top_k=10, hybrid=False):
    search_results = dict()
    for item in items_dict['items']:
        text_emb = get_single_text_embedding(item['refined_text'], model, tokenizer)
        if hybrid:
            sparse_vector = gen_sparse_vector(item['refined_text'], splade_model, splade_tokenizer)
        else:
            sparse_vector=None
        
        if 'clothes_type' in list(item.keys()):
            search_result = index.query(
                            vector=text_emb[0],
                            sparse_vector=sparse_vector,
                            top_k=top_k,
                            filter={"category": {"$eq": item['clothes_type']}},
                            include_metadata=True
                        )
            search_results[item['clothes_type']] = search_result
        else:
            search_result = index.query(
                            vector=text_emb[0],
                            sparse_vector=sparse_vector,
                            top_k=top_k,
                            include_metadata=True
                        )
            search_results['all'] = search_result
    return search_results
```

# 2. Image input only

In [None]:
image_path = "test_images/test_image8.jpg"

In [None]:
image = Image.open(image_path)

image = fix_channels(ToTensor()(image))
# object detections
print("Detecting items from the image.")
cropped_images = clothes_detector(image, yolo_feature_extractor, yolo_model, thresh=0.7)

if len(cropped_images.keys())==0:
    print("Nothing detected from the image")
else:
    print("Detected ", cropped_images.keys())
    
    # describe the labels I have found
    descriptions = dict()

    print("Start creating descriptions for each item.")
    for i, img in cropped_images.items():
        print("Created descriptions for {}".format(i))
        desc = describe_clothes(img, i, openai.api_key)
        descriptions[i] = desc
    print("\nTransform the descriptions into structured query.")
    text_query = fashion_query_transformer(str(descriptions))
    print(text_query)
    results = text_search(index, text_query, clip_model, clip_tokenizer, splade_model, splade_tokenizer, top_k=100)
    print("\nRetrieved 100 images based on text search")

    print("\nConducting additional search using the input images")

    results2 = additional_search(local_db, cropped_images, results, clip_processor, clip_model, clip_tokenizer, 10)

    for k,v in results2.items():
        print(k)
        draw_images([Image.open(i) for i in v])

In [None]:
image

# 3. Text and Image input

- If you enter a fashion style that is too different from the existing image, there is a high possibility that the desired result will not be obtained.

In [None]:
text_input = "softer and more comfortable material"
image_path = "test_images/test_image2.jpg"

In [None]:
image

In [None]:
text_query = fashion_query_transformer(text_input)
print("### Result from the text_input gateway : \
{}".format(text_query))

# Fashion-related query

if 'clothes_type' in text_query['items'][0].keys():
    print("Please enter the desired fashion style, rather than a specific item.")
elif text_query:
    image = Image.open(image_path)
    image = fix_channels(ToTensor()(image))
    # object detections
    print("Detecting items from the image.")
    cropped_images = clothes_detector(image, yolo_feature_extractor, yolo_model)

    if len(cropped_images.keys())==0:
        print("Nothing detected from the image")
    else:
        print("Detected ", cropped_images.keys())
        print("-"*10, "Start image only search", "-"*10)
        
        # describe the labels I have found
        descriptions = dict()

        print("Start creating descriptions for each item.")
        for i, img in cropped_images.items():
            print("Created descriptions for {}".format(i))
            desc = describe_clothes(img, i, openai.api_key)
            descriptions[i] = desc
        print("\nTransform the descriptions into structured query.")
        text_query = fashion_query_transformer(str(descriptions))
        print(text_query)
        results = text_search(index, text_query, clip_model, clip_tokenizer, splade_model, splade_tokenizer, top_k=200)
        print("\nRetrieved 200 images based on text search")

        print("\nConducting additional search using the input images")

        results2 = additional_search(local_db, cropped_images, results, clip_processor, clip_model, clip_tokenizer, 100)
        print("\nRetrieved 100 items each, from sequential image search")

        print("-"*10, "Start reranking the results based on user input text", "-"*10)
        # Text search
        new_results = list()

        for k,v in results2.items():
            ids = [os.path.splitext(os.path.basename(i))[0] for i in v]
            tmp = local_db.loc[local_db['vdb_id'].isin(ids)]

            r = get_top_indices(tmp, text_query['items'][0]['refined_text'], k, clip_processor, clip_model, clip_tokenizer, 10, type='text')
            new_results.append(r)

        refined_result = dict()

        for search_result in new_results:
            category = list(search_result.keys())[0]
            paths = list(search_result.values())[0]

            full_paths = [os.path.join("imaterialist-fashion-2020-fgvc7", "cropped_images", i+".jpg") for i in paths]
            refined_result[category] = full_paths
        for k,v in refined_result.items():
            print(k)
            draw_images([Image.open(i) for i in v])
else:
    print("This text is not related to fashion.")