In [8]:
import cv2
import os
import torch
import clip
from PIL import Image
import gradio as gr
import numpy as np

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Load the stored image embeddings
embeddings = np.load("embeddings.npy")

def image_retrieval(query, image_folder):
    try:
        if not os.path.isdir(image_folder):
            return None, "Invalid folder path provided."

        image_embeddings = []
        for image_file in os.listdir(image_folder):
            image_path = os.path.join(image_folder, image_file)
            image = cv2.imread(image_path)

            if image is not None:
                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image_pil = Image.fromarray(image_rgb)
                preprocessed_image = preprocess(image_pil).unsqueeze(0).to(device)
                image_features = model.encode_image(preprocessed_image)
                image_embeddings.append(image_features.detach().cpu().numpy())

        if not image_embeddings:
            return None, "No valid images found in the specified folder."

        image_embeddings_tensor = torch.tensor(image_embeddings, device=device)
        text_encoded = model.encode_text(clip.tokenize([query]).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
        similarities = (image_embeddings_tensor @ text_encoded.T).squeeze(1)
        best_image_idx = similarities.argmax().item()

        best_image_path = os.path.join(image_folder, os.listdir(image_folder)[best_image_idx])
        caption = f"Best matching image to query '{query}'"
        return best_image_path, caption
    except Exception as e:
        return None, f"Error: {str(e)}"

def video_frame_retrieval(video_file, query):
    try:
        cap = cv2.VideoCapture(video_file.name)
        output_directory = "processed_video_frames"
        os.makedirs(output_directory, exist_ok=True)

        frame_counter = 0
        best_frame_idx = -1
        best_similarity = -1.0

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            if frame_counter % 301 == 0:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image_pil = Image.fromarray(frame_rgb)
                preprocessed_image = preprocess(image_pil).unsqueeze(0).to(device)
                image_features = model.encode_image(preprocessed_image)

                with torch.no_grad():
                    text_encoded = model.encode_text(clip.tokenize([query]).to(device))
                    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
                    similarity = (image_features @ text_encoded.T).item()

                if similarity > best_similarity:
                    best_similarity = similarity
                    best_frame_idx = frame_counter

            frame_counter += 1

        cap.set(cv2.CAP_PROP_POS_FRAMES, best_frame_idx)
        ret, best_frame = cap.read()
        cap.release()

        if ret:
            best_frame_path = os.path.join(output_directory, "best_frame.jpg")
            cv2.imwrite(best_frame_path, best_frame)
            caption = f"Best matching frame to query '{query}'"
            return best_frame_path, caption
        else:
            return None, "No matching frame found."
    except Exception as e:
        return None, f"Error: {str(e)}"

def process_media_and_caption(media_file, query, mode, image_folder=None):
    if mode == "image":
        if image_folder:
            return image_retrieval(query, image_folder)
        else:
            return None, "Please provide the path to the image folder."
    elif mode == "video":
        return video_frame_retrieval(media_file, query)
    else:
        return None, "Invalid mode selected."

iface = gr.Interface(
    fn=process_media_and_caption,
    inputs=[
        gr.inputs.File(label="Upload a Media"),
        gr.inputs.Textbox(label="Enter a Query"),
        gr.inputs.Radio(["image", "video"], label="Select Mode"),
        gr.inputs.Textbox(label="Enter Image Folder Path (for image mode)")
    ],
    outputs=["image", "text"],
    title="Media Retrieval and Captioning",
    description="Upload an image or video and enter a query to retrieve a matching media with caption."
)
iface.launch(share=True)


  gr.inputs.File(label="Upload a Media"),
  gr.inputs.File(label="Upload a Media"),
  gr.inputs.File(label="Upload a Media"),
  gr.inputs.Textbox(label="Enter a Query"),
  gr.inputs.Textbox(label="Enter a Query"),
  gr.inputs.Textbox(label="Enter a Query"),
  gr.inputs.Radio(["image", "video"], label="Select Mode"),
  gr.inputs.Radio(["image", "video"], label="Select Mode"),
  gr.inputs.Textbox(label="Enter Image Folder Path (for image mode)")
  gr.inputs.Textbox(label="Enter Image Folder Path (for image mode)")
  gr.inputs.Textbox(label="Enter Image Folder Path (for image mode)")


Running on local URL:  http://127.0.0.1:7875

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


