-
Notifications
You must be signed in to change notification settings - Fork 3
/
search.py
46 lines (34 loc) · 1.99 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from .models import create_tokenizer, create_and_load_from_hub
class MultiLingualSearch:
def __init__(self, images_embeddings, images_data, model = None, device='cpu'):
self.model = model if model else create_and_load_from_hub()
self.model.eval()
self.tokenizer = create_tokenizer()
self.images_embeddings = images_embeddings
self.images_data = images_data
self.device = device
def compare_embeddings(self, logit_scale, img_embs, txt_embs):
# normalized features
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
return logits_per_image, logits_per_text
def compare_text_images(self, model, text, images_embeddings):
tokens = self.tokenizer(text)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
with torch.no_grad():
txt_embs = model.text_encoder(tokens)
images_tensors = torch.tensor(images_embeddings)
logit_scale = model.clip_model.logit_scale.exp().float().to('cpu')
logits_images, logits_text = self.compare_embeddings(logit_scale, images_tensors.to('cpu'), txt_embs.to('cpu'))
return logits_images.softmax(dim=0).cpu().detach().numpy()
def search_images(self, text, images_embeddings, images_data, amount=10):
probs = self.compare_text_images(self.model, text, images_embeddings)
images_probs = list(zip(images_data, [item[0] for item in probs.tolist()]))
sorted_images = sorted(images_probs, key=lambda x:x[1], reverse=True)
return [{'image': item[0], 'prob': item[1]} for item in sorted_images[:amount]]
def search(self, text, amount=10):
return self.search_images(text, self.images_embeddings, self.images_data, amount)