In [None]:
# CAPTION_FIELD = "sentences_raw"
CAPTION_FIELD = "th_sentences_raw"

In [None]:
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# os.environ["HF_DATASETS_CACHE"] = "/workspace/cache"

from datasets import load_dataset
import numpy as np
from sentence_transformers import SentenceTransformer
import torch

In [None]:
dataset = load_dataset("patomp/thai-mscoco-2014-captions", split="test")
dataset

In [None]:
print(dataset[0])

In [None]:
images = dataset["image"]
texts = [captions[0] for captions in dataset[CAPTION_FIELD]]

In [None]:
### ======== Multi-lingual Model ========
### (1) [Sentence-Transformers] Multilingual
# TEXT_ENCODER_MODEL_NAME = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
# img_model = SentenceTransformer('clip-ViT-B-32')
# text_model = SentenceTransformer(TEXT_ENCODER_MODEL_NAME)
# image_embeddings = img_model.encode(images)
# stime = time.time()
# text_embeddings = text_model.encode(texts)
# etime = time.time()

### (2) [Sentence-Transformers] Multilingual
# import clip
# from multilingual_clip import pt_multilingual_clip
# import transformers

# model_name = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
# text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# visual_model, preprocess = clip.load("ViT-B/32", device=device)

# with torch.no_grad():
#     stime = time.time()
#     text_embeddings = [text_model.forward([txt,], tokenizer)[0] for txt in texts]
#     etime = time.time()
#     image_embeddings = [visual_model.encode_image(preprocess(img).unsqueeze(0).to(device))[0] for img in images]

# text_embeddings = np.array([x.detach().numpy() for x in text_embeddings])
# image_embeddings = np.array([x.detach().numpy() for x in image_embeddings])


### ======== Thai-only Model ========
### (1) Thai-Cross-CLIP
# import sys
# sys.path.append("./Thai-Cross-CLIP")
# from source.model import *
# from source.config import *
# sys.path.append("./thai2transformers/thai2transformers")
# from preprocess import process_transformers
# import clip
# text_model = TextModel().to(CFG.device)
# text_model.load_state_dict(torch.load("./CLIP-MSE-WangchanBerta/text_MSE_2m.pt", map_location=CFG.device))
# text_model.eval().requires_grad_(False)
# clip_model, compose = clip.load('ViT-B/32')
# clip_model.to(CFG.device).eval()
# input_resolution = clip_model.visual.input_resolution
# print("Text encoder parameters:", f"{np.sum([int(np.prod(p.shape)) for p in text_model.parameters()]):,}")
# print("Input image resolution:", input_resolution)
# _images = [img.convert('RGB').resize((input_resolution,input_resolution)) for img in images]
# _images = [torch.tensor(np.array(img)).permute(2, 0, 1)/255 for img in _images]
# _images = [img.unsqueeze(0).to(CFG.device) for img in _images]
# with torch.no_grad():
#     image_embeddings = np.array([clip_model.encode_image(img).detach().numpy() for img in _images])
#     stime = time.time()
#     text_embeddings = [text_model.encode_text([process_transformers(txt)]) for txt in texts]
#     etime = time.time()
#     text_embeddings= [txt.to(CFG.device).detach().numpy() for txt in text_embeddings]
#     text_embeddings = np.array(text_embeddings)
# image_embeddings = image_embeddings.reshape(-1, 512)
# text_embeddings = text_embeddings.reshape(-1, 512)

### (2) Thai2Fit
# import sys
# sys.path.append("../models")
# import numpy as np
# import torch
# from pythainlp import word_vector, word_tokenize
# from projector import Projector
# projector = Projector(input_embedding_dim=300)
# projector.load_state_dict(torch.load("../models/projector_high_alpha.pt"))
# projector.eval()
# model = word_vector.WordVector(model_name="thai2fit_wv")#.get_model()
# def embed_sentence(text):
#     embed = model.sentence_vectorizer(text, use_mean=True)[0]
#     return projector(torch.from_numpy(embed).float())
# with torch.no_grad():
#     stime = time.time()
#     text_embeddings = [embed_sentence(x) for x in texts]
#     etime = time.time()
# text_embeddings = np.array([x.detach().numpy() for x in text_embeddings])

from transformers import AutoModel
model = AutoModel.from_pretrained("patomp/thai-light-multimodal-clip-and-distill", trust_remote_code=True)
text_embeddings = np.array([model(text) for text in texts])

import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
visual_model, preprocess = clip.load("ViT-B/32", device=device)
with torch.no_grad():
    image_embeddings = [visual_model.encode_image(preprocess(img).unsqueeze(0).to(device))[0] for img in images]
image_embeddings = np.array([x.detach().numpy() for x in image_embeddings])

In [None]:
latent_time = etime - stime
sample_per_sec = float(len(texts)) / latent_time

print("latent_time: ", latent_time)
print("sample_per_sec: ", sample_per_sec)

## Indexing

In [None]:
# Reference for FAISS Index: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
import faiss

d = text_embeddings.shape[1]
assert d == image_embeddings.shape[1]

text_index = faiss.IndexFlatIP(d)
image_index = faiss.IndexFlatIP(d)

assert text_index.ntotal == image_index.ntotal

In [None]:
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1).reshape(-1, 1)
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1).reshape(-1, 1)

image_index.add(image_embeddings)
text_index.add(text_embeddings)

In [None]:
# np.linalg.norm(text_embeddings, axis=1)

In [None]:
for i in range(d):
    sample_embd = text_embeddings[d]
    cosim_score = np.inner(sample_embd, sample_embd)
    assert cosim_score > 0.99 and cosim_score < 1.01

## Evaluation

In [None]:
def get_recall_at_k(a_modal_embeddings, b_modal_index, k=5) -> float:    
    _, _retrieved_indices = b_modal_index.search(a_modal_embeddings, k=k)
    # print(_retrieved_indices)
    _n = len(a_modal_embeddings)
    _recall = [
        1.0 if i in indices else 0.0
        for i,indices in zip(range(_n),_retrieved_indices)
    ]
    _recall = sum(_recall) / float(_n)
    return _recall

In [None]:
# expect > .99
get_recall_at_k(text_embeddings, text_index, k=1)

In [None]:
# expect > .99
get_recall_at_k(image_embeddings, image_index, k=1)

In [None]:
get_recall_at_k(text_embeddings, image_index,k=1)

In [None]:
get_recall_at_k(text_embeddings, image_index,k=10)

In [None]:
get_recall_at_k(image_embeddings, text_index,k=1)

In [None]:
get_recall_at_k(image_embeddings, text_index,k=10)

## Query Examples

In [None]:
text = "หมากำลังวิ่งเล่น"

In [None]:
# this code came from sbert multilingual
embd = text_model.encode([text,])
_, indices = image_index.search(embd, k=5)

In [None]:
images[indices[0][0]]

In [None]:
images[indices[0][1]]

In [None]:
images[indices[0][2]]