In [None]:
import torch
import numpy as np
import pickle
import json
from torchvision import transforms

# Load components
def load_components():
    # Load model
    model = torch.load("app/models/model.pth", map_location=torch.device("cpu"))
    model.eval()

    # Load embeddings
    embeddings_data = np.load("app/data/model_embeddings.npz")
    embeddings = embeddings_data["embeddings"]

    # Load Zalando items
    with open("app/data/downloaded_items.pkl", "rb") as f:
        zalando_items = pickle.load(f)

    # Load transformations
    with open("app/data/transformations.json", "r") as f:
        transformations = json.load(f)

    # Define preprocessing transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=transformations["mean"], std=transformations["std"])
    ])

    return model, embeddings, zalando_items, transform

# Preprocess image
def preprocess_image(image, transform):
    return transform(image).unsqueeze(0)

# Classify image
def classify_image(model, image):
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
    return predicted.item()

# Find the most similar item
def find_similar_item(image, embeddings, zalando_items):
    with torch.no_grad():
        # Compute the embedding for the uploaded image
        image_embedding = image.cpu().numpy().flatten()

        # Compute similarity (e.g., cosine similarity)
        similarities = np.dot(embeddings, image_embedding) / (
            np.linalg.norm(embeddings, axis=1) * np.linalg.norm(image_embedding)
        )

        # Find the most similar item
        most_similar_index = np.argmax(similarities)
        return zalando_items[most_similar_index]