# Installations

In [None]:
!pip install gradio --quiet

# Imports

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity

# Inference

In [None]:
WEIGHTS_PATH = 'weights/triplet_model.pth'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and send to device only if CUDA has enough memory
torch.cuda.empty_cache()  # Free up GPU memory

try:
    model = TripletNetwork().to(device)
    model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
    model.eval()
except RuntimeError:
    print("⚠️ CUDA Memory issue detected. Switching to CPU.")
    device = torch.device("cpu")
    model = TripletNetwork().to(device)
    model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu"))
    model.eval()


face_database = {}
image_database = {}

def extract_embedding(image):
    """Extracts feature embedding from an uploaded image."""
    image = image.convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model.forward_once(image)  # Extract features
        embedding = embedding.cpu().numpy().flatten()  # Convert to NumPy
    return embedding

def add_person(name, image):
    """Adds a person's embedding to the face database."""
    if image is None:
        return "❌ No image uploaded!"
    
    embedding = extract_embedding(image)
    image_resized = image.resize((128, 128))  # Resize for consistent display
    
    if name in face_database:
        face_database[name].append(embedding)
        image_database[name].append(image_resized)
    else:
        face_database[name] = [embedding]
        image_database[name] = [image_resized]
    
    return f"✅ {name} has been added to the database!"

def remove_person(name):
    """Removes a person and all their images from the database."""
    if name in face_database:
        del face_database[name]
        del image_database[name]
        return f"✅ {name} and their images have been removed from the database!"
    else:
        return "❌ Person not found in the database!"

def recognize_person(image):
    """Finds the most similar person in the database using cosine similarity."""
    if image is None:
        return "❌ No image uploaded!"
    
    embedding = extract_embedding(image)
    best_match = None
    best_score = -1  # Cosine similarity ranges from -1 to 1 (higher is better)
    
    for name, stored_embeddings in face_database.items():
        stored_embedding = np.mean(stored_embeddings, axis=0).reshape(1, -1)
        embedding = embedding.reshape(1, -1)
        
        # Compute Cosine Similarity
        similarity = cosine_similarity(embedding, stored_embedding)[0][0]
        
        if similarity > best_score:  # Higher similarity means a better match
            best_match = name
            best_score = similarity
    
    threshold = 0.5  # Adjust based on dataset (0.5 to 0.8 works well)
    result = best_match if best_score > threshold else "Unknown"
    return f"🔍 Recognized as: {result} (Similarity: {best_score:.4f})"

def show_dataset():
    dataset_info = ""
    for name, images in image_database.items():
        dataset_info += f"<h3>👤 {name} ({len(images)} images)</h3>"
        for img in images:
            encoded_img = encode_image(img)
            dataset_info += f'<div style="display:inline-block; margin:5px; text-align:center;">'
            dataset_info += f'<img src="data:image/png;base64,{encoded_img}" alt="{name}" style="width:128px; height:128px; margin:5px; border:1px solid #ccc;">'
            dataset_info += f'</div>'
    return dataset_info

def encode_image(image):
    """Encodes an image to base64 for display."""
    import base64
    from io import BytesIO
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

with gr.Blocks() as demo:
    gr.Markdown("# Triplet Face Recognition System")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("## ➕ Add a Person")
            new_name_input = gr.Textbox(label="Enter Name", placeholder="Enter new name")
            image_input = gr.Image(type="pil", label="Upload Image")
            add_button = gr.Button("➕ Add Person")
            add_output = gr.Textbox()
        
        with gr.Column():
            gr.Markdown("## 🔍 Recognize a Person")
            image_recognize = gr.Image(type="pil", label="Upload Image to Recognize")
            recognize_button = gr.Button("🔍 Recognize Person")
            recognize_output = gr.Textbox()

    with gr.Row():
        with gr.Column():
            gr.Markdown("## 📂 Current Dataset")
            dataset_output = gr.HTML()
            show_dataset_button = gr.Button("🔄 Refresh Dataset")

        with gr.Column():
            gr.Markdown("## ❌ Remove a Person")
            remove_name_dropdown = gr.Dropdown(label="Select Person to Remove", choices=list(face_database.keys()), interactive=True)
            remove_person_button = gr.Button("Remove Person")
            remove_person_output = gr.Textbox()

    def add_person_ui(new_name, image):
        """Adds a person and updates the dropdown."""
        result = add_person(new_name, image)
        return result, gr.update(choices=list(face_database.keys()))

    def remove_person_ui(name):
        """Removes a person and updates the dropdown and dataset."""
        result = remove_person(name)
        return result, show_dataset(), gr.update(choices=list(face_database.keys()))

    add_button.click(
        add_person_ui,
        inputs=[new_name_input, image_input],
        outputs=[add_output, remove_name_dropdown]
    )

    recognize_button.click(
        recognize_person,
        inputs=[image_recognize],
        outputs=recognize_output
    )

    show_dataset_button.click(
        show_dataset,
        outputs=dataset_output
    )

    remove_person_button.click(
        remove_person_ui,
        inputs=[remove_name_dropdown],
        outputs=[remove_person_output, dataset_output, remove_name_dropdown]
    )

    # Initialize dataset display on load
    demo.load(
        fn=show_dataset,
        outputs=dataset_output
    )

demo.launch()