In [26]:
import os
import sys
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from torch.utils.data import DataLoader
import multiprocessing
import tqdm
import ast
from annoy import AnnoyIndex



In [3]:
dir_path = os.path.dirname(os.path.abspath(os.getcwd()))

In [4]:
sys.path.append(os.path.join(dir_path, "src"))

In [5]:
from data_module import ImageDataModule
from resnet import Resnet50
from utils import collate_batch
from dataset import ImageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
data_channels = {
    "image_path": os.path.join(dir_path, "images", "raw/"),
    "dataset": os.path.join(dir_path, "data", "dataset/"),
}

In [7]:
dataset = ImageDataset(_dir=data_channels["dataset"], image_path=data_channels["image_path"])

In [8]:
dataloader = DataLoader(
            dataset,
            collate_fn=collate_batch,
            batch_size=32,
            num_workers=multiprocessing.cpu_count()
        )

In [9]:
ckpt = os.path.join(dir_path, "notebooks","lightning_logs","version_39","checkpoints","epoch=7-step=1952.ckpt")

In [10]:
model = Resnet50.load_from_checkpoint(ckpt, embedding_size=512, num_classes=19)



In [11]:
def extract_embeddings(model, dataloader):
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader):
            embeddings = model(batch["image"])
            all_embeddings.append(embeddings)
            all_labels.append(batch["targets"])
    
    all_embeddings = torch.cat(all_embeddings, dim=0).cpu().detach().numpy()
    all_labels = torch.cat(all_labels, dim=0).cpu().detach().numpy()
    
    return all_embeddings, all_labels

In [12]:
def save_metadata(dataloader):
    index_metadata = []
    
    for batch in tqdm.tqdm(dataloader):
        index_metadata.append(batch["metadata"])
    
    return [metadata for batch in index_metadata for metadata in batch]

In [13]:
index_embeddings, index_labels = extract_embeddings(model, dataloader)

100%|██████████| 305/305 [06:01<00:00,  1.18s/it]


In [14]:
index_metadata = save_metadata(dataloader)

100%|██████████| 305/305 [01:01<00:00,  4.96it/s]


In [17]:
index_df = pd.DataFrame(index_metadata)

In [33]:
index_df

Unnamed: 0,genre,image_name
0,['Animation'],qNBAXBIQlnOThrVvA6mA2B5ggV6.jpg
1,['Drama'],vJU3rXSP9hwUuLeq8IpfsJShLOk.jpg
2,['Science Fiction'],t6HIqrRAclMCA60NsSmeqe9RmNV.jpg
3,['Animation'],qVdrYN8qu7xUtsdEFeGiIVIaYd.jpg
4,['Comedy'],swzMoIVn6xjB857ziYJ8KBV440g.jpg
...,...,...
9734,['Comedy'],vkF8VLrazGtk9OjdEhihG6kKAhP.jpg
9735,['Thriller'],yw8x2i3vaHZZzpvqvF75E8q2N6M.jpg
9736,['Drama'],bFOmE3zCFU01TuomOOwClAWdvOD.jpg
9737,['Action'],kziBJGQFo9f0Vkj9s37qI0G9I0I.jpg


## Create Search index

In [27]:
embedding_size = 512
num_trees = 40 # total genre types = 19, will try num_classes X 2 initially

annoy_index = AnnoyIndex(embedding_size, 'euclidean')

for i, embedding in enumerate(index_embeddings):
    annoy_index.add_item(i, embedding)

annoy_index.build(num_trees)


True

## Save Search index and Metadata DF

In [28]:
annoy_index_file = 'annoy_index.ann'
metadata_file = 'metadata.csv'

In [29]:
annoy_index.save(annoy_index_file)

True

In [34]:
index_df.to_csv(metadata_file)
