In [None]:
import torch

torch.cuda.is_available()

In [None]:
import torch
import torch.nn as nn
from torchvision import models , transforms
from PIL import Image


resnet = models.resnet18(pretrained=True)

model = nn.Sequential(*list(resnet.children())[:-1])
model.eval()

In [None]:
# Image preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

def get_embedding(image_path: str):
    img = Image.open(image_path).convert("RGB")
    img_t = preprocess(img).unsqueeze(0)  # add batch dimension
    with torch.no_grad():
        embedding = model(img_t).squeeze().numpy()  # shape (2048,)
    return embedding


In [None]:
embedding = get_embedding("Images/airmax-97.jpg")
print(embedding.shape) 



In [None]:
embedding = get_embedding("Images/nike-sb.png")
embedding

In [None]:
import os
from dotenv import load_dotenv
from astrapy import DataAPIClient
from astrapy.constants import VectorMetric
from astrapy.info import CollectionDefinition

# Load environment variables from .env file
load_dotenv()

ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")

client = DataAPIClient()

db = client.get_database(
    ASTRA_DB_API_ENDPOINT,
    token=ASTRA_DB_APPLICATION_TOKEN
)

In [None]:


# Create collection (only runs if it doesn't exist)
my_collection = db.create_collection(
    "sneaker_search",
    definition=(
        CollectionDefinition.builder()
        .set_vector_dimension(512)   # must match your embedding dimension
        .set_vector_metric(VectorMetric.COSINE)
        .build()
    )
)


In [None]:
embedding

In [None]:
sample = {
    'id': '1',
    'image_path': 'Images/airmax-97.jpg',
    '$vector': embedding.tolist()
}

my_collection.insert_one(sample)



In [None]:
# Run  vector search
query_embedding = get_embedding("Images/nike-sb.png").tolist()

cursor = my_collection.find(
    {},
    sort = {"$vector": query_embedding},
    limit=5,
    include_similarity=True,
)
for result in cursor:
    print(f"{result['image_path']}")