In [None]:
from PIL import Image
from transformers import YolosFeatureExtractor, YolosForObjectDetection
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import ToTensor

import json
import pandas as pd
import warnings
import openai
import os

# Ignore warnings
warnings.filterwarnings('ignore')
# pandas dataframe display
pd.set_option('display.max_columns', None)

In [None]:
# initialize openai
os.environ['OPENAI_API_KEY']= "YOUR_OPENAI_API_KEY"
openai.api_key = os.environ["OPENAI_API_KEY"]

## Table of Contents

### 1. Object detection using yolo
### 2. Bbox integration
### 3. Cropping
### 4. Determining if an image is searchable
### 5. Providing search results for each detected category

In [None]:
attributes = pd.read_csv("attribute_specific.csv")

In [None]:
from yolo_utils import fix_channels, visualize_predictions, rescale_bboxes, plot_results, box_cxcywh_to_xyxy

In [None]:
MODEL_NAME = "valentinafeve/yolos-fashionpedia"

feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
model = YolosForObjectDetection.from_pretrained(MODEL_NAME)

# Pre-selected prediction labels
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']

In [None]:
IMAGE_PATH = 'test_images/test_image5.jpg'

In [None]:
image = Image.open(open(IMAGE_PATH, "rb"))
image = fix_channels(ToTensor()(image))
image

In [None]:
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)

In [None]:
visualize_predictions(image, outputs, threshold=0.)

In [None]:
probas = outputs.logits.softmax(-1)[0, :, :-1]
len(probas)

In [None]:
probas

In [None]:
def idx_to_text(i):
    return cats[i]

probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.5

prob = probas[keep]

In [None]:
keep

In [None]:
probas.shape

In [None]:
probas[0]

In [None]:
prob[0]

In [None]:
indices = [np.argmax(idx.detach().numpy()) for idx in prob]

indices

In [None]:
detected_cats = [cats[idx] for idx in indices]

detected_cats

In [None]:
len(outputs.pred_boxes[0])

In [None]:
boxes = outputs.pred_boxes[0, keep].cpu()

In [None]:
boxes

In [None]:
bboxes_scaled = rescale_bboxes(boxes, image.size).tolist()

In [None]:
bboxes_scaled

In [None]:
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results_2(pil_img, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for label, (xmin, ymin, xmax, ymax), c in zip(labels, boxes, colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        
        ax.text(xmin, ymin, label, fontsize=10,
                bbox=dict(facecolor=c, alpha=0.8))
    plt.axis('off')
    plt.show()

In [None]:
plot_results_2(image, detected_cats, bboxes_scaled)

### filter bounding boxes (select only necessary categories)

In [None]:
import pandas as pd

In [None]:
new_df = pd.read_csv("clothes_final2.csv")
new_df.name.unique()

In [None]:
category_of_interest = new_df.name.unique().tolist()

In [None]:
category_of_interest

In [None]:
keep_indices = list()
keep_bboxes = list()

for idx, box in zip(detected_cats, bboxes_scaled):
    if idx in category_of_interest:
        keep_indices.append(idx)
        keep_bboxes.append(box)

In [None]:
keep_indices, keep_bboxes

### concat bboxes

In [None]:
def iou(boxA, boxB):
    # Calculate the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # Compute the area of intersection
    interArea = max(0, xB - xA) * max(0, yB - yA)

    # Compute the area of both the prediction and ground-truth rectangles
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])

    # Compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the intersection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    return iou

def merge_boxes(boxes, labels):
    merged_boxes = []
    merged_labels = []
    used = set()

    for i in range(len(boxes)):
        if i in used:
            continue
        current_box = boxes[i]
        for j in range(i + 1, len(boxes)):
            if j in used or labels[i] != labels[j]:
                continue
            if iou(current_box, boxes[j]) > 0.5:  # Assuming a positive IoU indicates overlap
                # For xyxy format, we merge by finding the min and max coordinates
                current_box = [
                    min(current_box[0], boxes[j][0]), 
                    min(current_box[1], boxes[j][1]), 
                    max(current_box[2], boxes[j][2]), 
                    max(current_box[3], boxes[j][3])
                ]
                used.add(j)
        merged_boxes.append(current_box)
        merged_labels.append(labels[i])
        used.add(i)

    return np.array(merged_boxes), merged_labels


In [None]:
plot_results_2(image, keep_indices, keep_bboxes)

In [None]:
merged_bbox, merged_labels = merge_boxes(keep_bboxes, keep_indices)

In [None]:
merged_labels, merged_bbox

