In [1]:
# Importing libraries
import lancedb
import pyarrow as pa
from transformers import CLIPModel, CLIPProcessor
from torchvision.io import read_image
import torch
from torch.utils.data import Dataset, DataLoader
import glob
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import os
import matplotlib.pyplot as plt
from PIL import Image
from typing import Union
import ipywidgets
from IPython.display import display
from deep_translator import GoogleTranslator
import open_clip
import math
import requests
import io
import json

In [2]:
translator = GoogleTranslator(source='vi', target='en')

In [3]:
class LancedbModel():
    def __init__(self) -> None:
        self.lancedb_instance = lancedb.connect("database.lance")
        self.database = {}
        if "patch14v2_old_and_extended_openclip" in self.lancedb_instance.table_names():
            self.database['openclip'] = self.lancedb_instance["patch14v2_old_and_extended_openclip"]
        else:
            print(f"Load database error!")
            raise FileExistsError("Database not found!")
        
        if "patch14v2_extended" in self.lancedb_instance.table_names():
            self.database['clip'] = self.lancedb_instance["patch14v2_extended"]
        else:
            print(f"Load database error!")
            raise FileExistsError("Database not found!")
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_openclip, _, self.preprocess_openclip = open_clip.create_model_and_transforms('ViT-H-14-378-quickgelu', pretrained='dfn5b')
        self.model_openclip.eval()  
        self.model_openclip.to(self.device)
        self.tokenizer_openclip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')

        self.model_clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
        self.processor_clip = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    def _inference_text(self, query, model_name):
        if model_name not in ["clip", "openclip"]:
            raise ValueError("Invalid model name. Must be 'clip' or 'openclip'.")
        
        if model_name == "openclip":
            inputs = self.tokenizer_openclip([query]).to(self.device)
            with torch.no_grad(), torch.amp.autocast('cuda'):
                text_embedding = self.model_openclip.encode_text(inputs)
                text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
                text_embedding = text_embedding.squeeze().cpu().numpy()
            return text_embedding
        
        if model_name == "clip":
            inputs = self.processor_clip(text=query, return_tensors="pt").to(self.device)
            with torch.inference_mode():
                text_embedding = self.model_clip.get_text_features(**inputs).squeeze().cpu().numpy()
            return text_embedding
        
        return None
    
    def _inference_image(self, query: Image.Image, model_name):
        if model_name not in ["clip", "openclip"]:
            raise ValueError("Invalid model name. Must be 'clip' or 'openclip'.")

        if model_name == "openclip":
            query = self.preprocess_openclip(query).unsqueeze(0)
            query = query.to(self.device)
            with torch.no_grad(), torch.amp.autocast('cuda'):
                embedding = self.model_openclip.encode_image(query)
                embedding /= embedding.norm(dim=-1, keepdim=True)
                embedding = embedding.squeeze().cpu().numpy()
            return embedding
        
        if model_name == "clip":
            inputs = self.processor_clip(images=query, return_tensors="pt").to(self.device)
            with torch.no_grad():
                embedding = self.model_clip.get_image_features(**inputs).cpu().squeeze().numpy()
            return embedding
        
        return None

    def find_by_text(self, query: str, num_data=25, metric="cosine", sql_filter : None | str = None, model_name: str = 'openclip') -> pd.DataFrame:
        text_embedding = self._inference_text(query, model_name)
        if sql_filter == None:
            results = self.database[model_name].search(text_embedding, vector_column_name='embedding').metric(metric).limit(num_data).to_pandas()
        else:
            results = self.database[model_name].search(text_embedding, vector_column_name='embedding').where(sql_filter, prefilter=True).metric(metric).limit(num_data).to_pandas()
        return results

    
    def find_by_picture(self, query: Union[Image.Image, str], num_data=25, metric="cosine", sql_filter: None | str = None, model_name: str = 'openclip') -> pd.DataFrame:
        if isinstance(query, str):
            query = Image.open(query)
        elif not isinstance(query, Image.Image):
            raise TypeError("query must be a PIL Image object or a string representing an image path")

        embedding = self._inference_image(query, model_name)
        if sql_filter == None:        
            results = self.database[model_name].search(embedding, vector_column_name='embedding').metric(metric).limit(num_data).to_pandas()
        else:
            results = self.database[model_name].search(embedding, vector_column_name='embedding').where(sql_filter, prefilter=True).metric(metric).limit(num_data).to_pandas()
        return results
        

model = LancedbModel()

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [7]:
df = model.find_by_text("a photo of a cat")

In [8]:
df

