In [None]:
import cv2
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        # transforms.Normalize(mean, std),
        # lambda x: torch.flip(x, [1]),
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)
dataset = torchvision.datasets.ImageFolder(
    root="./data/caltech101/train",
    transform=transform,
)

valset = torchvision.datasets.ImageFolder(
    root="./data/caltech101/val",
    transform=transform,
)

testset = torchvision.datasets.ImageFolder(
    root="./data/caltech101/test",
    transform=transform,
)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=128, shuffle=False, num_workers=2
)

valloader = torch.utils.data.DataLoader(
    valset, batch_size=128, shuffle=False, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2
)

In [None]:
from cbir import *
from cbir.pipeline import *

rgb_histogram = RGBHistogram(n_bin=8, h_type="region")
resnet = ResNetExtractor(model = "resnet18", device="cuda")
siftbow = SIFTBOWExtractor(mode="tfidf")
# vgg16 = VGG16Extractor(device="cuda")
array_store = NPArrayStore(retrieve=KNNRetrieval(metric="cosine"))

In [None]:
from tqdm import tqdm
import numpy as np
train_img = []
for images, labels in tqdm(valloader):
    images = (images.numpy().transpose(0,2,3,1) * 255).astype(np.uint8)
    train_img.append(images)
    
train_img = np.concatenate(train_img)
siftbow.fit(train_img, k=64)

In [None]:
cbir = CBIR(resnet, array_store)

In [None]:
for images, labels in tqdm(dataloader):
    # images = (images.numpy().transpose(0,2,3,1) * 255).astype(np.uint8)
    images = images.numpy()
    cbir.indexing(images)

In [None]:
cbir.feature_extractor.mode = "tfidf"

In [None]:
img = cv2.imread("/home/edtechai/works/vunt/CV-basic/data/caltech101/test/ant/image_0024.jpg")
img = transform(img)
# img = (img.numpy().transpose(1,2,0) * 255).astype(np.uint8)
img = img.unsqueeze(0).numpy()

In [None]:
# rs = cbir.retrieve(np.expand_dims(img, axis=0), k=5)
rs = cbir.retrieve(img, k=10)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2,5)
for i, r in enumerate(rs[0]):
    ax[i // 5][i % 5].imshow(r.image.transpose((1, 2, 0)))
    # ax[i // 5][i % 5].imshow(r.image)
    ax[i // 5][i % 5].axis("off")
plt.show()

In [None]:
rs = []
ground_truth = []
for images, labels in tqdm(testloader):
    # images = (images.numpy().transpose(0,2,3,1) * 255).astype(np.uint8)
    images = images.numpy()
    rs.extend(cbir.retrieve(images, k=10))
    ground_truth.extend(labels)

In [None]:
np.take(dataset.targets, predicted, axis=0)

In [None]:
def average_precision(predictions, ground_truths, k):
    top_k_predictions = predictions[:k]
    
    relevant_items = sum(1 for pred in top_k_predictions if pred in ground_truths)
    
    ap = relevant_items / len(top_k_predictions)
    
    return ap

def hit_rate(predictions, ground_truths, k):
    top_k_predictions = predictions[:k]
    
    relevant_items = sum(1 for pred in top_k_predictions if pred in ground_truths)
    
    hit = 1 if relevant_items >= 1 else 0
    
    return hit

In [None]:
ap = []
hit = []
for r, g in zip(rs, ground_truth):
    predicted = []
    for i in r:
        predicted.append(i.index)
    class_preds = np.take(dataset.targets, predicted, axis=0)
    ap.append(average_precision(class_preds.tolist(), [g.tolist()], 10))
    hit.append(hit_rate(class_preds.tolist(), [g.tolist()], 10))

In [None]:
pd.concat

In [None]:
np.mean(ap)
np.mean(hit)

In [None]:
import pandas as pd
pd.DataFrame({"options": [1], "map": [np.mean(ap)], "hit_rate": [np.mean(hit)]})

In [None]:
def grid(*args):
    if len(args) == 1:
        for k in args[0]:
            yield [k]
    else:
        for k in args[0]:
            for rest in grid(*args[1:]):
                yield([k] + rest)

In [None]:
for a in grid([1, 2, 3], [4, 5, 6], [1, 3, 2]):
    print(a)

In [None]:
for p in {"a": [1, 3, 4], "b": [12, 5, 7]}:
    break

p