# Libs etc

In [1]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
!pip install torch torchvision 

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/_t/v7tb441j7_l32mrtl8d_nr7r0000gn/T/pip-req-build-myszvrr_
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /private/var/folders/_t/v7tb441j7_l32mrtl8d_nr7r0000gn/T/pip-req-build-myszvrr_
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25ldone


In [2]:
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import pandas as pd
import ast
import os
import torch
import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from clip import clip
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torch.nn.functional import normalize

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model, preprocess = clip.load("ViT-B/32")

# work with csv

In [5]:
#def create_full_path(row):
#    return f"remaster_fashion_for_clip/{row['file_path']}/{row['file_name']}.jpg"

In [6]:
#fashion_data['full_path'] = fashion_data.apply(create_full_path, axis=1)

In [7]:
fashion_data = pd.read_csv("remaster_fashion_for_clip/fashion.csv")

In [8]:
fashion_data.head()

Unnamed: 0.4,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,file_name,description,file_path,category,full_path
0,0,0,0,0,dress01,daisy slip dress,dress,dress,remaster_fashion_for_clip/dress/dress01.jpg
1,1,1,1,1,dress02,bunny tank top dress,dress,dress,remaster_fashion_for_clip/dress/dress02.jpg
2,2,2,2,2,dress03,floral midi slip dress,dress,dress,remaster_fashion_for_clip/dress/dress03.jpg
3,3,3,3,3,dress04,jeans slip dress,dress,dress,remaster_fashion_for_clip/dress/dress04.jpg
4,4,4,4,4,dress05,distressed double-corset dress,dress,dress,remaster_fashion_for_clip/dress/dress05.jpg


In [9]:
fashion_data = fashion_data.drop(fashion_data[fashion_data['file_name'] == 'ts03'].index)

In [10]:
fashion_data = fashion_data.drop(columns=["Unnamed: 0.3", "Unnamed: 0.2", "Unnamed: 0.1", "Unnamed: 0"])

In [11]:
fashion_data.head()

Unnamed: 0,file_name,description,file_path,category,full_path
0,dress01,daisy slip dress,dress,dress,remaster_fashion_for_clip/dress/dress01.jpg
1,dress02,bunny tank top dress,dress,dress,remaster_fashion_for_clip/dress/dress02.jpg
2,dress03,floral midi slip dress,dress,dress,remaster_fashion_for_clip/dress/dress03.jpg
3,dress04,jeans slip dress,dress,dress,remaster_fashion_for_clip/dress/dress04.jpg
4,dress05,distressed double-corset dress,dress,dress,remaster_fashion_for_clip/dress/dress05.jpg


In [12]:
def calculate_text_features(row):
    text = row['description']
    text_tokens = clip.tokenize([text]).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

In [13]:
def calculate_image_features(row):
    image_path = row['full_path']
    img = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

    image_features = model.encode_image(img)
    image_features /= image_features.norm(dim=-1, keepdim=True)

    return image_features

In [None]:
text_features_list = []
image_features_list = []

# Calculate features for each row in fashion_data and store them in the lists
for index, row in fashion_data.iterrows():
    text_features = calculate_text_features(row)
    image_features = calculate_image_features(row)
    text_features_list.append(text_features.detach().numpy())
    image_features_list.append(image_features.detach().numpy())

In [None]:
fashion_data['text_features'] = text_features_list
fashion_data['image_features'] = image_features_list

# Set the display options for float formatting
pd.set_option('display.float_format', lambda x: '%.5f' % x)

In [None]:
#fashion_data.to_csv("remaster_fashion_for_clip/f_w_F.csv")

In [None]:
#fashion_data = pd.read_csv("remaster_fashion_for_clip/f_w_f.csv", index_col=0)

In [None]:
fashion_data.head()

In [None]:
category_counts = fashion_data['category'].value_counts()

print("Number of items in each category:")
print(category_counts)

# functions

In [None]:
def get_random_image_info(dataset):
    if dataset.empty:
        raise ValueError("Dataset is empty.")

    random_index = np.random.randint(0, dataset.shape[0])
    random_row = dataset.iloc[random_index]

    random_image_info = {
        "file_path": random_row["file_path"],
        "full_path": random_row["full_path"],
        "file_name": random_row["file_name"],
        "description": random_row["description"],
        "text_features": np.array(random_row["text_features"]),
        "image_features": np.array(random_row["image_features"])
    }

    return random_image_info

