In [None]:
from pathlib import Path

import numpy as np
from PIL import Image

from sneakers_ml.features.features import load_features
from sneakers_ml.features.hog import get_hog
from sneakers_ml.features.resnet152 import get_resnet152_feature
from sneakers_ml.features.sift import get_sift_feature
from sneakers_ml.models.onnx import load_catboost_onnx, load_sklearn_onnx, predict_catboost_onnx, predict_sklearn_onnx

In [None]:
list(Path("data/training/brands-classification-splits/test").iterdir())

[PosixPath('data/training/brands-classification-splits/test/adidas'),
 PosixPath('data/training/brands-classification-splits/test/nike'),
 PosixPath('data/training/brands-classification-splits/test/vans'),
 PosixPath('data/training/brands-classification-splits/test/jordan'),
 PosixPath('data/training/brands-classification-splits/test/saucony'),
 PosixPath('data/training/brands-classification-splits/test/kangaroos'),
 PosixPath('data/training/brands-classification-splits/test/converse'),
 PosixPath('data/training/brands-classification-splits/test/clarks'),
 PosixPath('data/training/brands-classification-splits/test/puma'),
 PosixPath('data/training/brands-classification-splits/test/new balance'),
 PosixPath('data/training/brands-classification-splits/test/reebok'),
 PosixPath('data/training/brands-classification-splits/test/karhu'),
 PosixPath('data/training/brands-classification-splits/test/asics')]

In [None]:
sklearn_hog_models = [
    "data/models/brands-classification/hog-sgd.onnx",
    "data/models/brands-classification/hog-svc.onnx",
]
sklearn_resnet_models = [
    "data/models/brands-classification/resnet-sgd.onnx",
    "data/models/brands-classification/resnet-svc.onnx",
]
sklearn_sift_models = [
    "data/models/brands-classification/sift-sgd.onnx",
    "data/models/brands-classification/sift-svc.onnx",
]
catboost_hog_models = ["data/models/brands-classification/hog-catboost.onnx"]
catboost_resnet_models = ["data/models/brands-classification/resnet-catboost.onnx"]
catboost_sift_models = ["data/models/brands-classification/sift-catboost.onnx"]

In [None]:
models = {}

for model in sklearn_hog_models + sklearn_resnet_models + sklearn_sift_models:
    models[Path(model).stem] = load_sklearn_onnx(model)

for model in catboost_hog_models + catboost_resnet_models + catboost_sift_models:
    models[Path(model).stem] = load_catboost_onnx(model)

In [None]:
def predict_using_all_models(image: Image.Image) -> dict:
    preds = {}
    kmeans = load_sklearn_onnx("data/models/brands-classification/sift-kmeans.onnx")
    hog_embedding = get_hog(image)[np.newaxis]
    resnet_embedding = get_resnet152_feature(image)[np.newaxis]
    sift_embedding = get_sift_feature(image, kmeans, 2000)[np.newaxis]
    for model in models:
        print(model)
        if "hog" in model:
            if "catboost" in model:
                preds[model] = predict_catboost_onnx(models[model], hog_embedding)[0][0]
            else:
                preds[model] = predict_sklearn_onnx(models[model], hog_embedding)[0]
        elif "resnet" in model:
            if "catboost" in model:
                preds[model] = predict_catboost_onnx(models[model], resnet_embedding)[0][0]
            else:
                preds[model] = predict_sklearn_onnx(models[model], resnet_embedding)[0]
        elif "sift" in model:
            if "catboost" in model:
                preds[model] = predict_catboost_onnx(models[model], sift_embedding)[0][0]
            else:
                preds[model] = predict_sklearn_onnx(models[model], sift_embedding)[0]
    return preds

In [None]:
models

{'hog-sgd': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fbb2c213a30>,
 'hog-svc': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fbb2c202dc0>,
 'resnet-sgd': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fbb2c202790>,
 'resnet-svc': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fbb2c2024f0>,
 'sift-sgd': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fba2e9e7b20>,
 'sift-svc': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fba2e9e7d30>,
 'hog-catboost': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fbb2c20d910>,
 'resnet-catboost': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fba2e9e7dc0>,
 'sift-catboost': <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fba2e9e79d0>}

In [None]:
image = Image.open("data/training/brands-classification-splits/test/asics/6.jpeg")

In [None]:
predict_using_all_models(image)

hog-sgd
hog-svc
resnet-sgd
resnet-svc
sift-sgd
sift-svc
hog-catboost
resnet-catboost
sift-catboost


{'hog-sgd': '1',
 'hog-svc': '1',
 'resnet-sgd': '1',
 'resnet-svc': '1',
 'sift-sgd': '8',
 'sift-svc': '8',
 'hog-catboost': '1',
 'resnet-catboost': '1',
 'sift-catboost': '8'}

In [None]:
_, _, class_2_idx = load_features("data/features/brands-classification-splits/hog-train.pickle")
class_2_idx

{'adidas': 0,
 'asics': 1,
 'clarks': 2,
 'converse': 3,
 'jordan': 4,
 'kangaroos': 5,
 'karhu': 6,
 'new balance': 7,
 'nike': 8,
 'puma': 9,
 'reebok': 10,
 'saucony': 11,
 'vans': 12}