# Streamlit Pet Breed Similarity Search App

This Streamlit application allows users to upload a pet image and find similar-looking pets from a pre-loaded database using a trained deep metric learning model.

## Running the Streamlit App

To run this Streamlit app:
1. Ensure you have the necessary Python packages installed (`streamlit`, `torch`, `torchvision`, `Pillow`, `numpy`, `tqdm`, `faiss-cpu` or `faiss-gpu`, `scikit-learn`, `matplotlib`, `umap-learn`).
2. Make sure you have the trained model file (e.g., `pet_metric_learning_resnet18_triplet.pth`) and the pet image database (e.g., `./data/oxford-iiit-pet/images`) available in the correct paths relative to where you run the command. You can configure these paths in the application's sidebar.
3. Open your terminal in the directory containing this notebook.
4. Convert this notebook to a Python script: `jupyter nbconvert --to python streamlit_pet_similarity_app.ipynb`
5. Run the command: `streamlit run streamlit_pet_similarity_app.py`

In [None]:
# Pet Similarity Search - Streamlit Application

import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import os
import sys
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from sklearn.manifold import TSNE
import faiss
from pathlib import Path
import json
import random
from collections import Counter

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the EmbeddingNet class here (copied from the notebook)
class EmbeddingNet(nn.Module):
    def __init__(self, backbone_name='resnet18', embedding_size=128, pretrained=True):
        super(EmbeddingNet, self).__init__()
        if backbone_name == 'resnet18':
            weights = ResNet18_Weights.DEFAULT if pretrained else None
            self.backbone = models.resnet18(weights=weights)
            backbone_output_size = 512
        elif backbone_name == 'resnet50':
            weights = ResNet50_Weights.DEFAULT if pretrained else None
            self.backbone = models.resnet50(weights=weights)
            backbone_output_size = 2048
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])

        self.projection_head = nn.Sequential(
            nn.Linear(backbone_output_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, embedding_size)
        )

    def forward(self, x):
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        embeddings = self.projection_head(features)
        normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
        return normalized_embeddings

    def get_embedding(self, x):
        return self.forward(x)

# Load the trained model
@st.cache_resource
def load_model(model_path='pet_metric_learning_resnet18_triplet.pth'):
    if not os.path.exists(model_path):
        st.error(f"Model file not found at {model_path}. Please ensure the trained model is in the correct location.")
        return None, None, None

    try:
        checkpoint = torch.load(model_path, map_location=device)

        model = EmbeddingNet(
            backbone_name=checkpoint.get('backbone_name', 'resnet18'),
            embedding_size=checkpoint.get('embedding_size', 128),
            pretrained=False
        )

        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
        model.eval()

        class_mapping = checkpoint.get('class_mapping', {'idx_to_class': {}})
        embedding_size = checkpoint.get('embedding_size', 128)

        return model, class_mapping, embedding_size
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None, None, None

# Enhanced preprocessing with data augmentation options
def get_transform(augment=False):
    if not augment:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# Function to compute embedding for an image
def get_embedding(model, image, augment=False):
    transform = get_transform(augment=augment)
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        embedding = model(image_tensor)

    return embedding.cpu()

# Enhanced function to process images with multiple augmentations and aggregate
def get_robust_embedding(model, image, num_augmentations=5):
    embeddings = []
    embeddings.append(get_embedding(model, image, augment=False))
    for _ in range(num_augmentations):
        embeddings.append(get_embedding(model, image, augment=True))
    avg_embedding = torch.mean(torch.cat(embeddings, dim=0), dim=0, keepdim=True)
    return F.normalize(avg_embedding, p=2, dim=1)

# Function to load a single image and process it
def load_and_process_image(args):
    img_path, transform, model = args
    try:
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)
        label = os.path.basename(img_path).split('_')[0]
        return img, img_path, label, img_tensor
    except Exception as e:
        return None, img_path, None, None

# Function to compute embeddings in batches
def compute_embeddings_batch(model, image_tensors):
    with torch.no_grad():
        return model(image_tensors)

# Function to find similar pets using FAISS
def find_similar_pets_faiss(query_embedding, faiss_index, database_images, database_labels, top_k=5):
    if faiss_index is None or len(database_images) == 0:
        return [], [], []
    query_np = query_embedding.numpy().astype('float32')
    if len(query_np.shape) == 2 and query_np.shape[0] == 1:
        query_np = query_np.reshape(-1)
    query_np = query_np.reshape(1, -1)
    similarities, indices = faiss_index.search(query_np, k=min(top_k, len(database_labels)))
    indices = indices[0].tolist()
    similarities = similarities[0].tolist()
    similar_images = [database_images[i] for i in indices]
    similar_labels = [database_labels[i] for i in indices]
    return similar_images, similar_labels, similarities

