In [None]:
import openai
import os
from PIL import Image
import pandas as pd
import json
from torchvision.transforms import ToTensor

## 목차

#### 1. 이미지 서치와 동일한 방식으로 검색 결과 제공
#### 2. 결과에서 추가적으로 input text를 활용한 rerank

## Load data & models

In [None]:
from pinecone import Pinecone

pc = Pinecone(api_key="74e30e50-02fa-4e55-9bff-affa6a3817a0")

index = pc.Index("fastcampus")
index.describe_index_stats()

In [None]:
from image_utils import fetch_clip, draw_images, extract_img_features

clip_model, clip_processor, clip_tokenizer = fetch_clip(model_name="patrickjohncyh/fashion-clip")

In [None]:
from splade.splade.models.transformer_rep import Splade
from transformers import AutoTokenizer

splade_model_id = 'naver/splade-cocondenser-ensembledistil'

splade_model = Splade(splade_model_id, agg='max')
splade_model.to('cpu')
splade_model.eval()

splade_tokenizer = AutoTokenizer.from_pretrained(splade_model_id)

In [None]:
from yolo_utils import fix_channels, visualize_predictions, rescale_bboxes, plot_results, box_cxcywh_to_xyxy
from transformers import YolosFeatureExtractor, YolosForObjectDetection

MODEL_NAME = "valentinafeve/yolos-fashionpedia"

yolo_feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
yolo_model = YolosForObjectDetection.from_pretrained(MODEL_NAME)

# 미리 선정된 prediction labels
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']

In [None]:
from search_utils import fashion_query_transformer, clothes_detector, text_search, gen_sparse_vector, describe_clothes, additional_search

In [None]:
local_db = pd.read_csv("local_db.csv")
local_db['values'] = local_db['values'].apply(json.loads)

## Image and text input <br>: Item level sequential search (text embeddings + image embeddings) -> text search

1. 제공되는 이미지를 기준으로 search를 하여 유관한 아이템을 N개 서치
2. fashion style을 명시한 text search를 활용하여 rerank

In [None]:
text_input = "I want the clothes to be more casual and easy to wear"

In [None]:
IMAGE_PATH = 'test_images/test_image7.jpg'

image = Image.open(open(IMAGE_PATH, "rb"))
image = fix_channels(ToTensor()(image))
image

In [None]:
# initialize openai
os.environ['OPENAI_API_KEY']= "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"
openai.api_key = os.environ["OPENAI_API_KEY"]

#### 1. Image only search

- 이미지와 유관한 아이템 search

In [None]:
from search_utils import clothes_detector

In [None]:
cropped_items = clothes_detector(image, yolo_feature_extractor, yolo_model, thresh=0.5)

In [None]:
cropped_items

In [None]:
descriptions = dict()

for i, img in cropped_items.items():
    print(i)
    desc = describe_clothes(img, i, openai.api_key)
    descriptions[i] = desc

In [None]:
for i, v in descriptions.items():
    print(i)
    print(v)
    print()

In [None]:
text_query = fashion_query_transformer(str(descriptions))
text_query

In [None]:
results = text_search(index, text_query, clip_model, clip_tokenizer, splade_model, splade_tokenizer, top_k=100, hybrid=False)
results.keys()

In [None]:
paths = dict()
for k,v in results.items():
    paths[k] = [i['metadata']['img_path'] for i in v['matches']]

# 이미지들 show
for k,v in paths.items():
    print(k)
    draw_images([Image.open(i) for i in v[:10]]) # 10개씩만 display

Reranking을 위한 충분한 후보군을 제공하기 위해 20를 retrieve

In [None]:
final_results = additional_search(local_db, cropped_items, results, clip_processor, clip_model, clip_tokenizer, top_k=50)

In [None]:
for k,v in final_results.items():
    print(k)
    draw_images([Image.open(i) for i in v])

#### 2. Reranking using text embeddings

- Text에서 묘사된 옷의 "방향성"을 활용하여 Rerank

In [None]:
# 가장 먼저, 텍스트가 구체적인 아이템을 언급하지 않는지 판단
text_result = fashion_query_transformer(text_input)
text_result

In [None]:
text_input

In [None]:
from search_utils import get_top_indices

In [None]:
if 'clothes_type' not in text_result['items'][0].keys():
    new_results = list()

    for k,v in final_results.items():
        # file_name을 다시 가져온다
        ids = [os.path.splitext(os.path.basename(i))[0] for i in v]
        tmp = local_db.loc[local_db['vdb_id'].isin(ids)]

        r = get_top_indices(tmp, text_result['items'][0]['refined_text'], k, clip_processor, clip_model, clip_tokenizer, 5, type='text')
        new_results.append(r)

In [None]:
refined_result = dict()

for search_result in new_results:
    category = list(search_result.keys())[0]
    paths = list(search_result.values())[0]

    full_paths = [os.path.join("imaterialist-fashion-2020-fgvc7", "cropped_images", i+".jpg") for i in paths]
    refined_result[category] = full_paths

In [None]:
for k,v in refined_result.items():
    print(k)
    draw_images([Image.open(i) for i in v])

In [None]:
# image