In [2]:
from src.model import TripletNetwork, FasterRCNNEmbedder
from src.data import *
from src.transforms import albumentations_transform

from torch.nn import TripletMarginLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

# Initialize feature extractor, model, loss, optimizer, lr_scheduler

model = FasterRCNNEmbedder()
loss = TripletMarginLoss(margin=1.0, p=2)
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
lr_sceduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001, last_epoch=-1)

# initialize TripletNetwork for training
network = TripletNetwork(model,
                          loss,
                          optimizer,
                          lr_sceduler)

# initialize datamodule

dm = TripletDataModule(data_dir='/home/georg/projects/university/C5/task3/dataset/COCO',
                          json_file='/home/georg/projects/university/C5/task3/dataset/COCO/mcv_image_retrieval_annotations.json',
                          batch_size=96,
                          #transforms=albumentations_transform(),
                          num_workers=16,
                          dims=(224, 224))

# Initialize callbacks 
checkpointer = ModelCheckpoint(
    monitor="val_loss", save_top_k=1, mode="min", save_weights_only=True)
early_stopper = EarlyStopping(monitor="val_loss", patience=3, mode="min")
logger = CSVLogger("logs", name="TripletNetworkCSV")

In [3]:
# # Train the network

# trainer = pl.Trainer(max_epochs=20, 
#                     devices=1,
#                     accelerator='gpu',
#                     callbacks=[checkpointer, early_stopper],
#                     logger=logger,
#                     num_sanity_val_steps=0) 
# trainer.fit(network, dm)


In [4]:
# define helper functions to order the data 

def get_img_file_name(img_id, set):
    return 'COCO_{}2014_{:012d}.jpg'.format(set, img_id)

def prepare_data(json_file, mode):
        with open(json_file, 'r') as file:
            # Load the JSON data
            data = json.load(file)[mode]
        print(f'Loaded {len(data)} classes from {json_file}')
        img_ids = []
        labels = []
        # loop over classes 
        for key in tqdm(data.keys(), desc=f'Preparing {mode} data'):
            class_ = key
            images_with_class = data[key]
            # loop over images with the class
            for image_id in images_with_class:
                # if it's a new image, add it to the list of images and create a label list for it
                if image_id not in img_ids:
                    img_ids.append(image_id)
                    labels.append([])
            # loop over images and add the class to the label list if it's in the list of images
            for i, img_id in enumerate(img_ids):
                if img_id in images_with_class:
                    labels[i].append(int(class_))

        data_split = 'train' if mode in ['train', 'database'] else 'val'

        img_files = [get_img_file_name(img_id, data_split) for img_id in img_ids]
        return img_files, labels

In [5]:
import torch 

#load model from checkpoint and set to eval mode
model.load_state_dict(torch.load('/home/georg/projects/university/C5/task3/task_3e/logs/TripletNetworkCSV/version_19/checkpoints/epoch=8-step=7695.ckpt'), strict=False)
model.eval()

# specify json file path
data_json = '/home/georg/projects/university/C5/task3/dataset/COCO/mcv_image_retrieval_annotations.json'

In [6]:
import json
from tqdm import tqdm
import os 
from PIL import Image
from src.transforms import preprocess
import numpy as np

# define helper functions to extract embeddings from images using the model
def extract_embeddings(img_files, imgs_path, model):
    embeddings = []
    for img_file in tqdm(img_files):
        img_path = os.path.join(imgs_path, img_file)
        image = Image.open(img_path).convert('RGB')
        image = preprocess([224,224])(image)
        image = image.unsqueeze(0)
        pred = model(image)
        embeddings.append(pred.squeeze(0).cpu().detach().numpy())
    return np.array(embeddings)


# extract embeddings from the training images
train_imgs_path = '/home/georg/projects/university/C5/task3/dataset/COCO/train2014'
train_img_files, train_labels = prepare_data(json_file=data_json, mode='database')
train_embeddings = extract_embeddings(train_img_files, train_imgs_path, model)

Loaded 80 classes from /home/georg/projects/university/C5/task3/dataset/COCO/mcv_image_retrieval_annotations.json


