In [1]:
from torch import nn
from torch.optim import Adam
import torchvision.datasets as datasets # pip install torch vision if you dont have this
from sklearn.metrics import classification_report
from tamnun.core import TorchEstimator

Load data using torch vision

In [2]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

Get the data tensors for train and test

In [3]:
train_X = mnist_trainset.data.reshape(-1, 28*28).numpy()
train_y = mnist_trainset.targets.numpy()

test_X = mnist_testset.data.reshape(-1, 28*28).numpy()
test_y = mnist_testset.targets.numpy()

Create simple linear classifier with 28x28 (the image size) as input and 10 classes as output

In [4]:
module = nn.Linear(28*28, 10)

Create the tamnun estimator

In [5]:
clf = TorchEstimator(module, optimizer=Adam(module.parameters(), lr=1e-4))

### fit().predict()!

In [6]:
clf = clf.fit(train_X, train_y, epochs=20, batch_size=32)

Epoch 1/20:
1874/1875 batch loss: 3.3726372718811035    avg loss: 6.262858227777481
Epoch 2/20:
1874/1875 batch loss: 0.9372385740280151   5 avg loss: 2.11575478618145
Epoch 3/20:
1874/1875 batch loss: 2.048184633255005 9    avg loss: 1.6549429415345192
Epoch 4/20:
1874/1875 batch loss: 0.0006771087646484375  avg loss: 1.4070107420285543
Epoch 5/20:
1874/1875 batch loss: 0.8950674533843994 6  avg loss: 1.234210046708584
Epoch 6/20:
1874/1875 batch loss: 1.1697653532028198     avg loss: 1.1113991823037466
Epoch 7/20:
1874/1875 batch loss: 0.6111218333244324     avg loss: 1.0094971430778503
Epoch 8/20:
1874/1875 batch loss: 0.008098363876342773  avg loss: 0.9214095135052999
Epoch 9/20:
1874/1875 batch loss: 1.3127366304397583     avg loss: 0.8733639629522959
Epoch 10/20:
1874/1875 batch loss: 0.2436842918395996    avg loss: 0.8153708161791166
Epoch 11/20:
1874/1875 batch loss: 0.3469241261482239  6 avg loss: 0.786324020353953
Epoch 12/20:
1874/1875 batch loss: 0.07478591799736023  6 avg 

In [7]:
predicted = clf.predict(test_X)
print(classification_report(test_y, predicted))

              precision    recall  f1-score   support

           0       0.98      0.92      0.95       980
           1       0.97      0.97      0.97      1135
           2       0.91      0.87      0.89      1032
           3       0.86      0.90      0.88      1010
           4       0.90      0.94      0.92       982
           5       0.78      0.89      0.83       892
           6       0.96      0.90      0.93       958
           7       0.97      0.81      0.88      1028
           8       0.81      0.86      0.84       974
           9       0.84      0.89      0.86      1009

   micro avg       0.89      0.89      0.89     10000
   macro avg       0.90      0.89      0.89     10000
weighted avg       0.90      0.89      0.90     10000

