# Ok listen!
93 and 346 are ACTUALLY not that bad, even I am surprised. Ignore 18...

In [17]:
import uuid

import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from qdrant_client import QdrantClient, models
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 9 * 9, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

ds = load_dataset("pantelism/cats-vs-dogs", trust_remote_code=True)
split = ds['train'].train_test_split(test_size=0.2, seed=42)
val_test = split['test'].train_test_split(test_size=0.5, seed=42)
test_dataset = val_test['test']

transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class HFDataset(Dataset):
    def __init__(self, hf_ds, transform=None):
        self.ds = hf_ds
        self.transform = transform
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"].convert("RGB")
        # Create a version for display before transforming it into a tensor
        display_img = transforms.Resize((150, 150))(img)

        if self.transform:
            tensor_img = self.transform(img)
        else:
            tensor_img = transforms.ToTensor()(display_img)

        return tensor_img, np.array(display_img)


test_ds = HFDataset(test_dataset, transform)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

model = CNN().to(device)
try:
    model.load_state_dict(torch.load('cat_dog_cnn_model.pth', map_location=device))
except FileNotFoundError:
    exit()

model.eval()

Using device: cpu


CNN(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=10368, out_features=1, bias=True)
  )
)

In [18]:
class FeatureExtractor(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x
feature_extractor = FeatureExtractor(model).to(device)
feature_extractor.eval()

FeatureExtractor(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)

In [20]:
client = QdrantClient(":memory:")

collection_name = "image_embeddings"
vector_size = 128 * 9 * 9

client.recreate_collection(
    collection_name=collection_name,
    vectors_config=models.VectorParams(
        size=vector_size,
        distance=models.Distance.COSINE
    )
)

points_to_upload = []
with torch.no_grad():
    for i, (batch_images, _) in enumerate(test_loader):
        batch_images = batch_images.to(device)
        embeddings = feature_extractor(batch_images).cpu().numpy()
        for j, embedding in enumerate(embeddings):
            original_index = i * test_loader.batch_size + j
            points_to_upload.append(
                models.PointStruct(
                    id=str(uuid.uuid4()),
                    vector=embedding.tolist(),
                    payload={"original_index": original_index}
                )
            )

client.upsert(
    collection_name=collection_name,
    points=points_to_upload,
    wait=True
)

  client.recreate_collection(


UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)

In [None]:

def display_results(query_img, search_results, test_dataset, query_index):
    fig, axes = plt.subplots(1, len(search_results) + 1, figsize=(15, 5))

    axes[0].imshow(query_img)
    axes[0].set_title(f"Query Image ({query_index})")
    axes[0].axis('off')

    for i, result in enumerate(search_results):
        retrieved_idx = result.payload['original_index']
        _, retrieved_img = test_dataset[retrieved_idx]
        axes[i+1].imshow(retrieved_img)
        axes[i+1].set_title(f"Idx: {retrieved_idx}\nScore: {result.score:.3f}")
        axes[i+1].axis('off')

    plot_filename = f"similarity_results_query_{query_index}.png"
    plt.savefig(plot_filename, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved similarity plot: {plot_filename}")

k = 5 # top 5 images
# random_indices = np.random.choice(len(test_ds), 3, replace=False)
random_indices = [18, 93, 346] # 4 consistency

for query_index in random_indices:
    query_index_int = int(query_index)
    print(f"\nPerforming search for image at index: {query_index_int}")
    query_tensor, query_img_display = test_ds[query_index_int]
    query_vector = feature_extractor(query_tensor.unsqueeze(0).to(device)).detach().cpu().numpy().flatten()
    search_results = client.search(
        collection_name=collection_name,
        query_vector=query_vector,
        limit=k
    )
    display_results(query_img_display, search_results, test_ds, query_index_int)


Performing search for image at index: 18


  search_results = client.search(


Saved similarity plot: similarity_results_query_18.png

Performing search for image at index: 93
Saved similarity plot: similarity_results_query_93.png

Performing search for image at index: 346
Saved similarity plot: similarity_results_query_346.png