Preparing database data: 100%|██████████| 80/80 [00:00<00:00, 5942.94it/s]
  0%|          | 0/1959 [00:00<?, ?it/s]

100%|██████████| 1959/1959 [01:30<00:00, 21.65it/s]


In [7]:
# Create FAISS index and add the training embeddings to it
import faiss  
             
index = faiss.IndexFlatL2(1024)   # build the index, d=size of vectors 
faiss.normalize_L2(train_embeddings)
print(train_embeddings.shape)
index.add(train_embeddings)                 # add vectors to the index
print(index.ntotal)


(1959, 1024)
1959


In [9]:
# Extract embeddings from the test/val images (can be configured using 'mode')

val_imgs_path = '/home/georg/projects/university/C5/task3/dataset/COCO/val2014'
val_img_files, val_labels = prepare_data(json_file=data_json, mode='test')
val_embeddings = extract_embeddings(val_img_files, val_imgs_path, model)

# Search for similar vectors k in the FAISS index
k = 5                       # we want 4 similar vectors
D, I = index.search(val_embeddings, k)     # actual search
print(I)

Loaded 80 classes from /home/georg/projects/university/C5/task3/dataset/COCO/mcv_image_retrieval_annotations.json


Preparing test data: 100%|██████████| 80/80 [00:00<00:00, 5519.27it/s]
  0%|          | 0/1917 [00:00<?, ?it/s]

100%|██████████| 1917/1917 [01:25<00:00, 22.47it/s]


[[  25   26   90   43   47]
 [1125   38 1117 1118 1129]
 [ 386  836  777  371   68]
 ...
 [ 400 1535 1374  548 1484]
 [1656 1695 1892  403 1281]
 [1117  854  466 1506  764]]


In [35]:
import matplotlib.pyplot as plt 
from sklearn.metrics import recall_score, precision_score, accuracy_score, average_precision_score


k = 5
visualize = False

targets = []
preds = []
for i, retrieval_indices in enumerate(I):
    query_img_file = val_img_files[i]
    query_img_path = os.path.join(val_imgs_path, query_img_file)
    query_img_labels = val_labels[i]
    retrieved_image_files = []
    retrieved_image_labels = []
    targets.append(1)

    for train_idx in retrieval_indices[:k]:
        retrieved_image_files.append(train_img_files[train_idx])
        retrieved_image_paths = [os.path.join(train_imgs_path, file) for file in retrieved_image_files]
        retrieved_image_labels.extend(train_labels[train_idx])
    
    
    preds.append(len(set(query_img_labels).intersection(set(retrieved_image_labels)))>0)
    if visualize ==True:
        query_img = cv2.imread(query_img_path)
        query_img = cv2.resize(query_img, (224, 224))
        retrieved_imgs = [cv2.imread(file) for file in retrieved_image_paths]
        retrieved_imgs = [cv2.resize(img, (224, 224)) for img in retrieved_imgs]
        
        fig, ax = plt.subplots(1, 6, figsize=(18, 3))
        ax[0].imshow(cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB))
        ax[0].set_title('Query Image')
        for j, img in enumerate(retrieved_imgs):
            ax[j+1].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            ax[j+1].set_title(f'Retrieved Image {j+1}')
        plt.savefig('query_plots/{:03d}.png'.format(i))
        plt.close()

print(len(targets))
print(len(preds))
precision = precision_score(targets, preds, average='binary')
recall = recall_score(targets, preds, average='binary')
accuracy = accuracy_score(targets, preds)
f1 = 2 * (precision * recall) / (precision + recall)
print(f'Precision: {precision}, \nRecall: {recall}, \nAccuracy: {accuracy}, \nF1: {f1}')
average_precision = average_precision_score(targets, preds)
print(f'Average precision: {average_precision}')



    

1917
1917
0.5101721439749609
Precision: 1.0, 
Recall: 0.5101721439749609, 
Accuracy: 0.5101721439749609, 
F1: 0.6756476683937824
Average precision: 1.0
