# Imports and such

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm

from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder

In [None]:
RANDOM_STATE = 42

# Preparing dataset

In [None]:
# Loading dataset
digits = datasets.load_digits()
n_samples = len(digits.images)

# Flattening dataset as a single input vector
data = digits.images.reshape((n_samples, -1))

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    data, digits.target, test_size=0.3, stratify=digits.target, random_state=RANDOM_STATE
)

input_dim = X_train[0].shape[0]
output_dim = len(np.unique(y_train))

In [None]:
scaler = MinMaxScaler()

X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

In [None]:
encoder = OneHotEncoder(sparse_output=False)

y_train = encoder.fit_transform(y_train.reshape(-1, 1))

In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

In [None]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

torch.Size([1257, 64]) torch.Size([1257, 10])
torch.Size([540, 64]) torch.Size([540])


# Defining model

In [None]:
input_dim, output_dim

(64, 10)

In [None]:
model = nn.Sequential(
    nn.Linear(input_dim, 32),
    nn.ReLU(),
    nn.Linear(32, 16),
    nn.ReLU(),
    nn.Linear(16, output_dim),
    nn.LogSoftmax(dim=1)
)

model

Sequential(
  (0): Linear(in_features=64, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=16, bias=True)
  (3): ReLU()
  (4): Linear(in_features=16, out_features=10, bias=True)
  (5): LogSoftmax(dim=1)
)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Defining training loop

In [None]:
epochs = 10
batch_size = 3
batches_per_epoch = len(X_train) // batch_size

In [None]:
for e in range(epochs):
  with tqdm.trange(batches_per_epoch, unit="batch", mininterval=0) as bar:
    bar.set_description(f"Epoch {e}")

    for i in bar:
      # Get current batch
      start = i * batch_size
      X_batch = X_train[start:start+batch_size]
      y_batch = y_train[start:start+batch_size]

      # Forward pass
      curr_y_pred = model(X_batch)
      loss_val = loss_fn(curr_y_pred, y_batch)

      # Backward pass
      optimizer.zero_grad()
      loss_val.backward()
      optimizer.step()


Epoch 0: 100%|██████████| 419/419 [00:04<00:00, 85.82batch/s]
Epoch 1: 100%|██████████| 419/419 [00:05<00:00, 80.41batch/s] 
Epoch 2: 100%|██████████| 419/419 [00:03<00:00, 121.58batch/s]
Epoch 3: 100%|██████████| 419/419 [00:05<00:00, 81.98batch/s] 
Epoch 4: 100%|██████████| 419/419 [00:03<00:00, 116.91batch/s]
Epoch 5: 100%|██████████| 419/419 [00:04<00:00, 102.96batch/s]
Epoch 6: 100%|██████████| 419/419 [00:06<00:00, 66.84batch/s]
Epoch 7: 100%|██████████| 419/419 [00:03<00:00, 128.50batch/s]
Epoch 8: 100%|██████████| 419/419 [00:02<00:00, 159.44batch/s]
Epoch 9: 100%|██████████| 419/419 [00:02<00:00, 162.89batch/s]


# Evaluate model

In [None]:
with torch.no_grad():
  y_pred = model(X_test)

In [None]:
print(metrics.classification_report(y_test, y_pred.argmax(1)))

              precision    recall  f1-score   support

         0.0       1.00      0.98      0.99        54
         1.0       0.92      0.89      0.91        55
         2.0       0.91      0.98      0.95        53
         3.0       0.95      1.00      0.97        55
         4.0       0.96      0.96      0.96        54
         5.0       1.00      0.95      0.97        55
         6.0       0.98      0.98      0.98        54
         7.0       1.00      0.94      0.97        54
         8.0       0.92      0.88      0.90        52
         9.0       0.93      1.00      0.96        54

    accuracy                           0.96       540
   macro avg       0.96      0.96      0.96       540
weighted avg       0.96      0.96      0.96       540