Unnamed: 0,embedding,video_name,image_name,frame_idx,path,origin,_distance
0,"[0.02330017, 0.024993896, -0.0024909973, 0.011...",L09_V007,233.jpg,22058,keyframes\L09_V007\233.jpg,original,0.708628
1,"[0.026245117, 0.030532837, 0.00022137165, 0.00...",L09_V007,image-01-scene-279-frame-22059.jpg,22059,keyframes\L09_V007\image-01-scene-279-frame-22...,extended,0.71347
2,"[0.024627686, 0.024719238, -0.0024299622, 0.00...",L09_V007,image-03-scene-279-frame-22122.jpg,22122,keyframes\L09_V007\image-03-scene-279-frame-22...,extended,0.71495
3,"[0.022903442, 0.007160187, 0.025299072, -0.006...",L09_V007,image-03-scene-285-frame-22528.jpg,22528,keyframes\L09_V007\image-03-scene-285-frame-22...,extended,0.718616
4,"[0.021881104, 0.01474762, -0.010116577, 0.0229...",L09_V007,image-02-scene-279-frame-22091.jpg,22091,keyframes\L09_V007\image-02-scene-279-frame-22...,extended,0.724461
5,"[-0.0031871796, 0.013023376, -0.0042686462, 0....",L08_V004,image-01-scene-334-frame-28932.jpg,28932,keyframes\L08_V004\image-01-scene-334-frame-28...,extended,0.72684
6,"[0.026565552, 0.016815186, 0.026733398, 0.0061...",L09_V007,image-02-scene-285-frame-22494.jpg,22494,keyframes\L09_V007\image-02-scene-285-frame-22...,extended,0.727361
7,"[0.02619934, 0.0031204224, -0.014678955, 0.014...",L04_V012,image-01-scene-188-frame-15436.jpg,15436,keyframes\L04_V012\image-01-scene-188-frame-15...,extended,0.730284
8,"[0.0018615723, 0.020965576, -0.019897461, -0.0...",L05_V005,234.jpg,24144,keyframes\L05_V005\234.jpg,original,0.730543
9,"[0.008590698, 0.004142761, 0.00806427, 0.00682...",L11_V028,image-03-scene-062-frame-4532.jpg,4532,keyframes\L11_V028\image-03-scene-062-frame-45...,extended,0.731702


In [None]:
print(df.drop(columns=['embedding', 'text', 'label'], inplace=True))

In [6]:
print('.\\keyframes\\L09_V007\\233.jpg')

.\keyframes\L09_V007\233.jpg


In [4]:
def find_by_2_query(query1: str, query2: str, num_data=5, model_name='openclip'):
    df1 = model.find_by_text(query1, num_data=num_data, model_name=model_name)
    df2 = pd.DataFrame()
    for video_name in df1['video_name'].unique():
        temp_df = model.find_by_text(query2, num_data=5, sql_filter=f"video_name = '{video_name}'")
        df2 = pd.concat([df2, temp_df], ignore_index=True)
    return df1, df2

In [5]:
def show_images_grid(data):
    columns = 0
    for x in data:
        if len(data[x]) > columns:
            columns = len(data[x])

    rows = len(data)
    
    fig, axs = plt.subplots(rows, columns, figsize=(3*3*columns, 3*2*rows))
    
    if rows == 1 or columns == 1:
        for x in data:
            for i in len(data[x]):
                axs[i].imshow(data[x][i]['image'])
                axs[i].set_title(data[x][i]['image_info'], color=data[x][i]['color'])
                axs[i].axis('off')
    else:
        i = 0
        for x in data:
            for j in range(columns):
                if j < len(data[x]):
                    axs[i, j].imshow(data[x][j]['image'])
                    axs[i, j].set_title(data[x][j]['image_info'], color=data[x][j]['color'])
                axs[i, j].axis('off')
            i += 1
    plt.tight_layout()
    plt.savefig("output.jpg")
    plt.show()    

In [6]:
with open('fps30.json', 'r') as f:
    fps30 = json.load(f)

In [7]:
fps30

