In [1]:
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import requests

# Load a pre-trained CLIP model
model = SentenceTransformer('clip-ViT-B-32')

  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
## declare constant and necessary variables
DB_path = "../DataCollector/result/places.db"
DB_name = "places"

In [3]:
import json
import sqlite3
from typing import List
import torch

def parseList(image_list):
    print(type(image_list))
    return json.loads(image_list)

def getImages(id):
    "Return a list of PIL images"
    conn = sqlite3.connect(DB_path)
    cursor = conn.cursor()
    cursor.execute(f"SELECT images FROM {DB_name} WHERE rowid = ?", (id,))
    image_url = cursor.fetchone()[0]
    image_url = parseList(image_url)
    pil_image = []
    for url in image_url:
        print(url)
        pil_image.append(Image.open(requests.get(url, stream=True).raw))

    return pil_image

def embedImageList(images):
    if len(images) == 0:
        return torch.zeros(1, 512)
    print("len(images) = ",len(images))
    encoding = [model.encode(image) for image in images]
    print(type(encoding[0]))
    # list of numpy.ndarray
    return sum(encoding) / len(encoding)
    

def embedText(text):
    return model.encode(text)

def getSimilarity(imageList, text):
    image_emb = embedImageList(imageList)
    text_emb = embedText(text)
    return util.cos_sim(image_emb, text_emb)

In [6]:
sample_text = "Đồ ăn ngon"
sample_images = [getImages(i) for i in range(2, 3)]

print(sample_images)


# my_embeddings = model.encode(sample_text)
print([getSimilarity(image, sample_text) for image in sample_images])

<class 'str'>
https://lh3.googleusercontent.com/p/AF1QipOSisTfpcYNbHV-0X3CB0E7pedEkh10avYId7eI=s0
https://lh3.googleusercontent.com/geougc-cs/AMBA38t2Kmr5_XJSv2PmDO10nabcToVMxqXZwW_M3PM9su_8PO1ED7uHBVqxnIGzYii6adp4DKOc07hYToYMGDCTK0A3sAnVZ0OcbUO7mh2w_5YBiC8MdqHRYBeFKFKBpXXethGqvS89=s0
https://lh3.googleusercontent.com/geougc-cs/AMBA38v8qd5Fo9KkbRzRKrYOOmP0JZxnFWPrjFUB2mbsDD04wT1pUBB01TM531RbG3_ZpfpUGoJgt2bvWrR0k5fFq7jkd3sp2m60IGig0JZJXf_wFPc2ANdfgFi1GH8-L_Wlg3dm-r9TpA=s0
[[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1080x1440 at 0x288D4B52910>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=3024x4032 at 0x288D4B3CB50>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=4032x3024 at 0x288D4A0E710>]]
len(images) =  3
<class 'numpy.ndarray'>
[tensor([[0.2461]])]


In [25]:
import faiss

# Define the path to your .bin file
index_file_path = "test_images.bin"

# Load the index from the file
try:
    index = faiss.read_index(index_file_path)
    print(f"Index loaded successfully from {index_file_path}")
    print(f"Index type: {type(index)}")
    print(f"Total number of vectors in the index: {index.ntotal}")
    ## print missing id from 1 to 2900

    ids_np = faiss.vector_to_array(index.id_map)
    ids_in_index = set(ids_np.tolist())

    expected_ids = set(range(1, 2876))
    missing = expected_ids - ids_in_index

    print("Missing IDs:", missing)

    print("Missing IDs:", len(missing))
except RuntimeError as e:
    print(f"Error loading index: {e}")
    print("Ensure the file path is correct and the file is a valid FAISS index file.")


Index loaded successfully from test_images.bin
Index type: <class 'faiss.swigfaiss_avx2.IndexIDMap2'>
Total number of vectors in the index: 2875
Missing IDs: {1852}
Missing IDs: 1


In [10]:
import sqlite3
## add missing ids to index

conn = sqlite3.connect("../DataCollector/result/images_embedding.db")
cursor = conn.cursor()
cursor.execute(f"SELECT rowid, url FROM test_images WHERE rowid IN ({', '.join(map(str, missing))})")
results = cursor.fetchall()

cursor.close()
conn.close()
print(results)

[(1852, 'https://lh3.googleusercontent.com/gps-cs-s/AG0ilSzFnBa-B0UbGoTTsoChbLAtqHnBzOPgVTFz_H8CjThSplPNn2_sARxmTvt8gXQdt-9FDrBmRFZ3NCbVuV5nvLGCLeQI8-xNTMZs3W2_ITxyzzrkKnuN_YfrgO-VxwThLxCR0nJ57mVMuUk=s0')]


In [24]:
## drop faiss index at 1852
import numpy as np
import faiss
index = faiss.read_index("test_images.bin")
index.remove_ids(np.array([1852], dtype=np.int64))
faiss.write_index(index, "test_images.bin")


In [27]:
import numpy as np

import requests
from io import BytesIO
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True


result = results[0]
## open faiss index again
index = faiss.read_index("test_images.bin")
id, url = result
print(id, url)
resp = requests.get(url, timeout=10)
resp.raise_for_status()

