In [1]:
from catboost import CatBoostClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

from sneakers_ml.features.features import get_train_val_test
from sneakers_ml.models.onnx import save_sklearn_onnx

In [2]:
X_train, X_val, X_test, y_train, y_val, y_test = get_train_val_test("data/features/resnet", "brands-classification")

In [3]:
param_grid = {"C": [0.1, 1, 10], "gamma": ["scale", "auto"], "kernel": ["linear"]}

svc = SVC()
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring="accuracy", verbose=1)
grid_search.fit(X_train, y_train)

Fitting 5 folds for each of 6 candidates, totalling 30 fits


In [4]:
pred = grid_search.best_estimator_.predict(X_test)
accuracy_score(y_test, pred)

0.7027027027027027

In [6]:
save_sklearn_onnx(grid_search.best_estimator_, X_train, "data/models/resnet-SVC-70acc.onnx")

In [7]:
param_grid = {"loss": ["log_loss", "hinge"], "alpha": [0.0001, 0.001]}

sgd = SGDClassifier()
grid_search = GridSearchCV(estimator=sgd, param_grid=param_grid, cv=5, scoring="accuracy", verbose=1)
grid_search.fit(X_train, y_train)

Fitting 5 folds for each of 4 candidates, totalling 20 fits


In [9]:
pred = grid_search.best_estimator_.predict(X_test)
accuracy_score(y_test, pred)

0.7138863000931966

In [10]:
save_sklearn_onnx(grid_search.best_estimator_, X_train, "data/models/resnet-SGD-71acc.onnx")

In [14]:
model = CatBoostClassifier(verbose=True, iterations=200)

model.fit(X_train, y_train, eval_set=(X_val, y_val))

Learning rate set to 0.204354
0:	learn: 2.3832257	test: 2.4004832	best: 2.4004832 (0)	total: 1.39s	remaining: 4m 35s
1:	learn: 2.2793682	test: 2.3083602	best: 2.3083602 (1)	total: 2.69s	remaining: 4m 26s
2:	learn: 2.2002017	test: 2.2418206	best: 2.2418206 (2)	total: 4.02s	remaining: 4m 24s
3:	learn: 2.1195991	test: 2.1782485	best: 2.1782485 (3)	total: 5.35s	remaining: 4m 21s
4:	learn: 2.0589311	test: 2.1321917	best: 2.1321917 (4)	total: 6.67s	remaining: 4m 20s
5:	learn: 2.0100818	test: 2.0927373	best: 2.0927373 (5)	total: 7.99s	remaining: 4m 18s
6:	learn: 1.9699169	test: 2.0679113	best: 2.0679113 (6)	total: 9.36s	remaining: 4m 18s
7:	learn: 1.9309290	test: 2.0415723	best: 2.0415723 (7)	total: 10.7s	remaining: 4m 16s
8:	learn: 1.8972441	test: 2.0162372	best: 2.0162372 (8)	total: 12s	remaining: 4m 15s
9:	learn: 1.8598374	test: 1.9920902	best: 1.9920902 (9)	total: 13.4s	remaining: 4m 13s
10:	learn: 1.8262297	test: 1.9678768	best: 1.9678768 (10)	total: 14.6s	remaining: 4m 10s
11:	learn: 1.

<catboost.core.CatBoostClassifier at 0x7f110379e670>

In [12]:
pred = model.predict(X_test)
accuracy_score(y_test, pred)

0.587138863000932

In [15]:
model.save_model(
    "data/models/resnet-CatBoost-58acc.onnx",
    format="onnx",
    export_parameters={
        "onnx_domain": "ai.catboost",
        "onnx_model_version": 1,
        "onnx_doc_string": "iterations=100 default model",
        "onnx_graph_name": "CatBoostModel_for_MultiClassification",
    },
)