## Torch

In [1]:
import torch
import random

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

from catboost import CatBoostClassifier 

import torch.nn as nn
import numpy as np


In [2]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


In [3]:
model = nn.Sequential(nn.Linear(2,3),nn.ReLU(),nn.Linear(3,1),nn.Sigmoid())

In [4]:
rand_vec=(torch.tensor([0.4983, 0.4915]))

In [5]:
rand_vec

tensor([0.4983, 0.4915])

In [6]:
model(rand_vec)

tensor([0.3709], grad_fn=<SigmoidBackward>)

In [7]:
model

Sequential(
  (0): Linear(in_features=2, out_features=3, bias=True)
  (1): ReLU()
  (2): Linear(in_features=3, out_features=1, bias=True)
  (3): Sigmoid()
)

In [8]:
torch.onnx.export(model.to(torch.float32),               
                  rand_vec.to(torch.float32),                         
                  "torch.onnx",  
                  input_names = ['input_1'],)


## Sklearn

In [9]:
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression()
clr.fit(X_train, y_train)


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


LogisticRegression()

In [10]:
initial_type = [('input_1', FloatTensorType([]))]
onx = convert_sklearn(clr, initial_types=initial_type,
                      options = {id(clr): {'zipmap': False, 'output_class_labels': False}}
)
with open("rf.onnx", "wb") as f:
    f.write(onx.SerializeToString())


In [11]:
tst=np.array([5.1, 3.4, 1.5, 0.2])

In [12]:
clr.predict_proba(tst.reshape(1, -1))

array([[9.67403244e-01, 3.25965891e-02, 1.67295612e-07]])

In [13]:
cb=CatBoostClassifier()
cb.fit(X_train, y_train)

Learning rate set to 0.070767
0:	learn: 1.0190575	total: 57.2ms	remaining: 57.2s
1:	learn: 0.9551038	total: 58.5ms	remaining: 29.2s
2:	learn: 0.8998060	total: 59.7ms	remaining: 19.8s
3:	learn: 0.8374843	total: 60.5ms	remaining: 15.1s
4:	learn: 0.7880622	total: 61.7ms	remaining: 12.3s
5:	learn: 0.7385566	total: 62.4ms	remaining: 10.3s
6:	learn: 0.7033668	total: 63.4ms	remaining: 9s
7:	learn: 0.6635131	total: 64.1ms	remaining: 7.95s
8:	learn: 0.6240725	total: 81.4ms	remaining: 8.96s
9:	learn: 0.5910182	total: 82.7ms	remaining: 8.19s
10:	learn: 0.5592191	total: 85.3ms	remaining: 7.67s
11:	learn: 0.5318448	total: 86.6ms	remaining: 7.13s
12:	learn: 0.5041223	total: 88.1ms	remaining: 6.68s
13:	learn: 0.4873621	total: 89.8ms	remaining: 6.32s
14:	learn: 0.4627092	total: 91.3ms	remaining: 6s
15:	learn: 0.4452085	total: 93.5ms	remaining: 5.75s
16:	learn: 0.4232752	total: 95ms	remaining: 5.49s
17:	learn: 0.4052171	total: 96.4ms	remaining: 5.26s
18:	learn: 0.3877373	total: 97.4ms	remaining: 5.03s


199:	learn: 0.0297729	total: 369ms	remaining: 1.48s
200:	learn: 0.0295984	total: 372ms	remaining: 1.48s
201:	learn: 0.0293904	total: 374ms	remaining: 1.48s
202:	learn: 0.0291273	total: 377ms	remaining: 1.48s
203:	learn: 0.0289393	total: 378ms	remaining: 1.47s
204:	learn: 0.0286868	total: 378ms	remaining: 1.47s
205:	learn: 0.0285361	total: 379ms	remaining: 1.46s
206:	learn: 0.0283961	total: 380ms	remaining: 1.45s
207:	learn: 0.0281896	total: 381ms	remaining: 1.45s
208:	learn: 0.0280284	total: 381ms	remaining: 1.44s
209:	learn: 0.0279071	total: 382ms	remaining: 1.44s
210:	learn: 0.0276520	total: 383ms	remaining: 1.43s
211:	learn: 0.0274742	total: 384ms	remaining: 1.43s
212:	learn: 0.0273284	total: 385ms	remaining: 1.42s
213:	learn: 0.0271940	total: 385ms	remaining: 1.42s
214:	learn: 0.0270403	total: 386ms	remaining: 1.41s
215:	learn: 0.0268861	total: 387ms	remaining: 1.4s
216:	learn: 0.0267053	total: 389ms	remaining: 1.4s
217:	learn: 0.0265725	total: 393ms	remaining: 1.41s
218:	learn: 0.

