<a href="https://colab.research.google.com/github/daisysong76/AI--Machine--learning/blob/main/Video_Search_and_Retrieval_System_with_spatial_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List, Tuple, Dict, Optional, NamedTuple
from dataclasses import dataclass
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel, DetrImageProcessor, DetrForObjectDetection
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.ops import box_iou
import faiss
import logging
from datetime import datetime
import json
import concurrent.futures
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BoundingBox(NamedTuple):
    """Represents a bounding box with normalized coordinates."""
    x1: float
    y1: float
    x2: float
    y2: float

    def area(self) -> float:
        """Calculate area of bounding box."""
        return (self.x2 - self.x1) * (self.y2 - self.y1)

    def to_pixels(self, width: int, height: int) -> Tuple[int, int, int, int]:
        """Convert normalized coordinates to pixel coordinates."""
        return (
            int(self.x1 * width),
            int(self.y1 * height),
            int(self.x2 * width),
            int(self.y2 * height)
        )

@dataclass
class SpatialFeature:
    """Represents a spatial feature with its location and embedding."""
    embedding: np.ndarray
    bbox: BoundingBox
    confidence: float
    label: Optional[str] = None

@dataclass
class VideoFrame:
    """Represents a single frame from a video with its metadata and spatial features."""
    frame_id: str
    video_path: str
    timestamp: float
    global_embedding: np.ndarray
    spatial_features: List[SpatialFeature]
    frame_width: int
    frame_height: int

