This notebook runs a random forest classifier and a linear SVM on the EMNIST - Letters datasets.

In [20]:
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.svm import SVC
import numpy as np
from joblib import load, dump

In [22]:
# Obtain a train/test/split for EMNIST
dataset_train = datasets.EMNIST(root='./data', split='letters', train=True, download=True)
dataset_test = datasets.EMNIST(root='./data', split='letters', train=False, download=True)

# Plan is to use a RandomForestClassifier and SVM from sklearn
# Need to flatten data into an array
x_train = dataset_train.data.numpy()
y_train = dataset_train.targets.numpy()
x_test = dataset_test.data.numpy()
y_test = dataset_test.targets.numpy()

# Normalise
x_train = x_train / 255.0
x_test = x_test / 255.0

# Flatten
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

In [18]:
# Create and train and make predictions with Random Forest Classifier
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=1)
rf_classifier.fit(x_train, y_train)
y_pred_rf = rf_classifier.predict(x_test)

# Evaluate
accuracy_rf = accuracy_score(y_test, y_pred_rf)
classification_report_rf = classification_report(y_test, y_pred_rf)
print(f"Random Forest Classifier Accuracy: {accuracy_rf:.4f}")
print("Random Forest Classifier Report:")
print(classification_report_rf)

# save
# joblib.dump(rf_classifier, './Models/rf_classifier.joblib')

Random Forest Classifier Accuracy: 0.8855
Random Forest Classifier Report:
              precision    recall  f1-score   support

           1       0.81      0.88      0.84       800
           2       0.91      0.92      0.92       800
           3       0.93      0.92      0.93       800
           4       0.90      0.88      0.89       800
           5       0.90      0.92      0.91       800
           6       0.92      0.90      0.91       800
           7       0.84      0.70      0.76       800
           8       0.89      0.87      0.88       800
           9       0.71      0.73      0.72       800
          10       0.88      0.89      0.89       800
          11       0.88      0.90      0.89       800
          12       0.72      0.71      0.72       800
          13       0.94      0.95      0.95       800
          14       0.87      0.90      0.89       800
          15       0.89      0.96      0.92       800
          16       0.92      0.95      0.94       800
      

['./Models/rf_classifier.joblib']

In [23]:
# Create and train and make predictions with SVM
svm_classifier = SVC(kernel='linear', random_state=1)
svm_classifier.fit(x_train, y_train)
y_pred_svm = svm_classifier.predict(x_test)

# Evaluate
accuracy_svm = accuracy_score(y_test, y_pred_svm)
classification_report_svm = classification_report(y_test, y_pred_svm)
print(f"SVM Classifier Accuracy: {accuracy_svm:.4f}")
print("Random Forest Classifier Report:")
print(classification_report_svm)


# save model
joblib.dump(svm_classifier, './Models/svm_classifier.joblib')

SVM Classifier Accuracy: 0.7879
Random Forest Classifier Report:
              precision    recall  f1-score   support

           1       0.66      0.79      0.72       800
           2       0.77      0.83      0.80       800
           3       0.85      0.87      0.86       800
           4       0.77      0.78      0.77       800
           5       0.81      0.80      0.81       800
           6       0.78      0.81      0.80       800
           7       0.63      0.60      0.61       800
           8       0.73      0.77      0.75       800
           9       0.62      0.69      0.65       800
          10       0.82      0.83      0.83       800
          11       0.76      0.77      0.76       800
          12       0.67      0.67      0.67       800
          13       0.88      0.89      0.89       800
          14       0.76      0.73      0.74       800
          15       0.88      0.91      0.89       800
          16       0.88      0.87      0.87       800
          17    

['./Models/svm_classifier.joblib']