<a href="https://colab.research.google.com/github/caleb-stewart/Trademark-Analysis-Identification-Tool/blob/main/video_trait.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# !pip install opencv-python
# !pip install ultralytics
# !pip install pillow
# !pip install faiss-cpu

In [14]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as tr
import cv2
import faiss
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, BeitFeatureExtractor, BeitModel
from torchvision.models.feature_extraction import create_feature_extractor
from ultralytics import YOLO
import os
from collections import defaultdict

In [15]:
class EmbeddingExtractor:
    """Class for extracting image embeddings using ResNet-50, CLIP, and BEiT."""

    def __init__(self, device=None):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # Load ResNet-50 model
        self.resnet = torchvision.models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()
        self.resnet = self.resnet.to(self.device).eval()

        # Load CLIP model
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device).eval()
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        # Load BEiT model
        self.beit_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224").to(self.device).eval()
        self.beit_processor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")

    def preprocess_resnet(self, img):
        transformations = tr.Compose([
            tr.Resize((224, 224)),
            tr.ToTensor(),
            tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        img = transformations(img).unsqueeze(0).to(self.device)
        return img

    def preprocess_clip(self, img):
        return self.clip_processor(images=img, return_tensors="pt")["pixel_values"].to(self.device)

    def preprocess_beit(self, img):
        return self.beit_processor(images=img, return_tensors="pt")["pixel_values"].to(self.device)

    def get_embedding(self, img, model_name="resnet"):

        img = Image.fromarray(img)

        if model_name == "resnet":
            img_tensor = self.preprocess_resnet(img)
            with torch.no_grad():
                embedding = self.resnet(img_tensor).cpu().numpy()

        elif model_name == "clip":
            img_tensor = self.preprocess_clip(img)
            with torch.no_grad():
                embedding = self.clip_model.get_image_features(img_tensor).cpu().numpy()

        elif model_name == "beit":
            img_tensor = self.preprocess_beit(img)
            with torch.no_grad():
                embedding = self.beit_model(img_tensor).last_hidden_state.mean(dim=1).cpu().numpy()

        else:
            raise ValueError("Invalid model name. Choose from: resnet, clip, beit.")

        return embedding

    @staticmethod
    def cosine_similarity(emb1, emb2):

        return torch.nn.functional.cosine_similarity(torch.tensor(emb1), torch.tensor(emb2)).item()

    @staticmethod
    def euclidean_distance(emb1, emb2):

        return np.linalg.norm(emb1 - emb2)

In [22]:
model = YOLO("best.pt")
embed_extract = EmbeddingExtractor()

In [23]:
def extract_logo_regions(image, save_crop=False, output_dir="cropped_logos"):
    """Runs YOLO on an image and extracts detected logo regions."""

    # Check if input is a file path or an image array
    if isinstance(image, str):
        img = cv2.imread(image)
    else:
        img = image

    if img is None:
        print("Error: Could not load image.")
        return [], []

    results = model(img)
    logo_regions = []
    bounding_boxes = []

    if save_crop and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for idx, box in enumerate(results[0].boxes):
        xyxy = box.xyxy[0].tolist()
        x1, y1, x2, y2 = map(int, xyxy)
        cropped_logo = img[y1:y2, x1:x2]  # extract detected region

        if save_crop and cropped_logo.size > 0:
            cropped_logo_path = os.path.join(output_dir, f"cropped_logo_{idx}.jpg")
            cv2.imwrite(cropped_logo_path, cropped_logo)
            print(f"Logo {idx} saved: {cropped_logo_path}")

        if cropped_logo.size > 0:
            logo_regions.append(cropped_logo)
            bounding_boxes.append((x1, y1, x2, y2))
            print(f"Logo {idx} detected at coordinates: ({x1}, {y1}) -> ({x2}, {y2})")

    return logo_regions, bounding_boxes

In [94]:


def process_video(input_video_path, output_video_path, frame_skip=5):
    # resnet size = 2048
    # clip size = 512
    # beit size = 768
    embedding_dim = 768 # Size for BEIT
    faiss_index = faiss.IndexFlatL2(embedding_dim)
    logo_id_counter = 0 # How many unique logos we've seen
    logo_id_map = {}  # maps FAISS index to logo ID
    logo_appearance_counts = defaultdict(int) # How many times a unique logo has appeared
    
    cap = cv2.VideoCapture(input_video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    frame_idx = 0
    save_frame = False

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break  # stop if video ends

        if frame_idx % frame_skip == 0:  # process every 5th frame
            print(f"Processing frame {frame_idx}")

            # extract detected logos from the current frame
            input_logos, input_bboxes = extract_logo_regions(frame, save_crop=False)

            # draw a bounding box around each detected logo
            for input_logo, bbox in zip(input_logos, input_bboxes):
                embedding = embed_extract.get_embedding(input_logo, model_name='beit') # Get the embedding of the frame
                faiss.normalize_L2(embedding) # normalize the embedding. Works really well with FAISS

                if faiss_index.ntotal == 0: # First entry into FAISS
                    faiss_index.add(np.array(embedding)) # Get embedding
                    logo_id_map[0] = logo_id_counter # First unique logo
                    logo_appearance_counts[logo_id_counter] += 1 # increment the unique logo
                    logo_id_counter += 1 # Go to next unique ID
                    assigned_id = logo_id_counter - 1 # current unique ID
                    save_frame = True
                else:
                    # Get the L2 distance and index 
                    D, I = faiss_index.search(np.array(embedding), k=1)
                    print("Distance:", D[0][0])
                    # Lower the distance, the better
                    if D[0][0] < 0.5:  # If a distance is above a 0.5, then the logo hasnt been seen
                        print("INDEX ALREADY EXISTS")
                        assigned_id = logo_id_map[I[0][0]]
                        logo_appearance_counts[assigned_id] += 1 # increase the amount of times weve seen this logo
                        save_frame = False
                    else:                        
                        print("ADDING NEW INDEX") # Create a new index into FAISS
                        faiss_index.add(np.array(embedding)) # Add a new index (embedding)
                        logo_id_map[faiss_index.ntotal - 1] = logo_id_counter # Assign the new index to a logo_id
                        logo_appearance_counts[logo_id_counter] += 1 # increment how many times weve seen this unique ID
                        assigned_id = logo_id_counter # current assigned ID
                        logo_id_counter += 1 # Go to the next unique ID
                        save_frame = True
                    
                x1, y1, x2, y2 = bbox
                cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 255, 255), 5)
                cv2.putText(frame, f"ID: {assigned_id}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (77, 33, 191), 2)

                if save_frame:
                    save_dir = "new_logo_frames"
                    os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"frame_{frame_idx}_logo_{logo_id_counter}.jpg")
                    cv2.imwrite(save_path, frame)


        out.write(frame)  # write processed frame to output

        frame_idx += 1 # next frame

    cap.release()
    out.release()
    print(f"Processed video saved as {output_video_path}")

    for faiss_idx, counter in logo_id_map.items():
        print(f'{faiss_idx} appeared approx {logo_appearance_counts[counter] * 5} times')



In [95]:

input_video_path = "starbucks_video.mp4"  # path to input video
output_video_path = "output_video.mp4"  # output processed video

process_video(input_video_path, output_video_path)

Processing frame 0

0: 384x640 1 logo, 113.0ms
Speed: 3.8ms preprocess, 113.0ms inference, 2.5ms postprocess per image at shape (1, 3, 384, 640)
Logo 0 detected at coordinates: (1128, 503) -> (1345, 654)
Processing frame 5

0: 384x640 1 logo, 89.4ms
Speed: 3.6ms preprocess, 89.4ms inference, 1.4ms postprocess per image at shape (1, 3, 384, 640)
Logo 0 detected at coordinates: (1129, 504) -> (1343, 654)
Distance: 0.042679153
INDEX ALREADY EXISTS
Processing frame 10

0: 384x640 1 logo, 105.4ms
Speed: 3.2ms preprocess, 105.4ms inference, 1.4ms postprocess per image at shape (1, 3, 384, 640)
Logo 0 detected at coordinates: (1129, 503) -> (1342, 653)
Distance: 0.046831667
INDEX ALREADY EXISTS
Processing frame 15

0: 384x640 1 logo, 92.1ms
Speed: 2.9ms preprocess, 92.1ms inference, 1.4ms postprocess per image at shape (1, 3, 384, 640)
Logo 0 detected at coordinates: (1128, 503) -> (1341, 653)
Distance: 0.047004372
INDEX ALREADY EXISTS
Processing frame 20

0: 384x640 1 logo, 90.2ms
Speed: 3.1