In [1]:
from network_conf import *
from PIL import Image
import matplotlib.pyplot as plt
from canvas import DrawingWidget
from IPython.display import display

In [2]:
MNIST_TRAIN = "data/mnist_train.csv"
MNIST_TEST = "data/mnist_test.csv"

X_train, y_train, X_test, y_test = load_data(MNIST_TRAIN, MNIST_TEST)

y_train_encoded = one_hot_encode(y_train, 10)
y_test_encoded = one_hot_encode(y_test, 10)

input_size = 784
hidden_size = 256
output_size = 10
lr = 0.1

In [3]:
nn = Network(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=output_size,
    lr=lr,
    activation=Activation.relu,
)

nn.train(X_train, y_train_encoded, epochs=100)
nn.save(f"mnist_relu_{hidden_size}.npz")

Epoch 0, Loss: 4.5288


In [None]:
y_test_pred = nn.predict(X_test)
y_test_pred_encoded = one_hot_encode(y_test_pred)
test_accuracy = accuracy(y_test_encoded, y_test_pred_encoded)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

In [None]:
def plot_predictions(X_test, y_test, y_pred, num_samples=10):
    plt.figure(figsize=(10, 10))

    for i in range(num_samples):
        idx = np.random.randint(0, X_test.shape[0])

        image = X_test[idx].reshape(28, 28)
        true_label = y_test[idx]
        predicted_label = y_pred[idx]

        plt.subplot(5, 5, i + 1)
        plt.imshow(image, cmap="gray")
        plt.title(f"True: {true_label}, Pred: {predicted_label}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


plot_predictions(X_test, y_test, y_test_pred, num_samples=10)

In [None]:
drawing_widget = DrawingWidget(
    width=500, height=500, background="#000", default_style="#fff"
)

drawing_widget.show()

In [None]:
x28_float32, x28_float32_flat = drawing_widget.get_image_data(
    mnist=True, size=28, dtype="int8"
)

display(Image.fromarray(x28_float32))

In [None]:
y_sample_pred = nn.predict(x28_float32_flat)[0]
print(y_sample_pred)