457:	learn: 0.0110802	total: 718ms	remaining: 850ms
458:	learn: 0.0110538	total: 724ms	remaining: 853ms
459:	learn: 0.0110170	total: 726ms	remaining: 852ms
460:	learn: 0.0109900	total: 727ms	remaining: 850ms
461:	learn: 0.0109585	total: 728ms	remaining: 848ms
462:	learn: 0.0109315	total: 729ms	remaining: 846ms
463:	learn: 0.0109071	total: 731ms	remaining: 844ms
464:	learn: 0.0108818	total: 734ms	remaining: 844ms
465:	learn: 0.0108523	total: 736ms	remaining: 843ms
466:	learn: 0.0108241	total: 737ms	remaining: 841ms
467:	learn: 0.0108001	total: 738ms	remaining: 839ms
468:	learn: 0.0107748	total: 741ms	remaining: 839ms
469:	learn: 0.0107525	total: 742ms	remaining: 837ms
470:	learn: 0.0107195	total: 743ms	remaining: 835ms
471:	learn: 0.0106928	total: 744ms	remaining: 833ms
472:	learn: 0.0106740	total: 745ms	remaining: 831ms
473:	learn: 0.0106510	total: 747ms	remaining: 828ms
474:	learn: 0.0106260	total: 747ms	remaining: 826ms
475:	learn: 0.0105880	total: 749ms	remaining: 824ms
476:	learn: 

746:	learn: 0.0063687	total: 1.07s	remaining: 363ms
747:	learn: 0.0063599	total: 1.07s	remaining: 362ms
748:	learn: 0.0063510	total: 1.08s	remaining: 361ms
749:	learn: 0.0063418	total: 1.08s	remaining: 359ms
750:	learn: 0.0063274	total: 1.08s	remaining: 357ms
751:	learn: 0.0063182	total: 1.08s	remaining: 356ms
752:	learn: 0.0063067	total: 1.08s	remaining: 355ms
753:	learn: 0.0062956	total: 1.08s	remaining: 353ms
754:	learn: 0.0062875	total: 1.08s	remaining: 352ms
755:	learn: 0.0062797	total: 1.09s	remaining: 351ms
756:	learn: 0.0062726	total: 1.09s	remaining: 349ms
757:	learn: 0.0062642	total: 1.09s	remaining: 347ms
758:	learn: 0.0062537	total: 1.09s	remaining: 346ms
759:	learn: 0.0062442	total: 1.09s	remaining: 344ms
760:	learn: 0.0062387	total: 1.09s	remaining: 343ms
761:	learn: 0.0062284	total: 1.09s	remaining: 341ms
762:	learn: 0.0062211	total: 1.09s	remaining: 340ms
763:	learn: 0.0062120	total: 1.09s	remaining: 338ms
764:	learn: 0.0062012	total: 1.09s	remaining: 336ms
765:	learn: 

938:	learn: 0.0049558	total: 1.25s	remaining: 81ms
939:	learn: 0.0049480	total: 1.25s	remaining: 79.7ms
940:	learn: 0.0049443	total: 1.25s	remaining: 78.3ms
941:	learn: 0.0049404	total: 1.25s	remaining: 77ms
942:	learn: 0.0049343	total: 1.25s	remaining: 75.6ms
943:	learn: 0.0049285	total: 1.25s	remaining: 74.2ms
944:	learn: 0.0049212	total: 1.25s	remaining: 72.9ms
945:	learn: 0.0049154	total: 1.25s	remaining: 71.6ms
946:	learn: 0.0049087	total: 1.25s	remaining: 70.2ms
947:	learn: 0.0049039	total: 1.25s	remaining: 68.9ms
948:	learn: 0.0048987	total: 1.26s	remaining: 67.5ms
949:	learn: 0.0048937	total: 1.26s	remaining: 66.2ms
950:	learn: 0.0048873	total: 1.26s	remaining: 64.8ms
951:	learn: 0.0048822	total: 1.26s	remaining: 63.5ms
952:	learn: 0.0048764	total: 1.26s	remaining: 62.3ms
953:	learn: 0.0048728	total: 1.26s	remaining: 60.9ms
954:	learn: 0.0048690	total: 1.26s	remaining: 59.5ms
955:	learn: 0.0048620	total: 1.26s	remaining: 58.3ms
956:	learn: 0.0048564	total: 1.27s	remaining: 56.9

<catboost.core.CatBoostClassifier at 0x7f8ec96d7a30>

In [14]:
tst=np.array([5.1, 3.4, 1.5, 0.2])

In [15]:
cb.predict_proba(X_train[0].reshape(1, -1))

array([[0.00248104, 0.99562836, 0.0018906 ]])

In [16]:
cb.save_model("cb.cbm",
           format="cbm",
)