In [1]:

import IPython
import os
import pickle
import requests
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.neighbors import NearestNeighbors
import io
import random
from sentence_transformers import SentenceTransformer, util


In [None]:
#Import dataset
dataset_path = f"dataset.csv"
dataset = pd.read_csv(dataset_path)

dataset.head(), dataset.info()

In [None]:
# Filter valid image links
valid_image_links = [url for url in dataset['image_links'].dropna().unique() if url.startswith("http")]

valid_image_links

In [5]:
# Step 1: Extract Image Embeddings and Cache Locally using ResNet
EMBEDDING_CACHE = "embeddings_cache.pkl"
embeddings_cache = {}

In [None]:
# Define a transformation pipeline for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform

In [None]:
# Load a pre-trained ResNet model
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model = torch.nn.Sequential(*(list(model.children())[:-1]))  # Remove the classification layer
model.eval()


In [8]:
f = open(EMBEDDING_CACHE, "rb")

In [None]:
# Load cached embeddings if available
if os.path.exists(EMBEDDING_CACHE) and os.path.getsize(EMBEDDING_CACHE) > 0:
    with open(EMBEDDING_CACHE, "rb") as f:
        embeddings_cache = pickle.load(f)
    print("Embeddings cache loaded successfully.")
else:
    print("No cache!!!")

In [10]:
def extract_embedding_from_url(image_url):
    # Check cache first
    if image_url in embeddings_cache:
        return embeddings_cache[image_url]
    
    try:
        # Use requests.get() without stream to ensure full content is downloaded
        response = requests.get(image_url, timeout=10)
        response.raise_for_status()  # Raise an exception for bad HTTP responses
        
        # Use BytesIO to create a file-like object
        image_data = io.BytesIO(response.content)
        
        # Open image with explicit error handling
        try:
            image = Image.open(image_data).convert("RGB")
        except UnidentifiedImageError:
            print(f"Could not identify image from {image_url}")
            return None
        
        # Apply transformations
        image = transform(image).unsqueeze(0)  # Add batch dimension
        
        with torch.no_grad():
            embedding = model(image).squeeze().numpy()
        
        # Cache the embedding
        embeddings_cache[image_url] = embedding
        return embedding
    
    except requests.RequestException as e:
        print(f"Error fetching image from {image_url}: {e}")
        return None
    except Exception as e:
        print(f"Unexpected error processing {image_url}: {e}")
        return None

In [11]:
# Save the cache after each run
def save_embeddings_cache():
    with open(EMBEDDING_CACHE, "wb") as f:
        pickle.dump(embeddings_cache, f)

In [None]:
progress_index = 0
progress_length = len(valid_image_links)

# for url in valid_image_links:
#     extract_embedding_from_url(url)
#     print(f"{progress_index}/{progress_length}: extracting embeddings from {url}")
#     progress_index += 1

def extract_embedding_multithread(image_links, max_thread = 10):
    with ThreadPoolExecutor(max_workers=max_thread) as executor:
        results = list(tqdm(
            executor.map(extract_embedding_from_url, image_links),
            total = len(image_links),
            desc = "Extracting embeddings"
        ))
    total_images = len(image_links)
    successful_embeddings = sum(1 for result in results if result is not None)
    failed_count = total_images - successful_embeddings
    print(f"Extraction complete. "
                 f"Total images: {total_images}, "
                 f"Successful: {successful_embeddings}, "
                 f"Failed: {failed_count}")

extract_embedding_multithread(valid_image_links)

In [13]:
save_embeddings_cache()

In [14]:
# res = random.choice(valid_image_links)
# print(str(res))
# len(res)

In [None]:
# Step 2: Image Similarity Matching Logic

# Function to find similar images from the cache
# Args:
#   query_image_url (str): URL of the query image to compare.
#   image_urls (list, optional): List of additional image URLs (not currently used).
#   top_k (int): Number of top similar images to retrieve. Default is 5.
# Returns:
#   list: A list of top_k similar image URLs from the cache.

def find_similar_images_from_urls(query_image_url, image_urls=None, top_k=5):
    # Step 2.1: Extract embedding for the query image
    print("Extracting query image embedding...")
    query_embedding = extract_embedding_from_url(query_image_url)
    if query_embedding is None:  # Handle invalid or failed extraction
        return []

    embeddings = []  # List to store cached embeddings
    valid_image_urls = []  # List to store corresponding image URLs

    # Step 2.2: Load embeddings from the cache
    print("Loading embeddings from cache...")
    for url, embedding in embeddings_cache.items():
        embeddings.append(embedding)  # Append embedding to the list
        valid_image_urls.append(url)  # Append corresponding URL

    # Step 2.3: Check if the cache is empty
    if not embeddings:
        print("No embeddings available in the cache.")
        return []

    # Step 2.4: Compute cosine similarity between the query embedding and cached embeddings
    similarities = cosine_similarity([query_embedding], embeddings)[0]
    print(similarities.shape)

    # Step 2.5: Sort similarities and retrieve indices of the top_k most similar embeddings
    top_indices = similarities.argsort()[-top_k:][::-1]  # Sort in descending order
    print("\nTop Similar Images:")
    for i, idx in enumerate(top_indices, 1):
        print(f"{i}. Image: {valid_image_links[idx]}")
        print(f"   Similarity Score: {similarities[idx]:.4f}")

    # Step 2.6: Save the updated cache (if any updates were made elsewhere)
    save_embeddings_cache()

