In [None]:
! pip install pandas

In [None]:
import os
import lancedb
import numpy as np
import timm
import torch
from PIL import Image
from torchvision import transforms
from sklearn.preprocessing import normalize

class EmbeddingGeneratorService:
    def __init__(
        self,
        model_path: str = "../model/resnet50d.ra2_in1k_fine_tune_51_classes_2024-10-06_12-01-37.pth",
    ):
        self.resnet_model = self.load_model(model_path)

    def load_model(self, model_path: str) -> torch.nn.Module:
        """
        Load the ResNet model
        """
        model = timm.create_model("resnet50d", pretrained=False, num_classes=51)
        model.reset_classifier(0)

        checkpoint = torch.load(model_path)

        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            model.load_state_dict(checkpoint)

        model.eval()
        return model

    def preprocess_image(
        self, image: Image.Image, device: str, target_size: tuple = (224, 224)
    ) -> torch.Tensor:
        """
        Preprocess the input image
        """
        transform = transforms.Compose(
            [
                transforms.Resize(target_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=self.resnet_model.default_cfg["mean"],
                    std=self.resnet_model.default_cfg["std"],
                ),
            ]
        )

        input_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device)
        return input_tensor

    def _generate_embedding(self, img_path: str) -> np.ndarray:
        """
        Generate embeddings for the given input image
        """
        # device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        device = "cpu"

        input_tensor = self.preprocess_image(Image.open(img_path), device)
        self.resnet_model.to(device)
        self.resnet_model.eval()

        with torch.no_grad():
            embedding = self.resnet_model(input_tensor)

        embedding = embedding.cpu().numpy().flatten()

        return embedding
    
    def generate_embedding(self, img_path: str) -> np.ndarray:
        """
        Generate embeddings for the given input image
        """
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

        input_tensor = self.preprocess_image(Image.open(img_path), device)
        self.resnet_model.to(device)
        self.resnet_model.eval()

        with torch.no_grad():
            embedding = self.resnet_model(input_tensor)

        embedding = embedding.cpu().numpy().flatten()

        return embedding    
    
    def get_all_embeddings(self, folder_path):
        """
        Get all embeddings of images in the folder using the embedding model
        """
        
        embeddings = []
        image_paths = []
        
        for root, _, files in os.walk(folder_path):
            for file in files:
                # if file == "1.jpg":
                image_path = os.path.join(root, file)
                print(f"Processing {image_path}")
                # image = Image.open(image_path)
                embedding = self.generate_embedding(image_path)
                embeddings.append(embedding)
                image_paths.append(image_path)

        return np.array(embeddings), image_paths

embedding_service = EmbeddingGeneratorService()

all_embeddings, image_paths = embedding_service.get_all_embeddings("../kaggle_data/all_birds")
normalized_embeddings = normalize(all_embeddings, axis=1)





Processing ../kaggle_data/all_birds/CREAM COLORED WOODPECKER/bird.jpg
Processing ../kaggle_data/all_birds/EASTERN TOWEE/bird.jpg
Processing ../kaggle_data/all_birds/SHOEBILL/bird.jpg
Processing ../kaggle_data/all_birds/BLACK-NECKED GREBE/bird.jpg
Processing ../kaggle_data/all_birds/AMERICAN COOT/bird.jpg
Processing ../kaggle_data/all_birds/AFRICAN PYGMY GOOSE/bird.jpg
Processing ../kaggle_data/all_birds/PARUS MAJOR/bird.jpg
Processing ../kaggle_data/all_birds/PUNA TEAL/bird.jpg
Processing ../kaggle_data/all_birds/BARROWS GOLDENEYE/bird.jpg
Processing ../kaggle_data/all_birds/BLACK VULTURE/bird.jpg
Processing ../kaggle_data/all_birds/HAWFINCH/bird.jpg
Processing ../kaggle_data/all_birds/BALD EAGLE/bird.jpg
Processing ../kaggle_data/all_birds/MERLIN/bird.jpg
Processing ../kaggle_data/all_birds/BLUE HERON/bird.jpg
Processing ../kaggle_data/all_birds/ROSE BREASTED COCKATOO/bird.jpg
Processing ../kaggle_data/all_birds/SAMATRAN THRUSH/bird.jpg
Processing ../kaggle_data/all_birds/GLOSSY IBIS/

In [4]:
# Initialize or connect to a LanceDB database
db = lancedb.connect("../lancedb")  # Local storage

# Define vector dimension (e.g., 1024 for your case)
vector_dim = 2048

bird_data = []

selected_bird = {}
for i, bird in enumerate(zip(all_embeddings, image_paths)):
    if i == 5:
        print(bird)
        selected_bird = bird
    bird_data.append({"image_path": bird[1], "vector": bird[0].tolist()})

# Create a table (or open if it exists)
table = db.create_table("embeddings", data=bird_data, mode="overwrite")  # Overwrites if exists
# table.create_index("cosine")

(array([6.3070692e-03, 0.0000000e+00, 1.5722051e-03, ..., 1.7533947e+00,
       4.2782438e-01, 3.1915866e-02], dtype=float32), '../kaggle_data/all_birds/AFRICAN PYGMY GOOSE/bird.jpg')


In [5]:
unknown_bird =  embedding_service.generate_embedding("./unknowns/1.jpg")
results = table \
    .search(unknown_bird) \
    .metric("cosine") \
    .limit(6) \
    .to_pandas()
results

