In [1]:
COLLECTION_NAME = "face_search_indian" 
DIMENSION = 512
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
BATCH_SIZE = 128

In [2]:
from pymilvus import connections

connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

In [3]:
from pymilvus import utility

if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

In [4]:
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(
        name="filepath", dtype=DataType.VARCHAR, max_length=1000
    ),  
    FieldSchema(name="image_embedding", dtype=DataType.FLOAT_VECTOR, dim=DIMENSION),
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

In [5]:
index_params = {
    "metric_type": "L2",
    "index_type": "FLAT",
}
collection.create_index(field_name="image_embedding", index_params=index_params)
collection.load()

In [6]:
import glob

paths = glob.glob("./archive/**/Priyanka_Chopra/*.jpg", recursive=True)
paths2 = glob.glob("./archive/**/Pooja_Hegde/*.jpg", recursive=True)
paths = paths + paths2
len(paths)

283

In [9]:
import numpy as np
def embed_search_split(ratio, paths):
    # shuffle the paths
    paths = np.random.permutation(paths)
    split = int(len(paths) * ratio)
    return paths[:split], paths[split:]

train_paths, test_paths = embed_search_split(0.8, paths)

In [10]:
countPo = 0
countPr = 0
for path in train_paths:
    if "Pooj" in path:
        countPo+=1
    else:
        countPr+=1
        
countPo, countPr

(123, 103)

In [11]:
from facenet_pytorch import InceptionResnetV1, MTCNN

mtcnn = MTCNN(image_size=160, margin=0, min_face_size=20)

def extract_face(filepath):
    img = Image.open(filepath)
    img_cropped = mtcnn(img)
    return img_cropped

model = InceptionResnetV1(pretrained="vggface2").eval()

In [12]:
import torch
from torchvision import transforms

def embed(data):
    with torch.no_grad():
        output = model(torch.stack(data[0])).squeeze()
        collection.insert([data[1], output.tolist()])


data_batch = [[], []]

In [13]:
len(train_paths), len(test_paths)

(226, 57)

In [14]:
countPo = 0
countPr = 0
for path in test_paths:
    if "Pooj" in path:
        countPo += 1
    else:
        countPr += 1

countPo, countPr

(25, 32)

In [15]:
from PIL import Image
from tqdm import tqdm

for path in tqdm(train_paths):
    im = extract_face(path)
    if im is None:
        continue
    data_batch[0].append(im)
    data_batch[1].append(path)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed(data_batch)
        data_batch = [[], []]

100%|██████████| 226/226 [02:23<00:00,  1.58it/s]


In [16]:
def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

In [17]:
data_batch = [[], []]

for path in test_paths:
    im = extract_face(path)
    if im is None:
        continue
    data_batch[0].append(im)
    data_batch[1].append(path)

In [18]:
import time
embeds = embed(data_batch[0])

In [19]:
start = time.time()
res = collection.search(
    embeds,
    anns_field="image_embedding",
    output_fields=["filepath"],
    param={
        "metric_type": "L2",
        "params": {"radius": 1.0},
    },
    limit=150,
)
finish = time.time()

In [51]:
N = 7
len(res[N])

54

In [52]:
r = res[N]
rpaths = [x.entity.get("filepath") for x in r]

In [53]:
data_batch[1][N]

'./archive/bollywood_celeb_faces_1/Priyanka_Chopra/79.jpg'

In [54]:
countPo = 0
countPr = 0
for path in rpaths:
    if "Pooj" in path:
        countPo += 1
    else:
        countPr += 1

countPo, countPr

(3, 51)