In [None]:
import torch
from PIL import Image
import open_clip
import torchvision
import os
import cv2
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import Compose, Resize, Normalize
import pandas as pd
import altair as alt
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from tqdm import tqdm
import matplotlib.pyplot as plt

# Utilities

In [None]:
class LoadVideo(Dataset):
    def __init__(self, path, transforms, vid_stride=1):   
        
        self.transforms = transforms 
        self.vid_stride = vid_stride
        self.cur_frame = 0
        self.cap = cv2.VideoCapture(path)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
 
    def __getitem__(self, index):    
        # Read video
        # Skip over frames
        for _ in range(self.vid_stride):
            self.cur_frame += 1
            self.cap.grab()
        
        # Read frame
        _, img = self.cap.retrieve()
        timestamp = self.cap.get(cv2.CAP_PROP_POS_MSEC)


        # Convert to PIL
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(np.uint8(img))
        
        # Apply transforms
        img_t = self.transforms(img)  
        
        return img_t, to_tensor(img), self.cur_frame, timestamp

    def __len__(self):
        return self.total_frames
    
MODELS = {
    "convnext_base - laion400m_s13b_b51k": ("convnext_base", "laion400m_s13b_b51k"),
    "convnext_base_w - laion2b_s13b_b82k": (
        "convnext_base_w",
        "laion2b_s13b_b82k",
    ),
    "convnext_base_w - laion2b_s13b_b82k_augreg": (
        "convnext_base_w",
        "laion2b_s13b_b82k_augreg",
    ),
    "convnext_base_w - laion_aesthetic_s13b_b82k": (
        "convnext_base_w",
        "laion_aesthetic_s13b_b82k",
    ),
    "convnext_base_w_320 - laion_aesthetic_s13b_b82k": (
        "convnext_base_w_320",
        "laion_aesthetic_s13b_b82k",
    ),
    "convnext_base_w_320 - laion_aesthetic_s13b_b82k_augreg": (
        "convnext_base_w_320",
        "laion_aesthetic_s13b_b82k_augreg",
    ),
    "convnext_large_d - laion2b_s26b_b102k_augreg": (
        "convnext_large_d",
        "laion2b_s26b_b102k_augreg",
    ),
    "convnext_large_d_320 - laion2b_s29b_b131k_ft": (
        "convnext_large_d_320",
        "laion2b_s29b_b131k_ft",
    ),
    "convnext_large_d_320 - laion2b_s29b_b131k_ft_soup": (
        "convnext_large_d_320",
        "laion2b_s29b_b131k_ft_soup",
    ),
    "convnext_xxlarge - laion2b_s34b_b82k_augreg": (
        "convnext_xxlarge",
        "laion2b_s34b_b82k_augreg",
    ),
    "convnext_xxlarge - laion2b_s34b_b82k_augreg_rewind": (
        "convnext_xxlarge",
        "laion2b_s34b_b82k_augreg_rewind",
    ),
    "convnext_xxlarge - laion2b_s34b_b82k_augreg_soup": (
        "convnext_xxlarge",
        "laion2b_s34b_b82k_augreg_soup",
    ),
}

# Select an Input Video and Configure Options

In [None]:
video_widget = FileChooser(
    path='./', 
    title="Input Video:",
)
model_widget = widgets.Dropdown(
    options=list(MODELS.keys()),
    value="convnext_base_w - laion2b_s13b_b82k",
    description='Model:',
    disabled=False,
)
stride_widget = widgets.IntText(
    value=4,
    min=1,
    description='Frame Stride:',
    disabled=False
)
batch_widget = widgets.IntText(
    value=4,
    min=1,
    description='Batch Size:',
    disabled=False
)
crop_widget = widgets.Checkbox(
    value=False,
    description='Center Crop',
    disabled=False,
    indent=True
)
cuda_widget = widgets.Checkbox(
    value=True,
    description='Cuda',
    disabled=False,
    indent=True
)
text_query_widget = widgets.Text(
    value='',
    placeholder='Text Search Query',
    description='',
    disabled=False   
)
image_query_widget = FileChooser(
    path='./', 
    title="Image Search Query:",
)

query_widget = widgets.Dropdown(
    options=['Text', 'Image'],
    value='Text',
    description='Query Type:',
    disabled=False,
)
output = widgets.Output()

def query_type_handler(change):
    output.clear_output()
    with output:
        if change.new == "Text":
            display(text_query_widget)
        else:
            display(image_query_widget)

query_widget.observe(query_type_handler, names="value")