In [None]:
def plot_image_with_info(image_info):
    image_path = image_info['full_path']
    image = mpimg.imread(image_path)

    plt.imshow(image)
    plt.title("Random Image")
    plt.axis("off")
    plt.show()

    print("File name:", image_info['file_name']+".jpg")
    print("Directory Name:", image_info['file_path'])
    print("Description:", image_info['description'])

In [None]:
def find_similar_images(image_info, needed_category, num_similar=3):
    category_data = fashion_data[fashion_data['file_path'] == needed_category]

    # Get the text and image features of the given image
    query_text_features = torch.flatten(torch.tensor(image_info['text_features']))
    query_img_features = torch.flatten(torch.tensor(image_info['image_features']))

    similarities = []
    for index, row in category_data.iterrows():
        if row['file_name'] != image_info['file_name']:
            # Flatten the 2D tensors from the dataset
            row_text_features = torch.flatten(torch.tensor(row['text_features']))
            row_img_features = torch.flatten(torch.tensor(row['image_features']))

            similarity = (
                torch.dot(row_text_features, query_text_features)
                + torch.dot(row_img_features, query_img_features)
            )
            similarities.append((row['full_path'], similarity.item()))

    similarities.sort(key=lambda x: x[1], reverse=True)
    similar_images = similarities[:num_similar]

    return similar_images

In [None]:
def plot_similar_images(similar_images):
    fig, axs = plt.subplots(1, len(similar_images), figsize=(15, 5))

    for idx, (img_path, similarity) in enumerate(similar_images, 1):
        img = Image.open(img_path)
        axs[idx - 1].imshow(img)
        axs[idx - 1].set_title(f"Top {idx} Similar Image\nSimilarity: {similarity:.2f}")
        axs[idx - 1].axis("off")

    plt.show()

In [None]:
needed_category1 = "dress"
needed_category2 = "top"
needed_category3 = "pants"
needed_category4 = "highheels"
needed_category5 = "jacket"
needed_category6 = "sneakers"
needed_category7 = "blazer"
needed_category8 = "tshirt"
needed_category9 = "boots"

# similar imgs

In [None]:
random_image_info = get_random_image_info(fashion_data)
plot_image_with_info(random_image_info)

In [None]:
similar_images1 = find_similar_images(random_image_info, needed_category1)
plot_similar_images(similar_images1)

In [None]:
similar_images2 = find_similar_images(random_image_info, needed_category2)
plot_similar_images(similar_images2)

In [None]:
similar_images3 = find_similar_images(random_image_info, needed_category3)
plot_similar_images(similar_images3)

In [None]:
similar_images4 = find_similar_images(random_image_info, needed_category4)
plot_similar_images(similar_images4)

In [None]:
similar_images5 = find_similar_images(random_image_info, needed_category5)
plot_similar_images(similar_images5)

In [None]:
similar_images6 = find_similar_images(random_image_info, needed_category6)
plot_similar_images(similar_images6)

In [None]:
similar_images7 = find_similar_images(random_image_info, needed_category7)
plot_similar_images(similar_images7)

In [None]:
similar_images8 = find_similar_images(random_image_info, needed_category8)
plot_similar_images(similar_images8)

In [None]:
similar_images9 = find_similar_images(random_image_info, needed_category9)
plot_similar_images(similar_images9)

# Let's try to search img by text description

def preprocess_text(text):
    # Tokenize and encode the text description using the same method you used in the dataset
    text_tokens = clip.tokenize([text]).to(device)
    return text_tokens

def calculate_text_features(text_tokens):
    # Calculate text features using the CLIP model
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

def find_similar_images(text_query, k=5):
    # Preprocess the text query
    query_text_tokens = preprocess_text(text_query)

    # Calculate text features for the query
    query_text_features = calculate_text_features(query_text_tokens)

    # Initialize a list to store image indices and their similarity scores
    similarity_scores = []

    for index, row in fashion_data.iterrows():
        # Calculate the similarity between the query text features and each row's text features
        row_text_features = torch.tensor(row['text_features'])
        row_img_features = torch.tensor(row['image_features'])

        similarity = torch.dot(query_text_features, row_text_features) + torch.dot(query_img_features, row_img_features)

        # Append the image index and its similarity score to the list
        similarity_scores.append((index, similarity.item()))

    # Sort the images based on similarity scores (in descending order)
    similarity_scores.sort(key=lambda x: x[1], reverse=True)

    # Get the indices of the top-k most similar images
    top_k_indices = [index for index, _ in similarity_scores[:k]]

    # Return the top-k rows from the fashion_data DataFrame
    return fashion_data.iloc[top_k_indices]

text_query = "red dress with floral pattern"
similar_images = find_similar_images(text_query, k=5)