In [None]:
!pip install ftfy regex tqdm -q
!pip install git+https://github.com/openai/CLIP.git -q
!pip install googletrans==3.1.0a0
!pip install translate==3.6.1
!pip install langdetect==1.0.9

In [4]:
import torch
import clip
from PIL import Image
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 138MiB/s]


In [5]:
def cosine_similarity(a, b):
    numerator = np.dot(a, b)
    denominator = np.linalg.norm(a) * np.linalg.norm(b)
    return numerator / denominator

def correlation_coefficient(a, b):
    return np.corrcoef(a, b)[0][1]

In [6]:
def check_clip(image_path, text, measure=cosine_similarity):
    if detect(text) == 'vi':
      text = translater(text)

    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    text = clip.tokenize([text]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

        return measure(image_features.flatten(), text_features.flatten())

In [7]:
import googletrans
import translate
from langdetect import detect


class Translation():
    def __init__(self, from_lang='vi', to_lang='en', mode='google'):
        # The class Translation is a wrapper for the two translation libraries, googletrans and translate.
        self.__mode = mode
        self.__from_lang = from_lang
        self.__to_lang = to_lang

        if mode in 'googletrans':
            self.translator = googletrans.Translator()
        elif mode in 'translate':
            self.translator = translate.Translator(
                from_lang=from_lang, to_lang=to_lang)

    def preprocessing(self, text):

        return text.lower()

    def __call__(self, text):

        text = self.preprocessing(text)
        return self.translator.translate(text) if self.__mode in 'translate' \
            else self.translator.translate(text, dest=self.__to_lang).text


translater = Translation()

In [8]:
# Start to test Vietnamese
image_path = "query-6.jpg"
query = "Đoạn video hai người chạy bộ. Các đồ vật nằm ngổn ngang bên trái khung hình."

print(check_clip(image_path, query))

0.26145172


In [9]:
# Start to test English
image_path = "query-6.jpg"
query = "Video of two people jogging. Objects are scattered on the left side of the frame."

print(check_clip(image_path, query))

0.26145172