# Load database images and compute embeddings - optimized version
@st.cache_data
def load_database(directory, _model, embedding_size, batch_size=32, max_db_size=500, use_parallel=True):
    start_time = time.time()
    transform = get_transform()
    image_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                image_files.append(os.path.join(root, file))
    random.shuffle(image_files)
    image_files = image_files[:max_db_size]
    images = []
    labels = []
    file_paths = []
    batched_tensors = []
    st.write(f"Found {len(image_files)} images. Processing up to {max_db_size}...")
    progress_bar = st.progress(0)
    if use_parallel:
        args_list = [(img_path, transform, _model) for img_path in image_files]
        results = []
        with ThreadPoolExecutor(max_workers=min(os.cpu_count(), 8)) as executor:
            futures = {executor.submit(load_and_process_image, args): i for i, args in enumerate(args_list)}
            for i, future in enumerate(as_completed(futures)):
                result = future.result()
                if result[0] is not None:
                    img, path, label, tensor = result
                    images.append(img)
                    file_paths.append(path)
                    labels.append(label)
                    batched_tensors.append(tensor)
                progress_bar.progress((i + 1) / len(image_files))
    else:
        for i, img_path in enumerate(tqdm(image_files, desc="Loading Database")):
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img).unsqueeze(0)
                label = os.path.basename(img_path).split('_')[0]
                images.append(img)
                file_paths.append(img_path)
                labels.append(label)
                batched_tensors.append(img_tensor)
            except Exception as e:
                continue
            progress_bar.progress((i + 1) / len(image_files))
    all_embeddings = []
    _model = _model.to(device)
    batch_idx = 0
    while batch_idx < len(batched_tensors):
        end_idx = min(batch_idx + batch_size, len(batched_tensors))
        batch = torch.cat(batched_tensors[batch_idx:end_idx], dim=0).to(device)
        with torch.no_grad():
            batch_embeddings = _model(batch).cpu()
        all_embeddings.append(batch_embeddings)
        batch_idx = end_idx
    if all_embeddings:
        embeddings_tensor = torch.cat(all_embeddings, dim=0)
    else:
        embeddings_tensor = torch.tensor([])
    if len(embeddings_tensor) > 0:
        embeddings_np = embeddings_tensor.numpy().astype('float32')
        faiss_index = faiss.IndexFlatIP(embedding_size)
        faiss_index.add(embeddings_np)
    else:
        faiss_index = None
    progress_bar.empty()
    load_time = time.time() - start_time
    breed_distribution = dict(Counter(labels))
    stats = {
        'num_images': len(images),
        'num_breeds': len(breed_distribution),
        'breed_distribution': breed_distribution,
        'load_time': load_time
    }
    return images, embeddings_tensor, labels, file_paths, faiss_index, stats

