In [1]:
import numpy as np
from catboost import CatBoostClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

from sneakers_ml.features.features import get_train_val_test
from sneakers_ml.models.onnx import save_sklearn_onnx

In [2]:
x_train, x_val, x_test, y_train, y_val, y_test = get_train_val_test(
    "data/features/brands-classification-splits", "resnet"
)
x_train_val = np.concatenate((x_train, x_val), axis=0)
y_train_val = np.concatenate((y_train, y_val))

In [3]:
param_grid = {"C": [0.1, 1, 10, 100], "gamma": ["scale", "auto"], "kernel": ["linear", "rbf"]}
svc = SVC()
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=3, scoring="f1_macro", verbose=1, n_jobs=-1)
grid_search.fit(x_train_val, y_train_val)
pred = grid_search.best_estimator_.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test, pred, average='weighted')}")
print(f"F1-macro: {f1_score(y_test, pred, average='macro')}")

Fitting 3 folds for each of 16 candidates, totalling 48 fits
Acc: 0.7632805219012115
F1-weighted: 0.7627610507119207
F1-macro: 0.7620054668979055


In [4]:
grid_search.best_params_

{'C': 100, 'gamma': 'scale', 'kernel': 'rbf'}

In [5]:
save_sklearn_onnx(grid_search.best_estimator_, x_train, "data/models/brands-classification/resnet-svc.onnx")

In [6]:
param_grid = {"loss": ["log_loss", "hinge"], "alpha": [0.0001, 0.001, 0.00001], "penalty": ["l2", "elasticnet"]}
sgd = SGDClassifier()
grid_search = GridSearchCV(estimator=sgd, param_grid=param_grid, cv=3, scoring="f1_macro", verbose=1, n_jobs=-1)
grid_search.fit(x_train_val, y_train_val)
pred = grid_search.best_estimator_.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test,pred,average='weighted')}")
print(f"F1-macro: {f1_score(y_test,pred,average='macro')}")

Fitting 3 folds for each of 12 candidates, totalling 36 fits
Acc: 0.6952469711090401
F1-weighted: 0.6911633638749781
F1-macro: 0.6479479888116414


In [7]:
grid_search.best_params_

{'alpha': 1e-05, 'loss': 'log_loss', 'penalty': 'elasticnet'}

In [8]:
save_sklearn_onnx(grid_search.best_estimator_, x_train, "data/models/brands-classification/resnet-sgd.onnx")

In [9]:
model = CatBoostClassifier(verbose=True, iterations=1000, task_type="GPU")
model.fit(x_train_val, y_train_val)

Learning rate set to 0.087133
0:	learn: 2.4402012	total: 272ms	remaining: 4m 31s
1:	learn: 2.3488970	total: 437ms	remaining: 3m 38s
2:	learn: 2.2647879	total: 603ms	remaining: 3m 20s
3:	learn: 2.1980965	total: 763ms	remaining: 3m 10s
4:	learn: 2.1408256	total: 933ms	remaining: 3m 5s
5:	learn: 2.0888312	total: 1.08s	remaining: 2m 59s
6:	learn: 2.0456744	total: 1.25s	remaining: 2m 56s
7:	learn: 2.0058022	total: 1.41s	remaining: 2m 55s
8:	learn: 1.9712451	total: 1.57s	remaining: 2m 53s
9:	learn: 1.9375039	total: 1.74s	remaining: 2m 51s
10:	learn: 1.9088136	total: 1.9s	remaining: 2m 50s
11:	learn: 1.8810334	total: 2.06s	remaining: 2m 49s
12:	learn: 1.8538776	total: 2.22s	remaining: 2m 48s
13:	learn: 1.8306177	total: 2.38s	remaining: 2m 47s
14:	learn: 1.8082692	total: 2.53s	remaining: 2m 46s
15:	learn: 1.7849467	total: 2.69s	remaining: 2m 45s
16:	learn: 1.7626295	total: 2.86s	remaining: 2m 45s
17:	learn: 1.7437637	total: 3.01s	remaining: 2m 44s
18:	learn: 1.7246571	total: 3.18s	remaining: 2

<catboost.core.CatBoostClassifier at 0x7f075daf38e0>

In [10]:
pred = model.predict(x_test)
print(f"Acc: {accuracy_score(y_test, pred)}")
print(f"F1-weighted: {f1_score(y_test,pred,average='weighted')}")
print(f"F1-macro: {f1_score(y_test,pred,average='macro')}")

Acc: 0.6999068033550793
F1-weighted: 0.6787267217149069
F1-macro: 0.6360292362180385


In [11]:
model.save_model(
    "data/models/brands-classification/resnet-catboost.onnx",
    format="onnx",
    export_parameters={
        "onnx_domain": "ai.catboost",
        "onnx_model_version": 1,
        "onnx_doc_string": "default model",
        "onnx_graph_name": "CatBoostModel_for_MultiClassification",
    },
)