In [None]:
import sys
sys.path.append("..") # for sibling import

import numpy as np
import pandas as pd
import walnut

# Example 4.2

### Convolutional Neural Network: more complex data

The goal of this model is to classify images of clothing items.

### Step 1: Prepare data
You will need to download the dataset from https://www.kaggle.com/datasets/zalando-research/fashionmnist?resource=download and place it into the *data* directory. Only using the official training data for training, validation and testing, since it is just to showcase the framework.

In [None]:
data = pd.read_csv('../data/fashion_mnist/fashion-mnist_train.csv')
data.head()

In [None]:
tensor = walnut.df_to_tensor(data)[:5000]
train, val, test = walnut.preprocessing.split_train_val_test(tensor, ratio_val=0.005, ratio_test=0.005)

x_train, y_train = train[:, 1:], train[:, 0].astype("int")
x_val, y_val = val[:, :-1], val[:, 0].astype("int")
x_test, y_test = test[:, :-1], test[:, 0].astype("int")

x_train = x_train.reshape((x_train.shape[0], 1 , 28, -1))
x_val = x_val.reshape((x_val.shape[0], 1, 28, -1))
x_test = x_test.reshape((x_test.shape[0], 1, 28, -1))

x_train = x_train.astype("float32") / 255.0
x_val = x_val.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

print (f'{x_train.shape=}')
print (f'{y_train.shape=}')

print (f'{x_val.shape=}')
print (f'{y_val.shape=}')

print (f'{x_test.shape=}')
print (f'{y_test.shape=}')

### Step 2: Build the neural network structure

In [None]:
import walnut.nn as nn
from walnut.nn.layers import *

model = nn.Sequential([
    Convolution2d(1, 8, kernel_size=(5, 5), pad="same", use_bias=False), Batchnorm(8), ReLU(),
    MaxPooling2d(kernel_size=(2, 2)),
    Convolution2d(8, 32, kernel_size=(3, 3), pad="same", use_bias=False), Batchnorm(32), ReLU(),
    MaxPooling2d(kernel_size=(2, 2)),
    Dropout(0.1),
    Flatten(),
    Linear(7*7*32, 128, use_bias=False), Batchnorm(128), ReLU(),
    Linear(128, 10)
])

In [None]:
model.compile(
    optimizer=nn.optimizers.Adam(3e-4),
    loss_fn=nn.losses.Crossentropy(),
    metric=nn.metrics.get_accuracy
)

In [None]:
from walnut.nn.analysis import model_summary
model_summary(model, (1, 28, 28))

### Step 3: Train the model

In [None]:
epochs = 10
batch_size = 32

train_loss_hist, val_loss_hist = model.train(x_train, y_train, epochs=epochs, batch_size=batch_size, val_data=(x_val, y_val))

In [None]:
n = max(10, epochs) // 10

traces = {
    "train_loss" : [np.average(train_loss_hist[i-min(n-1, i):i+1]) for i in range(len(train_loss_hist))],
    "val_loss" : val_loss_hist
}

nn.analysis.plot_curve(traces=traces, figsize=(15, 3), title="loss history", x_label="epoch", y_label="loss")

### Step 4: Evaluate the model

In [None]:
loss, accuracy = model.evaluate(x_test, y_test)
print(f'loss {loss:.4f}')
print(f'accuracy {accuracy*100:.2f}')

In [None]:
predictions = model(x_test)
nn.analysis.plot_confusion_matrix(predictions, y_test, figsize=(5, 5), cmap='Blues')