img = Image.open(BytesIO(resp.content)).convert("RGB")
embedding = model.encode(img, normalize_embeddings=True, convert_to_numpy=True).astype(np.float32)
index.add_with_ids(embedding.reshape(1, -1), np.array([id], dtype=np.int64))

faiss.write_index(index, "test_images.bin")

1852 https://lh3.googleusercontent.com/gps-cs-s/AG0ilSzFnBa-B0UbGoTTsoChbLAtqHnBzOPgVTFz_H8CjThSplPNn2_sARxmTvt8gXQdt-9FDrBmRFZ3NCbVuV5nvLGCLeQI8-xNTMZs3W2_ITxyzzrkKnuN_YfrgO-VxwThLxCR0nJ57mVMuUk=s0


In [31]:
food = "https://lh3.googleusercontent.com/geougc-cs/AMBA38tjkOByp4DU0-4-YrmbeXz3sHpteXD6IelGuoMxZfNOLg2TJVLoqHnWMEhBrHeb3KdQkt2InycNnnijlDYAi6X-N3UqykUm1rHMf0Xsp1rNuKYlNnGqaNgfqKZgx_k2jMpmyeIY=s0"
restaurant = "https://lh3.googleusercontent.com/gps-cs-s/AG0ilSzMtPXeE1uEkq2fRV1rpRaBW5MZMN-caZKmFtpRkjdvwC2LdEHP6xHOcwxsBEr2mObNWrCh5pFkq9InYJFr50TAUlfVGJS1P62hnKukoXOCmCglpwUz3skQhOcDLKx5H-JsW3Gq=s0"
mrInmenu = "https://lh3.googleusercontent.com/geougc-cs/AMBA38vz3isPKRASPEXTJDq8sgL9Ve0jp4SepgQhw6lWUY6CnwOqC8hFjuqPY6eFtG0u2gJZvxRBM46HpxRisrhhtSulyqVrxiBvFbbVO8bdR_XJ4lAeA7d90vdpailpeUfrQkMzd-R8yA=s0"
phomenu = "https://lh3.googleusercontent.com/geougc-cs/AMBA38v1aQd-mplhpMrPgxU5-YjF3bk_OYHpU60lXEdJ-vnLHSOpENZwiTfYQr2kDboZIDTe2roVrObhLtWQtDnTatAXuN-pTxMNIs3RSo9HMx13qJQWyvfge1lXzUop93gTpChGpVrBZA=s0"
phoimage = "https://lh3.googleusercontent.com/gps-cs-s/AG0ilSwJ-cwjeEWaqQx33AkZLrVGk3SRQ8WmijDHXKn7elBvB6lATdCQvzJhLJUOw5CPP2oIMvUzMXOlVGJ6HmZpl8O_2K0nNmm9Vz_Z0sd1qwEwdWbUlPF6Wgz73F16cJOBEr39uOFn=s0"
dauhurangmuoipho = "https://lh3.googleusercontent.com/geougc-cs/AMBA38uMoMAaMREZj9SFUMWOPZtUqn_7CnziB2vRK2N5yMed0jon3twSykZDevZfkGa73GHGlPa5-HAoUbSWMDFs37iIvN6EV8hu0K9HusUxLbZ6z5F9GDVPVYLLCBrvma_pQMRgHi-t9Q=s0"
banhtrangtron = "https://lh3.googleusercontent.com/geougc-cs/AMBA38t0abjIRTf0wDT7zlKsr1lFddm20spRsVIP1jJ4EtQFNb51JBWzz1KZj_2h3gWmq0e72_viEDGlhBpqGf8Foi4ScQ6ZSAbPsFlNe45WJfsglZJkGzinXsHEPic-sb9_BbAsz05L=s0"


text = "Japanese food"

def getSimilarity(url, text):
    # model = SentenceTransformer('clip-ViT-B-32')
    image = Image.open(requests.get(url, stream=True).raw)
    embedding = model.encode(image, normalize_embeddings=True, convert_to_numpy=True)
    text_embedding = model.encode(text, normalize_embeddings=True, convert_to_numpy=True)
    sim = util.cos_sim(embedding, text_embedding)
    return sim

print(getSimilarity(food, text))
print(getSimilarity(restaurant, text))
print(getSimilarity(phoimage, text))
print(getSimilarity(mrInmenu, text))
print(getSimilarity(dauhurangmuoipho, text))
# print(getSimilarity(souvenir, text))
print(getSimilarity(banhtrangtron, text))
print(getSimilarity(phomenu, text))


tensor([[0.2949]])
tensor([[0.2328]])
tensor([[0.2362]])
tensor([[0.2016]])
tensor([[0.2157]])
tensor([[0.1786]])
tensor([[0.2072]])


In [32]:
doannhatban = "https://lh3.googleusercontent.com/geougc-cs/AMBA38vGIgxWXqEhkjMJeiHGHDv7DZkXGnZNfjt0cGUaBEEpuV3kbLC9s6AwjjsThe09R0gwa3LQn6y6uRmpouFn5Y3va4GU5UTJdQdDJ6IukWoKO5Jy4HMjIjETBoJHNGmXBJRH915t=s0"
print(getSimilarity(doannhatban, text))

tensor([[0.2823]])