res = random.choice(valid_image_links)
print(res)
find_similar_images_from_urls(res)


In [None]:

# Step 2: Build Similarity Search Engines
# # Build FAISS index
# def build_faiss_index(embeddings):
#     dimension = embeddings[0].shape[0]
#     index = faiss.IndexFlatL2(dimension)  # L2 distance
#     faiss_embeddings = np.vstack(embeddings)
#     index.add(faiss_embeddings)
#     return index

# # Build HNSW index
# def build_hnsw_index(embeddings, space="cosine"):
#     dimension = embeddings[0].shape[0]
#     index = hnswlib.Index(space=space, dim=dimension)
#     index.init_index(max_elements=len(embeddings), ef_construction=200, M=16)
#     for i, embedding in enumerate(embeddings):
#         index.add_items(embedding, i)
#     index.set_ef(50)  # Controls recall quality at query time
#     return index

# Function to find similar images using cosine similarity
def find_similar_cosine(query_embedding, embeddings, top_k=5):
    similarities = cosine_similarity([query_embedding], embeddings)[0]
    top_indices = similarities.argsort()[-top_k:][::-1]
    return top_indices, similarities

# Step 3: Find Similar Images from Cache
def find_similar_images(query_image_url, top_k=20):
    print("Extracting query image embedding...")
    query_embedding = extract_embedding_from_url(query_image_url)
    if query_embedding is None:
        return []

    embeddings = []
    valid_image_urls = []

    print("Loading embeddings from cache...")
    for url, embedding in embeddings_cache.items():
        embeddings.append(embedding)
        valid_image_urls.append(url)

    if not embeddings:
        print("No embeddings available in the cache.")
        return []

    embeddings = np.vstack(embeddings)
    query_embedding = np.array(query_embedding)

    top_indices, similarities = find_similar_cosine(query_embedding, embeddings, top_k)
    # if method == "cosine":
    #     # Cosine Similarity Search
    # elif method == "faiss":
    #     # FAISS Search
    #     faiss_index = build_faiss_index(embeddings)
    #     _, top_indices = faiss_index.search(np.expand_dims(query_embedding, axis=0), top_k)
    #     top_indices = top_indices[0]
    #     similarities = None
    # elif method == "hnsw":
    #     # HNSW Search
    #     hnsw_index = build_hnsw_index(embeddings)
    #     labels, _ = hnsw_index.knn_query(query_embedding, k=top_k)
    #     top_indices = labels[0]
    #     similarities = None
    # else:
    #     raise ValueError("Invalid method. Choose from ['cosine', 'faiss', 'hnsw']")

    # save_embeddings_cache()  # Save updated cache

    print("Top similar images:")
    
    for i, index in enumerate(top_indices, 1):
        print(f"{index}.{similarities[index]}: {valid_image_urls[index]}")
    return [valid_image_urls[i] for i in top_indices]

# Example usage
query_image_url = random.choice(valid_image_links) # Replace with a query image URL
query_image_url = "https://thumbs.dreamstime.com/b/beach-ball-12760024.jpg"
print(query_image_url)
similar_images = find_similar_images(query_image_url)  # Options: "cosine", "faiss", "hnsw"
print("Top similar images:")
for img_url in similar_images:
    print(img_url)


In [None]:
# ============================
# 1. Load and Preprocess Data
# ============================
# Load the dataset (replace with your file path)

# Combine text fields for embedding
dataset['combined_text'] = dataset['category_1'] + " " + dataset['category_2'] + " " + dataset['category_3'] + " " + dataset['title']

# =============================
# 2. Text Embedding Generation
# =============================
# Load pre-trained Sentence Transformer model
text_model = SentenceTransformer('all-MiniLM-L6-v2')

TEXT_EMBEDDINGS_CACHE = "text_embeddings_cache.pkl"
text_embeddings_cache = {}
text_datas = dataset['combined_text'].tolist()

# Generate text embeddings for the dataset
def generate_text_embeddings(text_data):
    if text_data in text_embeddings_cache:
        return text_embeddings_cache[text_data]
    try:
        text_embedding = text_model.encode(text_data, convert_to_tensor=True)
        print("Generating text embeddings...")
        text_embeddings_cache[text_datas] = text_embedding
        return text_embedding
    except Exception as e:
        print(f"Error: {e}")

def extract_text_embeddings_multithread(text_datas, max_thread = 10):
    with ThreadPoolExecutor(max_workers=max_thread) as executor:
        results = list(tqdm(
            executor.map(generate_text_embeddings, text_datas),
            total = len(text_datas),
            desc = "Extracting text embeddings"
        ))
    total_images = len(image_links)
    successful_embeddings = sum(1 for result in results if result is not None)
    failed_count = total_images - successful_embeddings
    print(f"Extraction complete. "
                 f"Total images: {total_images}, "
                 f"Successful: {successful_embeddings}, "
                 f"Failed: {failed_count}")

def save_text_embedding_cache():
    with open(TEXT_EMBEDDINGS_CACHE, "wb") as f:
        pickle.dump(text_embeddings_cache, f)