Unnamed: 0,image_path,vector,_distance
0,../kaggle_data/all_birds/CREAM COLORED WOODPEC...,"[0.339799, 0.8893364, 0.0, 0.10170445, 1.26368...",0.355239
1,../kaggle_data/all_birds/CEDAR WAXWING/bird.jpg,"[2.3690917, 0.031060576, 0.0, 0.0, 0.71125716,...",0.547832
2,../kaggle_data/all_birds/EMERALD TANAGER/bird.jpg,"[0.15029724, 0.045553826, 0.0, 0.0, 0.06174791...",0.547924
3,../kaggle_data/all_birds/ORANGE BREASTED TROGO...,"[0.01295435, 0.022253562, 0.0, 0.20944391, 0.2...",0.575007
4,../kaggle_data/all_birds/COPPERSMITH BARBET/bi...,"[0.112756304, 0.0, 0.0, 0.0, 0.80765474, 0.126...",0.582224
5,../kaggle_data/all_birds/HAWFINCH/bird.jpg,"[0.02649048, 0.019148355, 0.5708676, 0.1696113...",0.583555


In [6]:
unknown_bird =  embedding_service.generate_embedding("./unknowns/2.jpg")
results = table \
    .search(unknown_bird) \
    .metric("cosine") \
    .limit(6) \
    .to_pandas()
results

Unnamed: 0,image_path,vector,_distance
0,../kaggle_data/all_birds/JAVA SPARROW/bird.jpg,"[0.06697277, 0.001832604, 0.0, 0.0, 3.5173216,...",0.030168
1,../kaggle_data/all_birds/PARUS MAJOR/bird.jpg,"[0.09032001, 0.025061814, 0.0, 0.2794906, 0.14...",0.52491
2,../kaggle_data/all_birds/BOBOLINK/bird.jpg,"[0.13251407, 0.0, 0.0, 0.0, 0.14810395, 0.0254...",0.572058
3,../kaggle_data/all_birds/PATAGONIAN SIERRA FIN...,"[0.0, 0.03416685, 0.0, 0.0, 0.17393155, 0.0865...",0.578968
4,../kaggle_data/all_birds/COPPERSMITH BARBET/bi...,"[0.112756304, 0.0, 0.0, 0.0, 0.80765474, 0.126...",0.588615
5,../kaggle_data/all_birds/AFRICAN PYGMY GOOSE/b...,"[0.006307069, 0.0, 0.0015722051, 0.08193158, 0...",0.617619


In [9]:
unknown_bird =  embedding_service.generate_embedding("./unknowns/3.jpg")
results = table \
    .search(unknown_bird) \
    .metric("cosine") \
    .limit(6) \
    .to_pandas()
results

Unnamed: 0,image_path,vector,_distance
0,../kaggle_data/all_birds/PARUS MAJOR/bird.jpg,"[0.09032001, 0.025061814, 0.0, 0.2794906, 0.14...",0.079511
1,../kaggle_data/all_birds/ORANGE BREASTED TROGO...,"[0.01295435, 0.022253562, 0.0, 0.20944391, 0.2...",0.41063
2,../kaggle_data/all_birds/PATAGONIAN SIERRA FIN...,"[0.0, 0.03416685, 0.0, 0.0, 0.17393155, 0.0865...",0.480254
3,../kaggle_data/all_birds/INLAND DOTTEREL/bird.jpg,"[0.0, 0.22410323, 0.0, 0.01811927, 0.86598057,...",0.532046
4,../kaggle_data/all_birds/PUNA TEAL/bird.jpg,"[0.040451776, 0.96030307, 0.0, 0.14561963, 0.4...",0.532208
5,../kaggle_data/all_birds/AFRICAN PYGMY GOOSE/b...,"[0.006307069, 0.0, 0.0015722051, 0.08193158, 0...",0.537022


In [16]:
import json

unknown_bird = embedding_service.generate_embedding("./unknowns/5.jpg")
results = table \
    .search(unknown_bird) \
    .metric("cosine") \
    .limit(6) \
    .to_pandas()
results

# Keep only the image_path and _distance columns
results = results[["image_path", "_distance"]]

# Extract the label from the image_path
results["label"] = results["image_path"].apply(lambda x: x.split("/")[-2])
results["similarity"] = 1 - results["_distance"]


results_list = results.to_dict(orient="records")

json_results = json.dumps(results_list, indent=4, default=str)
print(json_results)

[
    {
        "image_path": "../kaggle_data/all_birds/SCARLET TANAGER/bird.jpg",
        "_distance": 0.08229416608810425,
        "label": "SCARLET TANAGER",
        "similarity": 0.9177058339118958
    },
    {
        "image_path": "../kaggle_data/all_birds/BORNEAN BRISTLEHEAD/bird.jpg",
        "_distance": 0.46663737297058105,
        "label": "BORNEAN BRISTLEHEAD",
        "similarity": 0.533362627029419
    },
    {
        "image_path": "../kaggle_data/all_birds/STRIPPED SWALLOW/bird.jpg",
        "_distance": 0.6153758764266968,
        "label": "STRIPPED SWALLOW",
        "similarity": 0.3846241235733032
    },
    {
        "image_path": "../kaggle_data/all_birds/BOBOLINK/bird.jpg",
        "_distance": 0.6520496010780334,
        "label": "BOBOLINK",
        "similarity": 0.34795039892196655
    },
    {
        "image_path": "../kaggle_data/all_birds/GREATER PEWEE/bird.jpg",
        "_distance": 0.657173752784729,
        "label": "GREATER PEWEE",
        "similarity": 0