In [None]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import walnut

# Example 4

### Convolutional neural network

The goal of this model is to classify images of hand-written digits.

### Step 1: Prepare data
You will need to download the dataset from https://www.kaggle.com/competitions/digit-recognizer/data and place it into the *data* directory. Only using the official training data for training, validation and testing, since it is just to showcase the framework.

In [None]:
data = pd.read_csv('data/mnist/train.csv')
data.head()

Since the labels are represented as integers, they should be encoded.

In [None]:
data_enc = walnut.preprocessing.encoding.pd_one_hot_encode(data, columns=['label'])
data_enc.head()

In [None]:
tensor = walnut.pd_to_tensor(data_enc)
train, val, test = walnut.preprocessing.split_train_val_test(tensor, ratio_val=0.01, ratio_test=0.02)

In [None]:
x_train, y_train = walnut.preprocessing.split_features_labels(train, 784)
x_val, y_val = walnut.preprocessing.split_features_labels(val, 784)
x_test, y_test = walnut.preprocessing.split_features_labels(test, 784)

In [None]:
x_train = x_train.reshape((x_train.shape[0], 1 , 28, -1))
x_val = x_val.reshape((x_val.shape[0], 1, 28, -1))
x_test = x_test.reshape((x_test.shape[0], 1, 28, -1))

print (f'{x_train.shape=}')
print (f'{y_train.shape=}')

print (f'{x_val.shape=}')
print (f'{y_val.shape=}')

print (f'{x_test.shape=}')
print (f'{y_test.shape=}')

Normalization

In [None]:
x_train = x_train / 255
x_val = x_val / 255
x_test = x_test / 255

### Step 2: Build the neural network structure

In [None]:
import walnut.nn as nn

model = nn.Sequential(layers=[
    nn.layers.Convolution2d(16, input_shape=(1, 28, 28), kernel_size=(3, 3), act="relu", norm="batch", use_bias=False), # bias not needed when using norm
    nn.layers.MaxPooling2d(kernel_size=(2, 2)),
    nn.layers.Reshape(),
    nn.layers.Linear(100, act="relu", norm="batch", use_bias=False),
    nn.layers.Linear(10, act="softmax")
])

The network is compiled to internally connect it's layers and initialize the model.

In [None]:
model.compile(
    optimizer=nn.optimizers.Adam(l_r=2e-3),
    loss_fn=nn.losses.Crossentropy(),
    metric=nn.metrics.Accuracy()
)

In [None]:
model

### Step 3: Train the model

In [None]:
epochs = 100
batch_size = 32

# train_loss_hist, val_loss_hist = model.train(x_train, y_train, epochs=epochs, batch_size=batch_size)
train_loss_hist, val_loss_hist = model.train(x_train, y_train, epochs=epochs, batch_size=batch_size, val_data=(x_val, y_val))

In [None]:
n = epochs // 10
traces = {
    "train_loss" : [l if i < n else sum(train_loss_hist[i-(n-1):i+1])/n for i,l in enumerate(train_loss_hist)],
    "val_loss" : val_loss_hist
}
nn.analysis.plot_curve(traces=traces, figsize=(15, 3), title="loss history", x_label="epoch", y_label="loss")

### Step 4: Evaluate the model

In [None]:
loss, accuracy = model.evaluate(x_test, y_test)
print(f'loss {loss:.4f}')
print(f'accuracy {accuracy*100:.2f}')

In [None]:
predictions = model(x_test)
nn.analysis.plot_confusion_matrix(predictions, y_test, figsize=(5,5))

### Step 5: Explore the inner workings
Pick a random image from the testing dataset.

In [None]:
i = random.randint(0, x_test.len)
image = np.moveaxis(x_test[i].data, 0, -1)
plot = plt.imshow(image, cmap='gray')
plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

Use it to predict a number and show the probability distribution of the outcome.

In [None]:
image_tensor = walnut.expand_dims(x_test[i], 0)
print(f"correct label: {np.argmax(y_test[i].data)}")
predictions = model(image_tensor)
nn.analysis.plot_probabilities(predictions, figsize=(6, 3))

Every layer of the model can be accessed to explore their output. Here we iterate over all the kernels of the convolutional layer to explore what they learned to focus on in images.

In [None]:
channels = {f"{i + 1} {l.__class__.__name__}" : l.y.data[0].copy() for i, l in enumerate(model.layers) if l.__class__.__name__ == "Convolution2d"}
nn.analysis.plot_images(channels, figsize=(40, 40), cmap="gray")