class SpatialFeatureExtractor:
    """Handles spatial feature extraction using DETR for object detection and CLIP for features."""

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize DETR for object detection
        self.detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        self.detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
        self.detr_model.to(self.device)

        # Initialize CLIP for feature extraction
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model.to(self.device)

    def extract_spatial_features(self, image: Image.Image,
                               confidence_threshold: float = 0.7) -> List[SpatialFeature]:
        """Extract spatial features using DETR and CLIP."""

        width, height = image.size

        # Detect objects using DETR
        with torch.no_grad():
            inputs = self.detr_processor(images=image, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.detr_model(**inputs)

        # Process DETR outputs
        probas = outputs.logits.softmax(-1)[0, :, :-1]
        keep = probas.max(-1).values > confidence_threshold

        # Convert boxes to normalized coordinates
        boxes = outputs.pred_boxes[0][keep]

        spatial_features = []
        for box, prob in zip(boxes, probas[keep]):
            # Convert box to normalized coordinates
            x1, y1, x2, y2 = box.cpu().tolist()
            bbox = BoundingBox(
                max(0, min(1, (x1 + 1) / 2)),
                max(0, min(1, (y1 + 1) / 2)),
                max(0, min(1, (x2 + 1) / 2)),
                max(0, min(1, (y2 + 1) / 2))
            )

            # Extract region feature using CLIP
            region = image.crop(bbox.to_pixels(width, height))
            inputs = self.clip_processor(images=region, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            region_features = self.clip_model.get_image_features(**inputs)
            region_features = F.normalize(region_features, dim=-1)

            spatial_features.append(SpatialFeature(
                embedding=region_features.cpu().numpy()[0],
                bbox=bbox,
                confidence=float(prob.max()),
                label=self.detr_processor.id2label[prob.argmax().item()]
            ))

        return spatial_features

class SpatialVectorDatabase:
    """Manages vector databases for both global and spatial features."""

    def __init__(self, dimension: int = 512):
        self.dimension = dimension
        self.global_index = faiss.IndexFlatIP(dimension)
        self.spatial_index = faiss.IndexFlatIP(dimension)

        self.frame_metadata: List[VideoFrame] = []
        self.spatial_feature_map: Dict[int, Tuple[VideoFrame, SpatialFeature]] = {}

    def add_frame(self, frame: VideoFrame):
        """Add frame with both global and spatial features to indices."""

        self.global_index.add(frame.global_embedding.reshape(1, -1))

        spatial_embeddings = []
        for feature in frame.spatial_features:
            spatial_embeddings.append(feature.embedding)

        if spatial_embeddings:
            spatial_embeddings = np.vstack(spatial_embeddings)
            start_idx = self.spatial_index.ntotal
            self.spatial_index.add(spatial_embeddings)

            # Map spatial feature indices to their metadata
            for i, feature in enumerate(frame.spatial_features):
                self.spatial_feature_map[start_idx + i] = (frame, feature)

        self.frame_metadata.append(frame)

    def spatial_search(self,
                      query_embedding: np.ndarray,
                      query_bbox: Optional[BoundingBox] = None,
                      k: int = 100,
                      iou_threshold: float = 0.5) -> List[Tuple[VideoFrame, SpatialFeature, float]]:
        """Search for similar spatial features with optional spatial constraints."""
        query_embedding = query_embedding.reshape(1, -1)

        # Perform similarity search
        similarities, indices = self.spatial_index.search(query_embedding, k)

        results = []
        for idx, similarity in zip(indices[0], similarities[0]):
            if idx in self.spatial_feature_map:
                frame, feature = self.spatial_feature_map[idx]

                # Apply spatial constraint if query_bbox is provided
                if query_bbox is not None:
                    # Calculate IoU between query_bbox and feature bbox
                    query_box = torch.tensor([[
                        query_bbox.x1, query_bbox.y1,
                        query_bbox.x2, query_bbox.y2
                    ]])
                    feature_box = torch.tensor([[
                        feature.bbox.x1, feature.bbox.y1,
                        feature.bbox.x2, feature.bbox.y2
                    ]])
                    iou = box_iou(query_box, feature_box)[0][0]

                    if iou < iou_threshold:
                        continue

                results.append((frame, feature, float(similarity)))

        return results

class VideoProcessor:
    """Enhanced video processor with spatial feature extraction."""

    def __init__(self, frame_interval: float = 1.0):
        self.frame_interval = frame_interval
        self.feature_extractor = FeatureExtractor(use_clip=True)
        self.spatial_extractor = SpatialFeatureExtractor()

    def extract_frames(self, video_path: str) -> List[VideoFrame]:
        """Extract frames with both global and spatial features."""
        frames = []
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_skip = int(fps * self.frame_interval)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_skip == 0:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame_rgb)

                global_embedding = self.feature_extractor.extract_features(image)

                spatial_features = self.spatial_extractor.extract_spatial_features(image)

                timestamp = frame_count / fps
                frame_id = f"{Path(video_path).stem}_{timestamp:.2f}"

                frames.append(VideoFrame(
                    frame_id=frame_id,
                    video_path=video_path,
                    timestamp=timestamp,
                    global_embedding=global_embedding,
                    spatial_features=spatial_features,
                    frame_width=width,
                    frame_height=height
                ))

            frame_count += 1

        cap.release()
        return frames

class VideoSearchEngine:
    """Enhanced search engine with spatial search capabilities."""

    def __init__(self, frame_interval: float = 1.0):
        self.video_processor = VideoProcessor(frame_interval=frame_interval)
        self.vector_db = SpatialVectorDatabase()
        self.feature_extractor = FeatureExtractor()
        self.spatial_extractor = SpatialFeatureExtractor()

    def search_by_region(self,
                        query_image: Image.Image,
                        query_bbox: Optional[BoundingBox] = None,
                        min_similarity: float = 0.7,
                        iou_threshold: float = 0.5) -> List[SearchResult]:
        """Search for video segments using a query image and optional spatial constraints."""
        # Extract spatial features from query image
        spatial_features = self.spatial_extractor.extract_spatial_features(query_image)

        if not spatial_features:
            logger.warning("No spatial features detected in query image")
            return []

        # If no specific bbox is provided, use the most confident detection
        if query_bbox is None:
            query_feature = max(spatial_features, key=lambda x: x.confidence)
            query_bbox = query_feature.bbox
            query_embedding = query_feature.embedding
        else:
            # Extract features from the specified region
            region = query_image.crop(query_bbox.to_pixels(*query_image.size))
            query_embedding = self.feature_extractor.extract_features(region)

        # Perform spatial search
        similar_regions = self.vector_db.spatial_search(
            query_embedding,
            query_bbox,
            iou_threshold=iou_threshold
        )

        # Filter by similarity threshold
        similar_regions = [(frame, feature, sim)
                          for frame, feature, sim in similar_regions
                          if sim >= min_similarity]

        # Group results by video and temporal proximity
        return self._cluster_spatial_results(similar_regions)

    def _cluster_spatial_results(self,
                               similar_regions: List[Tuple[VideoFrame, SpatialFeature, float]],
                               time_threshold: float = 3.0) -> List[SearchResult]:
        """Group similar spatial regions into video segments."""
        if not similar_regions:
            return []

        # Sort by video path and timestamp
        similar_regions.sort(key=lambda x: (x[0].video_path, x[0].timestamp))

        results = []
        current_group = []

        for frame, feature, similarity in similar_regions:
            if not current_group:
                current_group.append((frame, feature, similarity))
                continue

            prev_frame, _, _ = current_group[-1]

            if (frame.video_path == prev_frame.video_path and
                frame.timestamp - prev_frame.timestamp <= time_threshold):
                current_group.append((frame, feature, similarity))
            else:
                results.append(self._create_spatial_search_result(current_group))
                current_group = [(frame, feature, similarity)]

        if current_group:
            results.append(self._create_spatial_search_result(current_group))

        return results

    def _create_spatial_search_result(self,
                                    group: List[Tuple[VideoFrame, SpatialFeature, float]]) -> SearchResult:
        """Create a SearchResult object from a group of spatial matches."""
        frames, features, similarities = zip(*group)
        return SearchResult(
            video_path=frames[0].video_path,
            start_time=min(f.timestamp for f in frames),
            end_time=max(f.timestamp for f in frames),
            confidence=float(np.mean(similarities)),
            matching_frames=[f.timestamp for f in frames],
            bboxes=[f.bbox for f in features]  # Include matching regions
        )

def main():
    """Example usage of the spatial video search engine."""
    search_engine = VideoSearchEngine(frame_interval=0.5)

    # Index videos from a directory
    search_engine.index_videos("path/to/videos")

    # Define a region of interest (normalized coordinates)
    query_bbox = BoundingBox(x1=0.2, y1=0.3, x2=0.8, y2=0.9)

    # Perform a spatial search using a query image and region
    query_image = Image.open("path/to/query.jpg")
    results = search_engine.search_by_region(
        query_image,
        query_bbox=query_bbox,
        min_similarity=0.7,
        iou_threshold=0.5
    )

    for result in results:
        print(f"Found match in {result.video_path}")
        print(f"Time range: {result.start_time:.2f}s - {result.end_time:.2f}s")
        print(f"Confidence: {result.confidence:.2%}")
        print("Matching regions:")
        for bbox in result.bboxes:
            print(f"  Region: ({bbox.x1:.2f}, {bbox.y1:.2f}) to ({bbox.x2:.2f}, {bbox.y2:.2f})")
        print("---")

if __name__ == "__main__":
    main()