In [None]:
pip install pymilvus torch gdown torchvision tqdm matplotlib

In [None]:
import gdown
import zipfile

url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings")

In [None]:
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search'
DIMENSION = 2048

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3

In [None]:
from pymilvus import MilvusClient, DataType

client = MilvusClient("./milvus_demo.db")

# Remove any previous collections with the same name
if client.has_collection(COLLECTION_NAME):
    client.drop_collection(COLLECTION_NAME)

# Create schema
schema = client.create_schema(
    auto_id=False,
    enable_dynamic_field=True,
)

# Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, max_length=256)
schema.add_field(field_name="image_embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION)

# Prepare index params
index_params = client.prepare_index_params()
index_params.add_index(
    field_name="image_embedding",
    index_type="FLAT",
    index_name="vector_index",
    metric_type="COSINE"
)

# Create a collection with the index loaded simultaneously
client.create_collection(
    collection_name=COLLECTION_NAME,
    schema=schema,
    index_params=index_params
)

res = client.get_load_state(
    collection_name=COLLECTION_NAME
)
print(res)

In [None]:
import torch
from torchvision import transforms

# Load the embedding model with the last layer removed
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()

# Preprocessing for images
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
import glob
from PIL import Image
from tqdm import tqdm

# Embed function that embeds the batch and inserts it
def embed(data_array):
    with torch.no_grad():
        output = model(torch.stack(data_array[0])).squeeze().tolist()
        data = [
            {"filepath": data_array[1][i], "image_embedding": output[i]} for i in range(len(data_array[0]))
        ]
        client.insert(collection_name=COLLECTION_NAME, data=data)

data_batch = [[], []]

# Get the filepaths of the images
paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
len(paths)

# Read the images into batches for embedding and insertion
for path in tqdm(paths):
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed(data_batch)
        data_batch = [[], []]

# Embed and insert the remainder
if len(data_batch[0]) != 0:
    embed(data_batch)

In [None]:
# Embed the search images
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[], []]

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
res = client.search(
    collection_name=COLLECTION_NAME,
    data=embeds,
    anns_field='image_embedding',
    limit=TOP_K,
    output_fields=['filepath'])
print(res)

In [None]:
from matplotlib import pyplot as plt

# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit['entity']['filepath']))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit['distance']))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')