In [8]:
from PIL import Image
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np


# Load the dataset
dataset_path = 'reverse_image_search.csv'  # Replace with your dataset path
df = pd.read_csv(dataset_path)

# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")




In [12]:
# Initialize a list to store tensors
embeddings = []

for index, row in df.iterrows():
    image_path = row['path']  # Assuming the path is in a column named 'path'
    image = Image.open(image_path).convert('RGB')  # Ensure image is in RGB
    inputs = processor(images=image, return_tensors="pt")
    image_features = model.get_image_features(**inputs)
    # Ensure the tensor is detached from the computational graph before converting
    embedding = image_features.squeeze(0).detach().numpy().tolist()
    embedding = embedding/np.linalg.norm(embedding)

    embeddings.append(embedding)

# Concatenate all feature vectors into a single tensor
#image_features_tensor = torch.stack(embeddings)

# image_features_tensor now contains the feature vectors for all images in your dataset>

In [13]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

# Milvus parameters
HOST = '127.0.0.1'
PORT = '19530'
TOPK = 13

In [14]:
connections.connect(host=HOST, port=PORT)
collection_name = 'tranformers_clip_patch16'
dim = 512  # Dimension of the embeddings
METRIC_TYPE = 'L2'  # You can choose 'L2', 'IP', etc., based on your requirement
INDEX_TYPE = 'IVF_FLAT'  # Index type

In [15]:
utility.list_collections()

['text_image_search_blip',
 'image_search_collection',
 'tranformers_clip',
 'reverse_image_search',
 'transformers_clip',
 'text_image_search']

In [13]:
utility.drop_collection("tranformers_clip")

In [16]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
        FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=500, 
                    is_primary=True, auto_id=False),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    index_params = {
        'metric_type': METRIC_TYPE,
        'index_type': INDEX_TYPE,
        'params': {"nlist": 512}
    }
    collection.create_index(field_name='embedding', index_params=index_params)
    return collection

In [17]:
collection = create_milvus_collection(collection_name, dim)

In [18]:
paths = df['path'].tolist()


In [19]:
entities = [[path for path in paths],
            [embedding for embedding in embeddings]]
#preparing for insertion to milvus

In [20]:
mr = collection.insert(entities)


In [21]:
collection = Collection(collection_name)      # Get an existing collection.
collection.load()

In [22]:
search_params = {
    "metric_type": "L2", 
    "offset": 0, 
    "ignore_growing": False, 
    "params": {"nprobe": 10}
}

In [23]:

# search with image

query_image_path = 'aleren.jpeg'  
query_image = Image.open(query_image_path).convert('RGB')  
query_inputs = processor(images=query_image, return_tensors="pt")
query_image_features = model.get_image_features(**query_inputs)
embedding = query_image_features.squeeze(0).detach().numpy().tolist()


# Concatenate all feature vectors into a single tensor
#image_features_tensor = torch.stack(embeddings)

# image_features_tensor now contains the feature vectors for all images in your dataset

results = collection.search(
    data=[embedding], 
    anns_field="embedding", 
    # the sum of `offset` in `param` and `limit` 
    # should be less than 16384.
    param=search_params,
    limit=10,
    expr=None,
)



FileNotFoundError: [Errno 2] No such file or directory: 'aleren.jpeg'

In [45]:
results[0].ids


['./train/basketball/n02802426_24958.JPEG',
 './train/basketball/n02802426_7656.JPEG',
 './train/basketball/n02802426_3881.JPEG',
 './train/basketball/n02802426_12782.JPEG',
 './train/horizontal_bar/n03535780_16077.JPEG',
 './train/basketball/n02802426_7726.JPEG',
 './train/basketball/n02802426_26718.JPEG',
 './train/basketball/n02802426_8222.JPEG',
 './train/basketball/n02802426_10137.JPEG',
 './train/basketball/n02802426_12191.JPEG']

In [46]:
results[0].distances

[85.77123260498047,
 95.03716278076172,
 95.73870849609375,
 97.46134185791016,
 99.283447265625,
 101.21218872070312,
 101.49933624267578,
 101.66267395019531,
 102.91883087158203,
 103.99423217773438]