['L06_V003',
 'L09_V009',
 'L15_V003',
 'L15_V004',
 'L15_V005',
 'L15_V008',
 'L15_V009',
 'L15_V010',
 'L15_V011',
 'L15_V012',
 'L15_V015',
 'L15_V016',
 'L15_V017',
 'L15_V018',
 'L15_V019',
 'L15_V022',
 'L15_V023',
 'L15_V024',
 'L15_V025',
 'L15_V028',
 'L15_V029',
 'L15_V030',
 'L16_V001',
 'L16_V002',
 'L16_V003',
 'L16_V006',
 'L16_V007',
 'L16_V008',
 'L16_V009',
 'L16_V012',
 'L16_V013',
 'L16_V014',
 'L16_V015',
 'L16_V016',
 'L16_V019',
 'L16_V020',
 'L16_V021',
 'L16_V022',
 'L16_V023',
 'L16_V026',
 'L16_V027',
 'L16_V028',
 'L16_V029',
 'L17_V003',
 'L17_V004',
 'L17_V005',
 'L17_V006',
 'L17_V007',
 'L17_V010',
 'L17_V011',
 'L17_V012',
 'L17_V013',
 'L17_V014',
 'L17_V017',
 'L17_V018',
 'L17_V019',
 'L17_V020',
 'L17_V021',
 'L17_V024',
 'L17_V025',
 'L17_V026',
 'L17_V027',
 'L17_V028',
 'L18_V002',
 'L18_V003',
 'L18_V004',
 'L18_V005',
 'L18_V006',
 'L18_V009',
 'L18_V010',
 'L18_V011',
 'L18_V012',
 'L18_V013',
 'L18_V016',
 'L18_V017',
 'L18_V018',
 'L18_V019',

In [8]:
media_info_df = pd.read_csv('media-info.csv')
def get_ytb_link(video_name:str, keyframe_number: int):
    timestamp = keyframe_number // 25
    if video_name in fps30:
        timestamp = keyframe_number // 30
    if video_name in ["L24_V044"]:
        timestamp = int(keyframe_number / 26.44)
    timestamp = max(0, timestamp - 2)
    for index, row in media_info_df.iterrows():
        if row['filename'] == video_name + '.json':
            return row['watch_url']+f'&t={timestamp}s'
    return "Video not found!"

In [9]:
query1_input = "Cảnh quay từ một chiếc camera trên một chiếc xe quay lại hành trình di chuyển."
query2_input = "một người mặc áo đen cùng vali màu hồng đứng bên tay phải"
query1_input = translator.translate(query1_input)
query2_input = translator.translate(query2_input)
print(query1_input)
print(query2_input)

Footage from a camera on a car filming the journey.
A man in a black shirt with a pink suitcase stands on the right.


In [10]:
query1_input = "​​COM BINH DAN in red letters"

In [11]:
query1_input, query2_input = query2_input, query1_input

In [12]:
df1, df2 = find_by_2_query(query1_input, query2_input, 30, 'openclip')
URL_BASE = 'http://127.0.0.1:5000/images/'
df1['query'] = 'query1'
df2['query'] = 'query2'
img_path = {}
for video_name in df1['video_name'].unique():
    img_path[video_name] = []
    for index, row in df1.iterrows():
        if row['video_name'] == video_name:
            url = URL_BASE + row['video_name']+ '/' + row['image_name']
            img_path[video_name].append({
                'image': plt.imread(io.BytesIO(requests.get(url).content), format=row['image_name'].rsplit('.', 1)[1]),
                'image_info': row['video_name'] + " " + row['image_name'] + " " + str(row['frame_idx']),
                'color': 'red' if row['query'] == 'query1' else 'blue'
            })
    for index, row in df2.iterrows():
        if row['video_name'] == video_name:
            url = URL_BASE + row['video_name']+ '/' + row['image_name']
            img_path[video_name].append({
                'image': plt.imread(io.BytesIO(requests.get(url).content), format=row['image_name'].rsplit('.', 1)[1]),
                'image_info': row['video_name'] + " " + row['image_name'] + " " + str(row['frame_idx']),
                'color': 'red' if row['query'] == 'query1' else 'blue'
            })
show_images_grid(img_path)
for i, video_name in enumerate(df1['video_name'].unique()):
    print("Row", i+1)
    j = 0
    for index, row in df1.iterrows():
        if row['video_name'] == video_name:
            j += 1
            print(j, get_ytb_link(video_name, row['frame_idx']))
    for index, row in df2.iterrows():
        if row['video_name'] == video_name:
            j += 1
            print(j, get_ytb_link(video_name, row['frame_idx']))

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


ConnectionError: HTTPConnectionPool(host='127.0.0.1', port=5000): Max retries exceeded with url: /images/L12_V007/image-01-scene-209-frame-17998.jpg (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x0000012940ACF7D0>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it'))

In [49]:
image = Image.open((io.BytesIO(requests.get(URL_BASE + 'L12_V015/image-01-scene-171-frame-14647.jpg').content)))
df1 = model.find_by_picture(image, 50, model_name='openclip')

In [84]:
image = Image.open("Untitled.png")
df1 = model.find_by_picture(image, 50, model_name='openclip', metric="L2")

In [85]:
df1['query'] = 'query1'
img_path = {}
for video_name in df1['video_name'].unique():
    img_path[video_name] = []
    for index, row in df1.iterrows():
        if row['video_name'] == video_name:
            url = URL_BASE + row['video_name']+ '/' + row['image_name']
            img_path[video_name].append({
                'image': plt.imread(io.BytesIO(requests.get(url).content), format=row['image_name'].rsplit('.', 1)[1]),
                'image_info': row['video_name'] + " " + row['image_name']+ " " + str(row['frame_idx']),
                'color': 'red' if row['query'] == 'query1' else 'blue'
            })

In [None]:
show_images_grid(img_path)
for i, video_name in enumerate(df1['video_name'].unique()):
    print("Row", i+1)
    j = 0
    for index, row in df1.iterrows():
        if row['video_name'] == video_name:
            j += 1
            print(j, get_ytb_link(video_name, row['frame_idx']))