# Install require packages

In [None]:
!conda install faiss-gpu -c pytorch -y

In [None]:
!pip install faiss-gpu

In [11]:
import faiss

n_gpu = faiss.get_num_gpus()

n_gpu

1

In [None]:
!pip install fiftyone
!pip install torch
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

# Models

In [22]:
from abc import abstractmethod
import clip
import torch
import numpy as np
from PIL import Image

class BaseModel():
    @abstractmethod
    def encode_text(self):
        raise NotImplementedError

    @abstractmethod
    def encode_image(self):
        raise NotImplementedError

In [35]:
def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0:
        norm = np.finfo(v.dtype).eps
    return v/norm

class VIT(BaseModel):
    def __init__(self):
        self.model, self.preprocess = clip.load("ViT-B/16", device="cpu")

        self.model.eval()
        self.input_resolution = self.model.visual.input_resolution
        self.context_length = self.model.context_length
        self.vocab_size = self.model.vocab_size

    def encode_text(self, text):
        text_tokens = clip.tokenize([text])

        with torch.no_grad():
            text_features = self.model.encode_text(text_tokens).float()

        return text_features

    def encode_frame(self, path = None):
        image = self.preprocess(Image.open(path)).unsqueeze(0)
        
        with torch.no_grad():
            image_features = self.model.encode_image(image)
            
        frame_features = image_features[0].cpu().detach().numpy()
        # frame_features = normalize(frame_features)
        
        return frame_features

# Retrival

In [3]:
import fiftyone as fo
from fiftyone import ViewField as F
import faiss
import json
import numpy as np
from glob import glob
import os
import pandas as pd
from tqdm.notebook import trange, tqdm

In [4]:
class Retriever:
    def __init__(self, img_dir: str, vector_dim: int = 512):
        self.img_dir = img_dir
        self.dataset = fo.Dataset.from_images_dir(
            img_dir, name=None, tags=None, recursive=True)
        self.object_dir = None
        self.video_feature_dict = {}
        # faiss
        self.index = faiss.IndexFlatIP(vector_dim)


    def add_meta_data_images(self):
        # Add video, frameid
        print('Adding meta data')
        pbar = tqdm(self.dataset)
        for sample in pbar:
            _, sample['video'], sample['frameid'] = sample['filepath'][:-4].rsplit('/', 2)
            sample.save()

    def add_object_detection(self, object_dir: str):
        self.object_dir = object_dir
        print('Adding object detection')
        pbar = tqdm(self.dataset)
        for sample in pbar:
            object_path = os.path.join(object_dir, sample['filepath'][-20:-4] + '.json')
            try:
                with open(object_path) as jsonfile:
                    det_data = json.load(jsonfile)
            except:
                continue
            detections = []
            for cls, box, score in zip(det_data['detection_class_entities'], det_data['detection_boxes'], det_data['detection_scores']):
                # Convert to [top-left-x, top-left-y, width, height]
                boxf = [float(box[1]), float(box[0]), float(box[3]) -
                        float(box[1]), float(box[2]) - float(box[0])]
                scoref = float(score)

                # Only add objects with confidence > 0.4
                if scoref > 0.4:
                    detections.append(
                        fo.Detection(
                            label=cls,
                            bounding_box=boxf,
                            confidence=float(score)
                        )
                    )
            sample["object_faster_rcnn"] = fo.Detections(detections=detections)
            sample.save()


    def get_keyframe_list(self):
        '''
            Return:
                a dictionary: {
                    'video_name': List[keyframe]
                }
        '''
        path_all_keyframe = os.path.join(self.img_dir,'*','*.jpg')

        all_keyframe = glob(path_all_keyframe)
        video_keyframe_dict = {}

        for kf in all_keyframe:
            _, vid, kf = kf[:-4].rsplit('/', 2)
            if vid not in video_keyframe_dict.keys():
                video_keyframe_dict[vid] = [kf]
            else:
                video_keyframe_dict[vid].append(kf)

        for k, v in video_keyframe_dict.items():
            video_keyframe_dict[k] = sorted(v)

        return video_keyframe_dict


    def get_video_list(self):
        path_all_video = os.path.join(self.img_dir, '*')
        
        all_video = glob(path_all_video)
        all_video = [v.rsplit('/', 1)[-1] for v in all_video]

        return all_video

    def extract_vector_features_per_frame(self, features_dir):
        self.features_dir = features_dir

        video_list = self.get_video_list()
        keyframe_list = self.get_keyframe_list()
        
        print('Extracting key frames')
        pbar = tqdm(video_list)
        for video_name in pbar:
            clip_path = os.path.join(features_dir,  video_name + '.npy')
            features = np.load(clip_path)
            feature_video_dir = os.path.join(features_dir, video_name)

            self.video_feature_dict[video_name] = {}
            for i, frameid in enumerate(keyframe_list[video_name]):
                self.video_feature_dict[video_name][frameid] = features[i]

    def add_clip_embedding(self):
        print('Add clip embedding')
        pbar = tqdm(self.dataset)
        embeddings = []
        for sample in pbar:
            tokens = sample['filepath'].split('/')
            video_name, frame_id = tokens[-2], tokens[-1][:-4]
            
            clip_embedding = self.video_feature_dict[video_name][frame_id]
            clip_embedding = clip_embedding.flatten()
            embeddings.append(clip_embedding)

            sample.save()
        
        embeddings = np.stack(embeddings, axis=0)
        embeddings = embeddings.astype(np.float32)
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)
        
    def search_queries(self, X, top_k: int):
        # convert to numpy array
        faiss.normalize_L2(X)
        distances, indices = self.index.search(X.astype(np.float32), top_k)
        return distances, indices

    def export(self, export_dir, distances, indices):
        indices = indices.tolist()
        mask = np.zeros(len(self.dataset), dtype=bool)
        for idx in indices:
            mask[idx] = True
            
        result = self.dataset[mask]
        for idx, sample in enumerate(result):
            sample['similarity'] = distances[idx]
            sample.save()
            
        result.sort_by('similarity', reverse=False)   
        result.export(export_dir=export_dir,
                       dataset_type=fo.types.FiftyOneDataset)