In [24]:
# search with text
query_text = "airplane"  
text_inputs = processor(text=query_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
query_text_features = model.get_text_features(**text_inputs)
text_embedding = query_text_features.squeeze(0).detach().numpy().tolist()

results = collection.search(
    data=[text_embedding], 
    anns_field="embedding", 
    # the sum of `offset` in `param` and `limit` 
    # should be less than 16384.
    param=search_params,
    limit=10,
    expr=None,
)



In [25]:
results[0].ids


['./train/warplane/n04552348_16150.JPEG', './train/parachute/n03888257_4738.JPEG', './train/warplane/n04552348_13334.JPEG', './train/warplane/n04552348_12780.JPEG', './train/parachute/n03888257_25150.JPEG', './train/horizontal_bar/n03535780_40969.JPEG', './train/can_opener/n02951585_32431.JPEG', './train/bullet_train/n02917067_12974.JPEG', './train/parachute/n03888257_13175.JPEG', './train/can_opener/n02951585_24077.JPEG']

In [26]:
# search with text 
#patch 16
query_text = "cat"  
text_inputs = processor(text=query_text, return_tensors="pt")
query_text_features = model.get_text_features(**text_inputs)
text_embedding = query_text_features.squeeze(0).detach().numpy().tolist()

text_embedding = text_embedding/np.linalg.norm(text_embedding)

print(text_embedding)

[ 3.31491634e-02 -2.03781980e-02  1.16234710e-04 -2.73521694e-03
  8.73957664e-03  4.97225687e-02  1.51330768e-02  1.20761029e-02
  3.32283047e-02 -3.54530867e-02 -1.46911245e-02 -4.67191712e-03
  4.20575467e-02 -1.09020783e-02 -1.99637226e-04 -2.71906871e-02
  9.54397225e-03 -1.06539634e-02 -2.09423192e-02 -2.05251296e-02
  1.30982930e-03  1.34396611e-02 -8.03361199e-02  6.13869963e-03
  2.20196314e-02 -5.00463309e-02 -4.61846680e-03  2.88087900e-02
  6.34654870e-03 -4.31359486e-02  3.25238494e-02 -9.53553278e-03
  2.39674856e-02  1.79717915e-02  1.77497047e-02  2.02372686e-02
 -1.46177078e-02 -2.12734175e-02  6.84565874e-03 -2.95933201e-02
 -8.11491804e-03 -1.12608875e-02  1.91217761e-02 -9.20535082e-03
  8.33442296e-03 -1.01767487e-02 -1.29096461e-02  1.41746369e-02
 -8.06408686e-03  5.78491187e-03  2.05520558e-02  1.48469991e-02
 -2.18581839e-02  4.37485099e-03 -9.59421089e-03  1.37751745e-02
 -1.36853935e-02  1.35774674e-02 -8.23595600e-03 -1.55879756e-02
  8.44839876e-03 -3.45796

In [29]:
results = collection.search(
    data=[text_embedding], 
    anns_field="embedding", 
    # the sum of `offset` in `param` and `limit` 
    # should be less than 16384.
    param=search_params,
    limit=10,
    expr=None,
)



In [26]:
# search with text 
#patch 32
query_text = "cat"  
text_inputs = processor(text=query_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
query_text_features = model.get_text_features(**text_inputs)
text_embedding = query_text_features.squeeze(0).detach().numpy().tolist()

print(text_embedding)

[0.2051691710948944, -0.03279348835349083, -0.06166143715381622, -0.03970229625701904, -0.2433445155620575, 0.21964505314826965, -0.34200319647789, 0.12462783604860306, -0.3632577955722809, 0.28449299931526184, 0.005944356322288513, -0.48552653193473816, 0.2448098212480545, -0.07519945502281189, 0.34867072105407715, 0.4418478310108185, 0.17099052667617798, -0.21698689460754395, -0.07976849377155304, 0.07817978411912918, 0.410218209028244, 0.3069245517253876, 0.2493753582239151, 0.10165336728096008, -0.19615575671195984, 0.3654637336730957, 0.352477103471756, 0.6314743757247925, -0.0062239691615104675, 0.0031716041266918182, 0.11212191730737686, -0.012524619698524475, 0.17180348932743073, 0.17517916858196259, -0.09132827818393707, 0.163207545876503, 0.17541882395744324, 0.3215360641479492, 0.3618620038032532, 0.19236218929290771, -0.0018282458186149597, -0.018739476799964905, 0.2335728406906128, 0.27410566806793213, 0.1741076558828354, 0.2329310178756714, -0.06656691431999207, -0.233122

In [30]:
results[0].ids


['./train/tub/n04493381_4746.JPEG', './train/tiger_cat/n02123159_2317.JPEG', './train/remote_control/n04074963_2858.JPEG', './train/dishwasher/n03207941_436.JPEG', './train/tiger_cat/n02123159_509.JPEG', './train/lynx/n02127052_3830.JPEG', './train/lynx/n02127052_9325.JPEG', './train/tiger_cat/n02123159_5205.JPEG', './train/lynx/n02127052_19993.JPEG', './train/dishwasher/n03207941_13200.JPEG']

In [32]:
for result_path in results[0].ids:
    result_image = Image.open(result_path).convert('RGB')
    result_image.show()


)07[?47h[1;24r[m[4l[?1h=[m[m[37m[40m[1;1H                                                                                [2;1H                                                                                [3;1H                                                                                [4;1H                                                                                [5;1H                                                                                [6;1H                                                                                [7;1H                                                                                [8;1H                                                                                [9;1H                                                                                [10;1H                                                                                [11;1H                                                                                [12

Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more in

)07[?47h[1;24r[m[4l[?1h=[m[m[37m[40m[1;1H                                                                                [2;1H                                                                                [3;1H                                                                                [4;1H                                                                                [5;1H                                                                                [6;1H                                                                                [7;1H                                                                                [8;1H                                                                                [9;1H                                                                                [10;1H                                                                                [11;1H                                                                                [12

In [136]:
results[0].distances

[154.8622589111328,
 154.98194885253906,
 156.00918579101562,
 157.588134765625,
 158.6085968017578,
 158.73123168945312,
 159.7007293701172,
 159.90402221679688,
 160.07493591308594,
 160.310546875]