MNIST : 28*28 grayscale image multi-class classification
=========================================================

In this tutorial we show how green\_tsetlin TM can be used to train on the **MNIST dataset**. MNIST is a benchmark by digit recognition 
that contains images of handwritten digits with a total of 70,000 images. Each image is a 28x28 pixel grayscale image with values between 0 and 255.

In [2]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split as split
import numpy as np

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

X_train, X_test, y_train, y_test = split(X, y, test_size=0.2, random_state=42, shuffle=True)

In [5]:
X_train = X_train[:10000]
y_train = y_train[:10000]
X_test = X_test[:1000]
y_test = y_test[:1000]

$(70000, 784) \leftarrow (70000, 28, 28)$

With sklearn we import an easy to use version of MNIST. This version gives 2d right away hence no flatten is needed. Next, as the 
TM requires binary values, each pixel is converted with a threshold of 75.

In [6]:
X_train = np.where(X_train > 75, 1, 0)
X_train = X_train.astype(np.uint8)
    
X_test = np.where(X_test > 75, 1, 0)
X_test = X_test.astype(np.uint8)

y_train = y_train.astype(np.uint32)
y_test = y_test.astype(np.uint32)

We can now train the Tsetlin Machine. Here, it is preferable and recommended to run a hyperparameter search.

In [13]:
from green_tsetlin.hpsearch import HyperparameterSearch


hpsearch = HyperparameterSearch(s_space=(3.0, 40.0),
                                clause_space=(1000, 8000),
                                threshold_space=(1000, 8000),
                                max_epoch_per_trial=20,
                                literal_budget=(5, 10),
                                k_folds=3,
                                n_jobs=4,
                                seed=42,
                                minimize_literal_budget=False)

hpsearch.set_train_data(X_train, y_train)
hpsearch.set_eval_data(X_test, y_test)

hpsearch.optimize(n_trials=10, study_name="MNIST hpsearch", show_progress_bar=True, storage=None)

[I 2024-06-19 13:29:54,872] A new study created in memory with name: MNIST hpsearch
Processing trial 9 of 10, best score: [0.9937278429233706]: 100%|██████████| 10/10 [1:11:24<00:00, 428.49s/it]


### Best parameters

best paramaters:  {'s': 21.627727185060525, 'n_clauses': 6154, 'threshold': 1218, 'literal_budget': 10}

best score:  0.9937278429233706

In [1]:
import green_tsetlin as gt

best_params = {'s': 21.627727185060525, 'n_clauses': 6154, 'threshold': 1218, 'literal_budget': 10}

tm = gt.TsetlinMachine(n_literals=28*28,
                        n_clauses=best_params['n_clauses'],
                        s=best_params['s'],
                        threshold=int(best_params['threshold']),
                        n_classes=10,
                        literal_budget=best_params['literal_budget'])

In [2]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split as split
import numpy as np

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

X_train, X_test, y_train, y_test = split(X, y, test_size=0.2, random_state=42, shuffle=True)

X_train = np.where(X_train > 75, 1, 0)
X_train = X_train.astype(np.uint8)
    
X_test = np.where(X_test > 75, 1, 0)
X_test = X_test.astype(np.uint8)

y_train = y_train.astype(np.uint32)
y_test = y_test.astype(np.uint32)

In [3]:
trainer = gt.Trainer(tm, k_folds=2, n_epochs=20, seed=42, n_jobs=7, progress_bar=True)

trainer.set_train_data(X_train, y_train)
trainer.set_eval_data(X_test, y_test)

res = trainer.train()

Processing epoch 20 of 20, train acc: 0.992, best eval score: 0.972 (epoch: 19): 100%|██████████| 20/20 [15:50<00:00, 47.51s/it]
Processing epoch 20 of 20, train acc: 0.995, best eval score: 0.992 (epoch: 0): 100%|██████████| 20/20 [15:03<00:00, 45.18s/it]


In [5]:
res

{'best_eval_score': 0.9918857142857143,
 'k_folds': 2,
 'train_time_of_epochs': [41.99848390498664,
  30.74028508097399,
  30.064865624008235,
  29.525387891975697,
  29.982523320999462,
  30.06839102698723,
  30.43663672398543,
  31.987098273995798,
  28.460180500987917,
  28.762363646004815,
  28.464446795987897,
  28.478722686995752,
  28.687962658004835,
  27.64200901699951,
  27.71681642101612,
  28.097751585999504,
  28.664220188045874,
  27.467564610997215,
  27.849587303004228,
  27.281124531000387,
  27.562880382989533,
  27.764159907004796,
  29.250643727951683,
  28.007854113995563,
  27.766060939000454,
  27.767420815012883,
  28.402095064986497,
  28.594877657014877,
  27.333333007001784,
  27.278216135047842,
  27.83884705597302,
  27.217197816993576,
  27.081594954011962,
  26.943604013998993,
  27.043323835998308,
  28.250813858001493,
  26.777471681009047,
  27.486269582004752,
  27.345113909977954,
  26.679121892957482],
 'train_log': [0.8924285714285715,
  0.94857142