display(
    video_widget,
    model_widget,
    query_widget,
    output,
    stride_widget,
    batch_widget,
    crop_widget,
    cuda_widget,
)
with output:
    display(text_query_widget)

# Search Video

In [None]:
print("Preparing model...")

# Check inputs
assert video_widget.selected, "An input video should be provided"
assert (
    text_query_widget.value is not None or image_query_widget.selected is not None
), "A text or image query should be provided"

if cuda_widget.value:
    assert torch.cuda.is_available(), "Selected cuda but cuda is not available"
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Initialize model
name, weights = MODELS[model_widget.value]
model, _, preprocess = open_clip.create_model_and_transforms(
    name, pretrained=weights, device=device
)
model.eval()

# Load video
dataset = LoadVideo(video_widget.selected, transforms=preprocess, vid_stride=stride_widget.value)
dataloader = DataLoader(
    dataset, batch_size=batch_widget.value, shuffle=False, num_workers=0
)

if not crop_widget.value:
    del preprocess.transforms[1]

# Get text query features
if text_query_widget.value:
    # Tokenize search phrase
    tokenizer = open_clip.get_tokenizer(name)
    text = tokenizer([text_query_widget.value]).to(device)

    # Encode text query
    with torch.no_grad():
        query_features = model.encode_text(text)
        query_features /= query_features.norm(dim=-1, keepdim=True)

# Get image query features
else:
    image = preprocess(Image.open(image_query_widget.selected)).unsqueeze(0).to(device)
    with torch.no_grad():
        query_features = model.encode_image(image)
        query_features /= query_features.norm(dim=-1, keepdim=True)

# Encode each frame and compare with query features
res = pd.DataFrame(columns=["Frame", "Timestamp", "Similarity"])
from torchvision.transforms.functional import to_pil_image, to_tensor
for image, orig, frame, timestamp in tqdm(dataloader):
    with torch.no_grad():
        image = image.to(device)
        image_features = model.encode_image(image)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    probs = query_features.cpu().numpy() @ image_features.cpu().numpy().T
    probs = probs[0]

    # Save frame similarity values
    df = pd.DataFrame(
        {
            "Frame": frame.tolist(),
            "Timestamp": torch.round(timestamp / 1000, decimals=2).tolist(),
            "Similarity": probs.tolist(),
        }
    )
    res = pd.concat([res, df])

# Create plot of similarity values
lines = (
    alt.Chart(res)
    .mark_line(color="firebrick")
    .encode(
        alt.X("Timestamp", title="Timestamp (seconds)"),
        alt.Y("Similarity", scale=alt.Scale(zero=True, domainMax=1)),
    )
).properties(width=600)

lines

# Display All Frames Over Similarity Threshold

In [None]:
def view_frames(b):
    output.clear_output()
    with output:
        thresh = thresh_widget.value
        
        assert 0 <= thresh <= 1.0, "Threshold must be between 0 and 1" 
        assert not save_widget.value or (save_widget.value and path_widget.selected is not None), "Must choose a save directory"

        frames = res.T.values[0]
        timestamps = res.T.values[1]
        sims = res.T.values[2]


        # Find all frames over the threshold
        matches = []
        for f, t, s in zip(frames, timestamps, sims):
            if s > thresh:
                matches.append((f, t, s))

        # Display frames
        cap = cv2.VideoCapture(video_widget.selected)            
        for f, t, s in matches:    
            # Grab frame from video
            cap.set(cv2.CAP_PROP_POS_FRAMES, f-1)
            _, img = cap.read()
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # Plot frame
            plt.axis("off")
            plt.title(f"{t}s - Score={s:.3f}")
            plt.imshow(img)
            plt.show()
            
            if save_widget.value:
                Image.fromarray(img).save(os.path.join(path_widget.selected, f"{f}.jpg"))
            
def save_handler(change):
    if change.new:
        with output_save:
            display(path_widget)
    else:
        output_save.clear_output()
            
            
thresh_widget = widgets.FloatText(
    value=0.3,
    min=0,
    max=1,
    step=0.01,
    description='Threshold:',
    readout=True,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
)

save_widget = widgets.Checkbox(
    value=False,
    description='Save Matched Frames',
    disabled=False,
    indent=True
)
output_save = widgets.Output()
save_widget.observe(save_handler, names="value")

path_widget = FileChooser(
    path='./', 
    title="Save Directory:",
    show_only_dirs=True,
)

button = widgets.Button(
    description='View Frames',
    disabled=False,
    button_style='',
    tooltip='View Frames',
    icon='check'
)
button.on_click(view_frames)

output = widgets.Output()

display(thresh_widget, save_widget, output_save, button, output)

In [None]:
res