In [40]:
# Category to attribute mapping
category_class_attribute_mapping = {
    'Kurtis': {
        'color': 'attr_1',
        'fit_shape': 'attr_2',
        'length': 'attr_3',
        'occasion': 'attr_4',
        'ornamentation': 'attr_5',
        'pattern': 'attr_6',
        'print_or_pattern_type': 'attr_7',
        'sleeve_length': 'attr_8',
        'sleeve_styling': 'attr_9'
    },
    'Men Tshirts': {
        'color': 'attr_1',
        'neck': 'attr_2',
        'pattern': 'attr_3',
        'print_or_pattern_type': 'attr_4',
        'sleeve_length': 'attr_5'
    },
    'Sarees': {
        'blouse_pattern': 'attr_1',
        'border': 'attr_2',
        'border_width': 'attr_3',
        'color': 'attr_4',
        'occasion': 'attr_5',
        'ornamentation': 'attr_6',
        'pallu_details': 'attr_7',
        'pattern': 'attr_8',
        'print_or_pattern_type': 'attr_9',
        'transparency': 'attr_10'
    },
    'Women Tops & Tunics': {
        'color': 'attr_1',
        'fit_shape': 'attr_2',
        'length': 'attr_3',
        'neck_collar': 'attr_4',
        'occasion': 'attr_5',
        'pattern': 'attr_6',
        'print_or_pattern_type': 'attr_7',
        'sleeve_length': 'attr_8',
        'sleeve_styling': 'attr_9',
        'surface_styling': 'attr_10'
    },
    'Women Tshirts': {
        'color': 'attr_1',
        'fit_shape': 'attr_2',
        'length': 'attr_3',
        'pattern': 'attr_4',
        'print_or_pattern_type': 'attr_5',
        'sleeve_length': 'attr_6',
        'sleeve_styling': 'attr_7',
        'surface_styling': 'attr_8'
    }
}

import numpy as np
import pandas as pd
import faiss
from sklearn.preprocessing import normalize
import os
import torch
from PIL import Image
import clip
from typing import List, Tuple
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import math