# Function to visualize embeddings with t-SNE
def plot_tsne(embeddings, labels, query_embedding=None, n_components=2):
    if len(embeddings) == 0:
        return None
    if query_embedding is not None:
        query_embedding_np = query_embedding.cpu().numpy()
        all_embeddings = np.vstack([query_embedding_np, embeddings.cpu().numpy()])
        is_query = np.zeros(len(all_embeddings), dtype=bool)
        is_query[0] = True
    else:
        all_embeddings = embeddings.cpu().numpy()
        is_query = None
    tsne = TSNE(n_components=n_components, random_state=42, perplexity=min(30, max(5, len(all_embeddings)//10)))
    embeddings_2d = tsne.fit_transform(all_embeddings)
    fig, ax = plt.subplots(figsize=(10, 8))
    if is_query is not None:
        scatter = ax.scatter(
            embeddings_2d[~is_query, 0],
            embeddings_2d[~is_query, 1],
            c=[hash(label) % 100 for label in labels],
            cmap='viridis',
            alpha=0.6,
            s=50
        )
        ax.scatter(
            embeddings_2d[is_query, 0],
            embeddings_2d[is_query, 1],
            color='red',
            s=200,
            marker='*',
            edgecolors='black',
            label='Query'
        )
    else:
        scatter = ax.scatter(
            embeddings_2d[:, 0],
            embeddings_2d[:, 1],
            c=[hash(label) % 100 for label in labels],
            cmap='viridis',
            alpha=0.6,
            s=50
        )
    ax.set_title('t-SNE Visualization of Pet Embeddings')
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.7)
    unique_labels = set(labels)
    for label in unique_labels:
        indices = [i for i, l in enumerate(labels) if l == label]
        if indices:
            idx = indices[0]
            if is_query is not None:
                idx_2d = idx + 1
            else:
                idx_2d = idx
            ax.annotate(
                label,
                (embeddings_2d[idx_2d, 0], embeddings_2d[idx_2d, 1]),
                textcoords='offset points',
                xytext=(5, 5),
                ha='center',
                fontsize=8,
                bbox=dict(boxstyle='round,pad=0.3', fc='yellow', alpha=0.3)
            )
    return fig

# Function to save results for later analysis
def save_results(query_image, similar_images, similar_labels, similarities, filename='similarity_results.json'):
    results_dir = Path('results')
    results_dir.mkdir(exist_ok=True)
    query_path = results_dir / 'query_image.jpg'
    query_image.save(query_path)
    similar_paths = []
    for i, img in enumerate(similar_images):
        img_path = results_dir / f'similar_{i}.jpg'
        img.save(img_path)
        similar_paths.append(str(img_path))
    results = {
        'query_image': str(query_path),
        'similar_images': similar_paths,
        'similar_labels': similar_labels,
        'similarities': similarities,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
    }
    with open(results_dir / filename, 'w') as f:
        json.dump(results, f, indent=2)
    return str(results_dir / filename)

# Function to compute evaluation metrics
def compute_metrics(query_label, similar_labels, top_k=5):
    similar_labels = similar_labels[:top_k]
    matches = [1 if label == query_label else 0 for label in similar_labels]
    precision_at_k = sum(matches) / len(similar_labels) if similar_labels else 0
    try:
        first_match = matches.index(1) + 1
    except ValueError:
        first_match = "Not found"
    if 1 in matches:
        mrr = 1.0 / (matches.index(1) + 1)
    else:
        mrr = 0.0
    return {
        'precision@k': precision_at_k,
        'first_match': first_match,
        'mrr': mrr,
        'matches': matches
    }

# Main Streamlit app
def run_app():
    st.set_page_config(layout="wide")
    st.title("🐾 Pet Breed Similarity Search 🐾")
    st.write("Upload a pet image to find similar-looking pets in our database!")
    st.sidebar.header("System Info")
    st.sidebar.write(f"Using device: {device}")
    st.sidebar.header("Configuration")
    model_file = st.sidebar.text_input("Model File Path", "pet_metric_learning_resnet18_triplet.pth")
    database_dir = st.sidebar.text_input("Database Directory Path", "./data/oxford-iiit-pet/images")
    top_k_similar = st.sidebar.slider("Number of Similar Pets to Show", 1, 20, 5)
    use_parallel = st.sidebar.checkbox("Use Parallel Processing", value=True)
    max_db_size = st.sidebar.number_input("Max Database Size", min_value=10, max_value=5000, value=500, step=50)
    batch_size = st.sidebar.selectbox("Batch Size", options=[8, 16, 32, 64], index=1)
    st.sidebar.header("Robustness Options")
    use_augmentation = st.sidebar.checkbox("Use Data Augmentation", value=False)
    num_augmentations = st.sidebar.slider("Number of Augmentations", 1, 10, 5, disabled=not use_augmentation)
    model, class_mapping, embedding_size = load_model(model_file)
    if model is None:
        st.stop()
    idx_to_class = class_mapping.get('idx_to_class', {})
    st.sidebar.header("Database")
    if st.sidebar.button("Load/Reload Database"):
        st.cache_data.clear()
        st.experimental_rerun()
    if not os.path.isdir(database_dir):
        st.warning(f"Database directory not found: '{database_dir}'. Please provide a valid path in the sidebar.")
        database_images, database_embeddings, database_labels, file_paths, faiss_index, stats = [], torch.tensor([]), [], [], None, {
            'num_images': 0,
            'num_breeds': 0,
            'breed_distribution': {},
            'load_time': 0
        }
    else:
        with st.spinner(f"Loading database from '{database_dir}'... This might take a while."):
            database_images, database_embeddings, database_labels, file_paths, faiss_index, stats = load_database(
                database_dir, model, embedding_size, batch_size=batch_size, max_db_size=max_db_size, use_parallel=use_parallel
            )
        if not database_images:
            st.error("No images loaded from the database directory. Check the path and image files.")
        else:
            st.sidebar.success(f"Loaded {stats['num_images']} images from {stats['num_breeds']} breeds in {stats['load_time']:.2f}s.")
    tab1, tab2, tab3 = st.tabs(["Similarity Search", "Database Stats", "Embedding Visualization"])
    with tab1:
        col1, col2 = st.columns([1, 2])
        with col1:
            st.header("Query Image")
            uploaded_file = st.file_uploader("Choose a pet image...", type=["jpg", "jpeg", "png"])
            if uploaded_file is not None:
                query_image = Image.open(uploaded_file).convert('RGB')
                st.image(query_image, caption="Your Query Image", use_column_width=True)
                query_label = st.text_input("Optional: Enter the breed of this pet for evaluation metrics", "")
        with col2:
            st.header("Similar Pets Found")
            if uploaded_file is not None and len(database_images) > 0:
                if st.button("Find Similar Pets"):
                    with st.spinner("Comparing your pet..."):
                        start_time = time.time()
                        if use_augmentation:
                            query_embedding = get_robust_embedding(model, query_image, num_augmentations)
                        else:
                            query_embedding = get_embedding(model, query_image)
                        similar_images, similar_labels, similarities = find_similar_pets_faiss(
                            query_embedding, faiss_index, database_images, database_labels, top_k=top_k_similar
                        )
                        search_time = time.time() - start_time
                        if not similar_images:
                            st.warning("Could not find any similar pets.")
                        else:
                            st.success(f"Found {len(similar_images)} similar pets in {search_time:.3f} seconds.")
                            num_results = len(similar_images)
                            num_cols = min(5, num_results)
                            rows = (num_results + num_cols - 1) // num_cols
                            for row in range(rows):
                                cols = st.columns(num_cols)
                                for col_idx in range(num_cols):
                                    idx = row * num_cols + col_idx
                                    if idx < num_results:
                                        with cols[col_idx]:
                                            st.image(similar_images[idx], caption=f"{similar_labels[idx].replace('_', ' ').title()}", use_column_width=True)
                                            st.write(f"Similarity: {similarities[idx]:.3f}")
                            if query_label and query_label.strip():
                                metrics = compute_metrics(query_label.strip().lower(), [label.lower() for label in similar_labels], top_k=len(similar_labels))
                                st.subheader("Evaluation Metrics")
                                metrics_cols = st.columns(3)
                                metrics_cols[0].metric("Precision@k", f"{metrics['precision@k']:.2f}")
                                metrics_cols[1].metric("First Match Position", metrics['first_match'])
                                metrics_cols[2].metric("Mean Reciprocal Rank", f"{metrics['mrr']:.2f}")
                                match_html = "<div style='display:flex;'>"
                                for i, match in enumerate(metrics['matches']):
                                    color = "green" if match else "red"
                                    match_html += f"<div style='margin-right:5px;width:20px;height:20px;background:{color};'></div>"
                                match_html += "</div>"
                                st.write("Match Pattern:")
                                st.markdown(match_html, unsafe_allow_html=True)
                            if st.button("Save Results"):
                                result_path = save_results(query_image, similar_images, similar_labels, similarities)
                                st.success(f"Results saved to {result_path}")
                            st.subheader("Embedding Visualization")
                            with st.spinner("Generating t-SNE visualization..."):
                                tsne_fig = plot_tsne(database_embeddings, database_labels, query_embedding)
                                if tsne_fig:
                                    st.pyplot(tsne_fig)
                                else:
                                    st.warning("Could not generate embedding visualization.")
            elif uploaded_file is None:
                st.info("Upload an image to start the search.")
            elif len(database_images) == 0:
                st.warning("Database is empty or not loaded. Please check the database path and click 'Load/Reload Database' in the sidebar.")
    with tab2:
        if len(database_images) > 0:
            st.header("Database Statistics")
            st.write(f"Total images: {stats['num_images']}")
            st.write(f"Number of breeds: {stats['num_breeds']}")
            st.write(f"Load time: {stats['load_time']:.2f} seconds")
            st.subheader("Breed Distribution")
            breeds = list(stats['breed_distribution'].keys())
            counts = list(stats['breed_distribution'].values())
            breed_fig, ax = plt.subplots(figsize=(10, 8))
            y_pos = np.arange(len(breeds))
            ax.barh(y_pos, counts)
            ax.set_yticks(y_pos)
            ax.set_yticklabels(breeds)
            ax.invert_yaxis()
            ax.set_xlabel('Count')
            ax.set_title('Number of Images per Breed')
            st.pyplot(breed_fig)
        else:
            st.info("Load a database to see statistics.")
    with tab3:
        if len(database_images) > 0:
            st.header("Embedding Space Visualization")
            st.write("This visualization shows the distribution of pet breed embeddings in a 2D space using t-SNE.")
            with st.spinner("Generating embedding visualization..."):
                vis_fig = plot_tsne(database_embeddings, database_labels)
                if vis_fig:
                    st.pyplot(vis_fig)
                else:
                    st.warning("Could not generate visualization.")
        else:
            st.info("Load a database to visualize embeddings.")

# The following check allows the script to be run directly when converted to a Python file
if __name__ == "__main__":
    run_app()