In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from glob import glob
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.preprocessing import Normalizer

from pymilvus import (
    connections, 
    FieldSchema, 
    CollectionSchema, 
    DataType, 
    Collection, 
    drop_collection, 
    utility
)

import albumentations as albu
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings(action="ignore")

# Datset

In [2]:
class Transforms:
    def __init__(self):
        transforms = [
            albu.LongestMaxSize(max_size=224, always_apply=True, p=1),
            albu.PadIfNeeded(
                min_height=224,
                min_width=224,
                always_apply=True,
                border_mode=0,
                value=(255, 255, 255),
            ),
            albu.Normalize(),
            ToTensorV2(),
        ]

        self.transforms = albu.Compose(transforms)

    def __call__(self, img, *args, **kwargs):
        return self.transforms(image=np.array(img))["image"]

dataset_folder_to_use = Path("../data/interim/dataset_part/")

val_dataset = ImageFolder(
    root=str(dataset_folder_to_use / "train"),
    transform=Transforms(),
)

val_dl = DataLoader(
    val_dataset,
    32,
    pin_memory=False,
    shuffle=False,
    num_workers=1,
    drop_last=False
)

# Model

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

model = torch.jit.load("torchscript.pt").to(device)

Running on device: cuda:0


In [4]:
def get_embeddings(model, dataloader):
    classes = []
    embeddings = []
    model.eval()
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(dataloader)):
            classes.extend(y.numpy())
            _, embeddings_tmp = model(x.cuda())
            embeddings_tmp = list(embeddings_tmp.cpu())
            embeddings.extend(embeddings_tmp)
    classes = np.array(classes)
    embeddings = np.array([x.numpy() for x in embeddings])
    print(f"len embeddings: {len(embeddings)}, len classes {len(classes)}")
    return embeddings, classes

In [5]:
%%time
embeddings, classes = get_embeddings(model, dataloader=val_dl)

  0%|          | 0/34 [00:00<?, ?it/s]

len embeddings: 1083, len classes 1083
CPU times: user 10 s, sys: 3.71 s, total: 13.7 s
Wall time: 16.1 s


In [11]:
normalizer = Normalizer().fit(embeddings)
embeddings_norm = normalizer.transform(embeddings)
with open('pickles/normalizer.pickle', 'wb') as handle:
    pickle.dump(normalizer, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [13]:
data_pathes = pd.DataFrame(val_dataset.imgs, columns=["paths", "classes_idx"])

mapper = {y:x for x,y in list(zip(val_dataset.class_to_idx, val_dataset.class_to_idx.values()))}

data_pathes["classes_names"] = data_pathes["classes_idx"].apply(lambda x: mapper[x])
data_pathes["paths"] = data_pathes["paths"].apply(lambda x: Path(x).name)
data_pathes.to_csv("./data_pathes.csv")

# Milvus

In [None]:
# connection.close()
connection = connections.connect(host='0.0.0.0', port='19530')

In [None]:
has = utility.has_collection("demo_metric")
print(f"Does collection demo_metric exist in Milvus: {has}")

In [None]:
drop_collection(collection_name='demo_metric')

In [None]:
collection_name = 'demo_metric'

schema = CollectionSchema([
            FieldSchema("embedding_id", DataType.INT64, is_primary=True),
            FieldSchema("label_id", DataType.INT64),
            FieldSchema("embeddings", dtype=DataType.FLOAT_VECTOR, dim=512)
        ])

collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)

In [23]:
data = [[i for i in range(len(classes))], 
        classes.tolist(), 
        embeddings_norm.tolist()
       ]

with open('pickles/embeddings.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
collection.insert(data)

index_params = {
        "metric_type":"L2",
        "index_type":"IVF_FLAT",
        "params":{"nlist":1024}
    }
collection.create_index("embeddings", index_params=index_params)
collection.load()

In [26]:
with open('pickles/mapper_faces.pickle', 'wb') as handle:
    pickle.dump(val_dataset.class_to_idx, handle, protocol=pickle.HIGHEST_PROTOCOL)