# Data and params

In [6]:
img_dir = '/run/media/zephy_manjaro/Crucial X6/AIC2022/data/keyframes/'
obj_dir = 'data/objects'
feature_dir = '/home/zephy_manjaro/workspace/code/projects/AIC2022/AIChallenge2022/data/clipfeatures'
vector_dim = 512

# Prepare data

In [7]:
retriever = Retriever(img_dir, vector_dim)
retriever.add_meta_data_images()
retriever.add_object_detection(obj_dir)
retriever.extract_vector_features_per_frame(feature_dir)
retriever.add_clip_embedding()



 100% |███████████| 887227/887227 [2.3m elapsed, 0s remaining, 6.7K samples/s]      
Adding meta data


  0%|          | 0/887227 [00:00<?, ?it/s]

Adding object detection


  0%|          | 0/887227 [00:00<?, ?it/s]

Extracting key frames


  0%|          | 0/1249 [00:00<?, ?it/s]

Add clip embedding


  0%|          | 0/887227 [00:00<?, ?it/s]

# Searching

In [36]:
encoder = VIT()

In [18]:
n_top_k_images = 1000

In [44]:
text_queries = {
    1: 'A motorbike parking lot along a street for an event. Then there are two rows of red and yellow lanterns.',
    2: 'A man wears a green shirt. His face painted green. There is a number on his shirt',
    3: 'A football match. People are wearing yellow shirt and red shirt',
    4: 'Neswspaper with pictures of an old man is wearing a blue shirt and a black glass',
    5: 'A yellow statue horse and a yellow statue man'
}

In [45]:
QUERY_ID = [2]

!rm -rf submission

text_feature_list = np.array([encoder.encode_text(text).cpu().numpy()[0] for (id, text) in text_queries.items() if id in QUERY_ID])
print(text_feature_list.shape)
text_id_list = [id for id in QUERY_ID]

distances_list, indices_list = retriever.search_queries(text_feature_list, n_top_k_images)
for i, (distances, indices) in enumerate(zip(distances_list, indices_list)):
    retriever.export('submission/{}_top_k_images'.format(text_id_list[i]), distances, indices)

(1, 512)
Exporting samples...
 100% |██████████████████| 1000/1000 [228.6ms elapsed, 0s remaining, 4.4K docs/s]      


# Get Similar Image

In [47]:
video = 'C02_V0381'
image = '001280.jpg'

path_image = os.path.join('/run/media/zephy_manjaro/Crucial X6/AIC2022/data/keyframes/', video, image)

image_feature_list = np.array([encoder.encode_frame(path_image)])

distances_list, indices_list = retriever.search_queries(image_feature_list, 200)
for i, (distances, indices) in enumerate(zip(distances_list, indices_list)):
    retriever.export('submission/similar-frame', distances, indices)
    
dataset = fo.Dataset.from_dir(
    dataset_dir='submission/similar-frame',
    dataset_type=fo.types.FiftyOneDataset
)
    
session = fo.launch_app(dataset, auto=False)
session.show()
    

Directory 'submission/similar-frame' already exists; export will be merged with existing files
Exporting samples...
 100% |████████████████████| 200/200 [58.5ms elapsed, 0s remaining, 3.4K docs/s] 
Importing samples...
 100% |█████████████████| 200/200 [12.9ms elapsed, 0s remaining, 15.5K samples/s]      
Import complete
Session launched. Run `session.show()` to open the App in a cell output.
