In [None]:
import numpy as np
from catboost import CatBoostClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

from sneakers_ml.features.features import get_train_val_test
from sneakers_ml.models.onnx import save_sklearn_onnx

In [None]:
x_train, x_val, x_test, y_train, y_val, y_test = get_train_val_test("data/features/brands-classification-splits", "hog")
x_train_val = np.concatenate((x_train, x_val), axis=0)
y_train_val = np.concatenate((y_train, y_val))

In [None]:
param_grid = {"C": [0.1, 1, 10, 100], "gamma": ["scale", "auto"], "kernel": ["linear", "rbf"]}
svc = SVC()
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=3, scoring="f1_macro", verbose=1, n_jobs=-1)
grid_search.fit(x_train_val, y_train_val)
pred = grid_search.best_estimator_.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test, pred, average='weighted')}")
print(f"F1-macro: {f1_score(y_test, pred, average='macro')}")

Fitting 3 folds for each of 16 candidates, totalling 48 fits
Acc: 0.8126747437092264
F1-weighted: 0.8096470380854331
F1-macro: 0.784528576960828


In [None]:
grid_search.best_params_

{'C': 10, 'gamma': 'scale', 'kernel': 'rbf'}

In [None]:
save_sklearn_onnx(grid_search.best_estimator_, x_train, "data/models/brands-classification/hog-svc.onnx")

In [None]:
param_grid = {"loss": ["log_loss", "hinge"], "alpha": [0.0001, 0.001, 0.00001], "penalty": ["l2", "elasticnet"]}
sgd = SGDClassifier()
grid_search = GridSearchCV(estimator=sgd, param_grid=param_grid, cv=3, scoring="f1_macro", verbose=1, n_jobs=-1)
grid_search.fit(x_train_val, y_train_val)
pred = grid_search.best_estimator_.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test,pred,average='weighted')}")
print(f"F1-macro: {f1_score(y_test,pred,average='macro')}")

Fitting 3 folds for each of 12 candidates, totalling 36 fits


Acc: 0.7651444547996272
F1-weighted: 0.7614728620491326
F1-macro: 0.7174377572851446


In [None]:
grid_search.best_params_

{'alpha': 0.0001, 'loss': 'log_loss', 'penalty': 'l2'}

In [None]:
save_sklearn_onnx(grid_search.best_estimator_, x_train, "data/models/brands-classification/hog-sgd.onnx")

In [None]:
model = CatBoostClassifier(verbose=True, iterations=1000, task_type="GPU")
model.fit(x_train_val, y_train_val)

Learning rate set to 0.087133
0:	learn: 2.4083200	total: 619ms	remaining: 10m 17s
1:	learn: 2.2942286	total: 1.09s	remaining: 9m 3s
2:	learn: 2.2007154	total: 1.58s	remaining: 8m 45s
3:	learn: 2.1220182	total: 2.07s	remaining: 8m 34s
4:	learn: 2.0616302	total: 2.57s	remaining: 8m 30s
5:	learn: 2.0127098	total: 3.04s	remaining: 8m 23s
6:	learn: 1.9621315	total: 3.53s	remaining: 8m 21s
7:	learn: 1.9198365	total: 3.53s	remaining: 8m 21s
8:	learn: 1.8835596	total: 6.1s	remaining: 12m 35s
9:	learn: 1.8490590	total: 6.59s	remaining: 12m 4s
10:	learn: 1.8169042	total: 7.05s	remaining: 11m 37s
11:	learn: 1.7956355	total: 7.52s	remaining: 11m 15s
12:	learn: 1.7722110	total: 7.99s	remaining: 10m 57s
13:	learn: 1.7483722	total: 8.46s	remaining: 10m 41s
14:	learn: 1.7301773	total: 8.91s	remaining: 10m 26s
15:	learn: 1.7083295	total: 9.38s	remaining: 10m 15s
16:	learn: 1.6881834	total: 9.86s	remaining: 10m 5s
17:	learn: 1.6713641	total: 10.3s	remaining: 9m 56s
18:	learn: 1.6536725	total: 10.8s	rema

<catboost.core.CatBoostClassifier at 0x7fe15f4b3e80>

In [None]:
pred = model.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test,pred,average='weighted')}")
print(f"F1-macro: {f1_score(y_test,pred,average='macro')}")

Acc: 0.7362534948741846
F1-weighted: 0.7245682177032027
F1-macro: 0.6796463543685718


In [None]:
model.save_model(
    "data/models/brands-classification/hog-catboost.onnx",
    format="onnx",
    export_parameters={
        "onnx_domain": "ai.catboost",
        "onnx_model_version": 1,
        "onnx_doc_string": "default model",
        "onnx_graph_name": "CatBoostModel_for_MultiClassification",
    },
)