In [None]:
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results_2(pil_img, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for label, (xmin, ymin, xmax, ymax), c in zip(labels, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        
        ax.text(xmin, ymin, label, fontsize=10,
                bbox=dict(facecolor=c, alpha=0.8))
    plt.axis('off')
    plt.show()

In [None]:
plot_results_2(image, merged_labels, merged_bbox)

## crop images

In [None]:
from image_utils import crop_bbox

In [None]:
from PIL import Image, ImageFilter

def resize_img(image, category):
    standard_size = {"lowerbody":[420, 540],
        "upperbody":[500, 700],
        "wholebody":[480, 880],
        "legs and feet":[100, 150],
        "head":[150, 100],
        "others":[200, 350],
        "waist":[200, 100],
        "arms and hands":[75, 75],
        "neck":[120, 200]}
    
    w, h = image.size
    img_size = w*h

    new_width, new_height = standard_size[category]
    new_size = new_width * new_height

    if img_size >= new_size:
        # For downsizing
        downsized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        return downsized_image
    else:
        # For upsizing
        upsized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        upsized_image = upsized_image.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
        return upsized_image

In [None]:
categories = new_df[['supercategory', 'name']].drop_duplicates()

In [None]:
# categories.to_csv("categories.csv", index=False)

In [None]:
# Here, by converting to a dictionary, each category is set as a key value.
# Therefore, even if two 'shoes' are detected, only one is selected.
cropped_images = dict()

for label, box in zip(merged_labels, merged_bbox):
    cropped = resize_img(crop_bbox(image, box), categories.loc[categories['name']==label, 'supercategory'].values[0])
    cropped_images[label] = cropped

In [None]:
cropped_images

### Search from DB

In [None]:
from pinecone import Pinecone

pc = Pinecone(api_key="YOUR_PINECONE_API_KEY")
# Check the number of indexes
# index_list = pc.list_indexes().indexes

# index description
index = pc.Index("fastcampus")
index.describe_index_stats()

In [None]:
# CLIP
from image_utils import fetch_clip, extract_img_features, draw_images

model, processor, tokenizer = fetch_clip(model_name="patrickjohncyh/fashion-clip")

In [None]:
cropped_images

In [None]:
results = dict()

for label, image in cropped_images.items():
    img_emb = extract_img_features(image, processor, model).tolist()

    result = index.query(
        vector=img_emb[0],
        top_k=5,
        filter={"category": {"$eq": label}},
        include_metadata=True
    )

    paths = [i['metadata']['img_path'] for i in result.matches]

    results[label] = paths


In [None]:
for k, paths in results.items():
    print(k)
    draw_images([Image.open(i) for i in paths])

In [None]:
def clothes_detector(image, feature_extractor, model, thresh=0.5):
    # all categories
    cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 
            'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 
            'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 
            'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 
            'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
    # category we are interested in
    category_of_interest = ['pants', 'shirt, blouse', 'jacket', 'top, t-shirt, sweatshirt', 'dress', 'shoe', 'glasses', 
                        'skirt', 'bag, wallet', 'belt', 'headband, head covering, hair accessory', 'sock', 'hat', 
                        'watch', 'glove', 'tights, stockings', 'sweater', 'tie', 'shorts', 'scarf', 'coat', 'vest', 
                        'umbrella', 'cardigan', 'cape', 'jumpsuit', 'leg warmer']
    # yolo detection
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    
    # extract detected labels and boundingboxes
    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > thresh

    prob = probas[keep]

    indices = [np.argmax(idx.detach().numpy()) for idx in prob]
    detected_cats = [cats[idx] for idx in indices]
    boxes = outputs.pred_boxes[0, keep].cpu()

    bboxes_scaled = rescale_bboxes(boxes, image.size).tolist()
    
    # keep boxes that we are interested in
    keep_indices = list()
    keep_bboxes = list()

    for idx, box in zip(detected_cats, bboxes_scaled):
        if idx in category_of_interest:
            keep_indices.append(idx)
            keep_bboxes.append(box)
    # Integrate bboxes with overlapping sections
    merged_bbox, merged_labels = merge_boxes(keep_bboxes, keep_indices)

    # cropping
    categories = pd.read_csv("categories.csv")
    cropped_images = dict()

    for label, box in zip(merged_labels, merged_bbox):
        cropped = resize_img(crop_bbox(image, box), categories.loc[categories['name']==label, 'supercategory'].values[0])
        cropped_images[label] = cropped

    return cropped_images

In [None]:
def image_search(index, cropped_images, model, processor, top_k=10):
    results = dict()

    for label, image in cropped_images.items():
        img_emb = extract_img_features(image, processor, model).tolist()

        result = index.query(
            vector=img_emb[0],
            top_k=top_k,
            filter={"category": {"$eq": label}},
            include_metadata=True
        )

        results[label] = result
    return results

## Test

In [None]:
MODEL_NAME = "valentinafeve/yolos-fashionpedia"

feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
model = YolosForObjectDetection.from_pretrained(MODEL_NAME)

In [None]:
from search_utils import clothes_detector

In [None]:
clip_model, clip_processor, clip_tokenizer = fetch_clip(model_name="patrickjohncyh/fashion-clip")

In [None]:
image = Image.open("test_images/test.jpg")
image = fix_channels(ToTensor()(image))
image

In [None]:
cropped_items = clothes_detector(image, feature_extractor, model, thresh=0.5)

In [None]:
cropped_items

In [None]:
search_result = image_search(index, cropped_items, clip_model, clip_processor)

In [None]:
# Get the paths of the images
paths = dict()
for k,v in search_result.items():
    paths[k] = [i['metadata']['img_path'] for i in v['matches']]

# Show the images
for k,v in paths.items():
    print(k)
    draw_images([Image.open(i) for i in v])