class CLIPEmbedder:
    def __init__(self, checkpoint_path: str = None, device: str = None):
        """
        Initialize the CLIP embedder with an optional custom checkpoint.
        
        Args:
            checkpoint_path: Path to the checkpoint containing CLIP model weights
            device: Device to run the model on ('cuda' or 'cpu'). If None, automatically detected.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        
        # Load the base CLIP model and preprocessor
        self.model, self.preprocess = clip.load("ViT-L/14", device=self.device)
        
        # Load custom weights if provided
        if checkpoint_path:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            if 'clip_model_state_dict' not in checkpoint:
                raise ValueError("Checkpoint does not contain 'clip_model_state_dict'")
            self.model.load_state_dict(checkpoint['clip_model_state_dict'])
        
        self.model.eval()
    
    @torch.no_grad()
    def extract_single_feature(self, image_path: str) -> np.ndarray:
        """Extract features from a single image."""
        try:
            image = Image.open(image_path).convert('RGB')
            pixel_values = self.preprocess(image).unsqueeze(0).to(self.device)
            features = self.model.encode_image(pixel_values)
            image.close()
            return features.cpu().numpy()
        except Exception as e:
            raise Exception(f"Error processing image {image_path}: {str(e)}")

class FAISSIndex:
    def __init__(self, features, filenames):
        self.features = features
        self.filenames = filenames
        self.dimension = features.shape[1]
        self.index = None
        
    def build_index(self):
        features_c = np.ascontiguousarray(self.features.astype('float32'))
        features_c = normalize(features_c)
        self.index = faiss.IndexFlatIP(self.dimension)
        self.index.add(features_c)
        print("Transferring index to GPU...")
        self.res = faiss.StandardGpuResources()
        self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)
        print("Index transferred to GPU.")

class SimilaritySearch:
    def __init__(self, index, filenames, train_categories, threshold=0.55):
        self.index = index
        self.filenames = filenames
        self.train_categories = train_categories  # Add train categories
        self.threshold = threshold
        
    def search(self, query_features, query_category, k=10):
        """
        Search for similar images within the same category
        
        Args:
            query_features: Features of the query image
            query_category: Category of the query image
            k: Number of results to return
        """
        # Get more initial results to ensure enough after category filtering
        initial_k = k * 3
        query_features = normalize(query_features.astype('float32'))
        distances, indices = self.index.search(query_features, initial_k)
        
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            similarity = dist
            filename = self.filenames[idx]
            # Only include results from the same category
            if (similarity >= self.threshold and 
                self.train_categories.get(filename) == query_category):
                results.append({
                    'filename': filename,
                    'similarity': float(similarity),
                    'category': query_category
                })
        
        # Return top k results after category filtering
        return results

class ImageVisualizer:
    @staticmethod
    def plot_similar_images(query_image_path: str, 
                          similar_images: pd.DataFrame, 
                          reference_image_folder: str,
                          figsize=(15, 10)):
        """
        Plot the query image and its similar matches with annotations.
        
        Args:
            query_image_path: Path to the query image
            similar_images: DataFrame containing similar image results
            reference_image_folder: Path to the folder containing reference images
            figsize: Size of the figure (width, height)
        """
        n_similar = len(similar_images)
        n_cols = min(5, n_similar + 1)  # +1 for query image
        n_rows = math.ceil((n_similar + 1) / n_cols)
        
        # Create figure with gridspec
        fig = plt.figure(figsize=figsize)
        gs = GridSpec(n_rows, n_cols, figure=fig)
        
        # Plot query image
        ax = fig.add_subplot(gs[0, 0])
        query_img = Image.open(query_image_path).convert('RGB')
        ax.imshow(query_img)
        ax.set_title('Query Image\n' + os.path.basename(query_image_path), fontsize=8)
        ax.axis('off')
        
        # Plot similar images
        for idx, row in enumerate(similar_images.itertuples()):
            row_idx = (idx + 1) // n_cols
            col_idx = (idx + 1) % n_cols
            
            ax = fig.add_subplot(gs[row_idx, col_idx])
            
            # Load and display similar image
            img_path = os.path.join(reference_image_folder, row.filename) + ".png"
            try:
                img = Image.open(img_path).convert('RGB')
                ax.imshow(img)
                title = f'Match {idx+1}\nFile: {row.filename}\nSimilarity: {row.similarity:.3f}'
                ax.set_title(title, fontsize=8)
                ax.axis('off')
            except Exception as e:
                print(f"Error loading image {img_path}: {str(e)}")
                ax.text(0.5, 0.5, 'Image not found', ha='center', va='center')
                ax.axis('off')
        
        plt.tight_layout()
        return fig

class SingleImageSimilaritySearch:
    def __init__(self, checkpoint_path: str, reference_features, reference_filenames, 
                 train_df: pd.DataFrame, test_categories_df: pd.DataFrame):
        """
        Initialize the similarity search system.
        
        Args:
            checkpoint_path: Path to CLIP checkpoint
            reference_features: Features of reference images
            reference_filenames: Filenames of reference images
            train_df: DataFrame containing training data with categories
            test_categories_df: DataFrame containing test image categories
        """
        self.feature_extractor = CLIPEmbedder(checkpoint_path)
        self.faiss_index = FAISSIndex(reference_features, reference_filenames)
        self.faiss_index.build_index()
        
        # Prepare category mappings
        self.train_df = train_df
        self.test_categories_df = test_categories_df
        
        # Create filename to category mappings
        self.train_df['id_as_filename'] = self.train_df['id'].astype(str).str.zfill(6) + '.png'
        self.train_categories = dict(zip(self.train_df['id_as_filename'], self.train_df['Category']))
        
        self.test_categories_df['id_as_filename'] = self.test_categories_df['id'].astype(str).str.zfill(6) + '.png'
        self.test_categories = dict(zip(self.test_categories_df['id_as_filename'], self.test_categories_df['Category']))
        
        self.similarity_search = SimilaritySearch(
            self.faiss_index.index, 
            reference_filenames,
            self.train_categories,
            threshold=0.55
        )
        self.visualizer = ImageVisualizer()
    
    def get_image_category(self, image_filename: str) -> str:
        """Get category for an image filename."""
        return self.train_categories.get(image_filename, None)
    
    def find_similar_images(self, image_path: str, k: int = 10) -> pd.DataFrame:
        """Find similar images for a single query image within the same category."""
        # Get query image category
        query_filename = os.path.basename(image_path)
        query_category = self.get_image_category(query_filename)
        
        if query_category is None:
            raise ValueError(f"Category not found for image: {query_filename}")
        
        # Extract features and search
        query_features = self.feature_extractor.extract_single_feature(image_path)
        similar_images = self.similarity_search.search(query_features, query_category, k=k)
        
        results_df = pd.DataFrame(similar_images)
        results_df['query_image'] = query_filename
        results_df['query_category'] = query_category
        results_df['file_path'] = f"/scratch/data/m23csa016/train_images_bg_removed/{results_df['filename']}"
        return results_df[['query_image', 'query_category', 'filename', 'file_path', 'similarity']]
    
    def find_and_save_visualizations(self, 
                                     image_id,
                                     image_path: str, 
                                     reference_image_folder: str, 
                                     output_folder: str,
                                     k: int = 10,
                                     image_size: tuple = (25, 10)):
        """
        Find similar images and save each visualization plot as an image file in the specified output folder.
        
        Args:
            image_path: Path to the query image
            reference_image_folder: Path to the folder containing reference images
            output_folder: Path to the folder to save individual plots
            k: Number of similar images to find
            image_size: Size of each plot image (width, height)
        
        Returns:
            DataFrame with the results of the similarity search.
        """
        results_df = self.find_similar_images(image_path, k)
        
        # Ensure output folder exists
        os.makedirs(output_folder, exist_ok=True)
        
      # Loop through each result to save a separate plot for each image
        for idx, row in enumerate(results_df.itertuples()):
            fig, axes = plt.subplots(1, 2, figsize=image_size)
            
            # Plot the query image on the left
            query_img = Image.open(image_path).convert('RGB')
            axes[0].imshow(query_img)
            query_category = results_df['query_category'].iloc[0]
            axes[0].set_title(f'Query Image\n{os.path.basename(image_path)}\nCategory: {query_category}', fontsize=8)
            axes[0].axis('off')
            
            # Plot the similar image on the right
            similar_img_path = os.path.join(reference_image_folder, row.filename)
            try:
                similar_img = Image.open(similar_img_path).convert('RGB')
                axes[1].imshow(similar_img)
                axes[1].set_title(f'Similar Image {idx+1}\nFile: {row.filename}\nSimilarity: {row.similarity:.3f}', fontsize=8)
                axes[1].axis('off')
            except Exception as e:
                print(f"Error loading image {similar_img_path}: {str(e)}")
                axes[1].text(0.5, 0.5, 'Image not found', ha='center', va='center')
                axes[1].axis('off')
            
            # Save the individual plot
            output_path = os.path.join(output_folder, f"{image_id}_{idx+1}.png")
            fig.savefig(output_path, bbox_inches='tight')
            plt.close(fig)
        
        print(f"Saved {len(results_df)} images to {output_folder}")
        return results_df
    

train_df = pd.read_csv('/scratch/data/m23csa016/meesho_data/new_train.csv')
test_categories_df = pd.read_csv('/scratch/data/m23csa016/meesho_data/test.csv')
print(f"Loaded train data: {len(train_df)} rows, test categories: {len(test_categories_df)} rows")

# Load pre-computed embeddings
train_features_df = pd.read_parquet('/scratch/data/m23csa016/meesho_data/cvl_nobg_max_train_em_1.parquet')
test_features_df = pd.read_parquet('/scratch/data/m23csa016/meesho_data/cvl_nobg_max_test_em_1.parquet')

feature_cols = [col for col in train_features_df.columns if col.startswith('feature_')]
train_features = train_features_df[feature_cols].values
train_filenames = train_features_df['filename'].tolist()

searcher = SingleImageSimilaritySearch(
    checkpoint_path="/scratch/data/m23csa016/meesho_data/bm_epoch_34_trainval_120024.pth",
    reference_features=train_features,
    reference_filenames=train_filenames,
    train_df=train_df,
    test_categories_df=test_categories_df
)

Loaded train data: 70379 rows, test categories: 30205 rows


  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Transferring index to GPU...
Index transferred to GPU.


In [50]:
import os
train_images = "/scratch/data/m23csa016/meesho_data/train_images_bg_removed"
image = "028878.png"
image_path = os.path.join(train_images, image)
output_folder = "/iitjhome/m23csa016/meesho_code/sim_output"
os.makedirs(output_folder, exist_ok=True)

In [51]:
df = searcher.find_and_save_visualizations(28878, image_path, train_images, output_folder, k=500)

Saved 77 images to /iitjhome/m23csa016/meesho_code/sim_output
