In [None]:
import cv2
import torch
import torch.nn.functional as F
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import clip
import supervision as sv
from PIL import Image
import numpy as np
from tqdm import tqdm
from datetime import timedelta

In [None]:
class Pipeline:
    def __init__(self, device=torch.device("mps")):
        self.device = device
        self.initialiseModels()
    def initialiseModels(self):
        self.detectionModel = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(self.device)
        self.detectionModelImageProcessor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
        self.clipModel, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        self.captionModel = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(self.device)
        self.captionProcessor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
        self.tracker = sv.ByteTrack(track_activation_threshold=0.4, lost_track_buffer=40, minimum_matching_threshold=0.4, frame_rate=30)
        self.boxAnnotator = sv.BoxAnnotator()
        self.labelAnnotator = sv.LabelAnnotator()
        self.captionCache = []
    def processVideo(self, videoPath, query, outputPath, captionInterval=5, similarity=0.2, captionsFile="captions.txt"):
        """Main function for single pass execution"""
        queryEmbedding = self.encodeQuery(query)
        capture = cv2.VideoCapture(videoPath)
        fps = capture.get(cv2.CAP_PROP_FPS)
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        totalFrames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        videoWriter = cv2.VideoWriter(outputPath, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        captionFrames = []
        currentBatch = []
        frameCount = 0
        with open(captionsFile, 'w') as cf:
            with tqdm(total=totalFrames, desc="Processing Video") as bar:
                while capture.isOpened():
                    ret, frame = capture.read()
                    if not ret: break
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    currentTime = frameCount/fps
                    # Caption generation
                    if frameCount%int(fps*captionInterval)==0:
                        timeStamp = frameCount/fps
                        currentBatch.append((Image.fromarray(frame), timeStamp))
                        if len(currentBatch)>=4:
                            self.processCaptionBatch(currentBatch, captionInterval, cf)
                            currentBatch = []
                    
                    # Object detection and tracking
                    with torch.no_grad():
                        detections = self.detectObjects(frame)
                        filteredDetections = self.filterDetections(detections, frame, queryEmbedding, similarity)
                        trackedDetections = self.tracker.update_with_detections(filteredDetections)
                    # Annotate video frames
                    annotatedFrame = self.annotateFrame(frame, trackedDetections, currentTime)
                    videoWriter.write(cv2.cvtColor(annotatedFrame, cv2.COLOR_RGB2BGR))
                    frameCount+=1
                    bar.update(1)
            # Process remaining frames
            if currentBatch:
                self.processCaptionBatch(currentBatch, captionInterval, cf)
        capture.release()
        videoWriter.release()
        return outputPath
    def frameSimilarity(self, frame, queryEmbedding):
        """Scene similarity"""
        pil_frame = Image.fromarray(frame)
        clip_input = self.clip_preprocess(pil_frame).unsqueeze(0).to(self.device)
        with torch.no_grad():
            frame_embedding = self.clipModel.encode_image(clip_input)
        sim = F.cosine_similarity(queryEmbedding, frame_embedding, dim=-1).item()
        return sim
    def formatTimeStamp(self, seconds):
        """Convert seconds to HH:MM:SS.mmm"""
        td = timedelta(seconds=seconds)
        totalSeconds = td.total_seconds()
        hours = int(totalSeconds//3600)
        minutes = int((totalSeconds%3600)//60)
        seconds = totalSeconds%60
        return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
    def processCaptionBatch(self, batch, interval, captionFile):
        """Process batch and captions"""
        frames, timestamps = zip(*batch)
        inputs = self.captionProcessor(images=frames, return_tensors="pt").to(self.device, torch.float16)
        with torch.no_grad():
            generatedIDs = self.captionModel.generate(**inputs, max_new_tokens=50)
        captions = self.captionProcessor.batch_decode(generatedIDs, skip_special_tokens=True)
        for timestamp, caption in zip(timestamps, captions):
            startTime = timestamp
            endTime = timestamp+interval
            captionEntry = {"start": startTime, "end": endTime, "text": caption.strip()}
            self.captionCache.append(captionEntry)
            startStr = self.formatTimeStamp(startTime)
            endStr = self.formatTimeStamp(endTime)
            ouputLine = f"[{startStr}-->{endStr}] {caption.strip()}"
            print(ouputLine)
            captionFile.write(ouputLine+"\n")
        self.captionCache.sort(key=lambda x: x["start"])
    def encodeQuery(self, query):
        """Encode query with CLIP"""
        with torch.no_grad():
            queryInput = clip.tokenize([query]).to(self.device)
            return self.clipModel.encode_text(queryInput)
    def detectObjects(self, frame):
        """Object detection on a frame"""
        inputs = self.detectionModelImageProcessor(images=Image.fromarray(frame), return_tensors="pt").to(self.device)
        outputs = self.detectionModel(**inputs)
        results = self.detectionModelImageProcessor.post_process_object_detection(outputs, threshold=0.5, target_sizes=[(frame.shape[0], frame.shape[1])])[0]
        return sv.Detections.from_transformers(results)
    def filterDetections(self, detections, frame, queryEmbedding, threshold):
        """Batch processing detection crops with CLIP"""
        if len(detections)==0:
            return detections
        crops = []
        validIndices = []
        for idx, (x1, y1, x2, y2) in enumerate(detections.xyxy):
            crop = frame[int(y1):int(y2), int(x1):int(x2)]
            if crop.size>0:
                crops.append(Image.fromarray(crop))
                validIndices.append(idx)
        if not crops:
            return sv.Detections.empty()
        clipInputs = torch.stack([self.clip_preprocess(crop) for crop in crops]).to(self.device)
        with torch.no_grad():
            imageEmbedding = self.clipModel.encode_image(clipInputs)
        similarities = F.cosine_similarity(queryEmbedding, imageEmbedding, dim=-1)
        
        #for idx, sim in zip(validIndices, similarities):
            #print(f"Box {idx} similarity: {sim.item()}")
        
        mask = similarities>=threshold
        filteredIndices = np.array(validIndices)[mask.cpu().numpy()]
        return detections[filteredIndices]
    def annotateFrame(self, frame, detections, currentTime):
        """Draw annotations on frame"""
        labels = [f"#{tid} {self.detectionModel.config.id2label[cid]} {conf:.2f}" for tid, cid, conf in zip(detections.tracker_id, detections.class_id, detections.confidence)]
        annotated = self.boxAnnotator.annotate(scene=frame.copy(), detections=detections)
        annotated = self.labelAnnotator.annotate(scene=annotated, detections=detections, labels=labels)
        return annotated
        

In [None]:
pipeline = Pipeline()
result = pipeline.processVideo(videoPath="/Users/ivanng/Downloads/20660921-hd_1920_1080_30fps.mp4", query="person with suitcase", outputPath="/Users/ivanng/annotated.mp4", captionInterval=5, similarity=0.25, captionsFile="/Users/ivanng/annotatedCaptions.txt")
print(f"\nProcessed video save to {result}")
print(f"captions saved")