<a href="https://colab.research.google.com/github/daniel-falk/ai-ml-principles-exercises/blob/log_initial_loss/more_logging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install dependencies and import packages
First we need to install the libraries we will be using. We will use `numpy` for generic matrix operations and `tensorflow` for deep learning operations such as convolutions, pooling and training (backpropagation).

In [None]:
import sys
!{sys.executable} -m pip install numpy tensorflow wandb
from IPython.display import clear_output
clear_output()
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import wandb

## Create a neural network
Next we define a function that can be used to build a neural network. The neural network is a simple CNN (convolutional neural network) used for classification. The structure of the network is not important for this exercise, you can instead see it as a black box that can be trained to classify an input image.

In [None]:
def create_cnn(input_shape, output_classes):
    return keras.Sequential(
        [
            keras.Input(shape=input_shape),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dense(output_classes, activation="softmax"),
        ]
    )

## Prepare the data
The neural network will be trained on a digit classification dataset called *MNIST*. This code downloads and loads the images together with their true labels. The code also does some preprocessing of the data to make it more suitable for a neural network.

In [None]:
dataset_name = "mnist"
num_classes = 10
shape = (28, 28, 1)

def get_mnist_data():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = getattr(keras.datasets, dataset_name).load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32") / 255
    x_test = x_test.astype("float32") / 255

    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)

    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    return (x_train, y_train), (x_test, y_test)

## Train the network
Finally we will train the network on the data to teach it how to classify a digit. We create a model which expects a 28x28 pixel monocolor image since this is the format the images in the *MNIST* dataset are. We then create an optimizer and calls the `fit()` method to start the training.

In [None]:
batch_size = 128
epochs = 50
learning_rate=1e-3

In [None]:
# Get the training data
(x_train, y_train), (x_valid, y_valid) = get_mnist_data()

In [None]:
# Create a Convolutional Neural Network that
# expects a 28x28 pixel image with 1 color chanel (gray) as input
wandb.init(project="ai-ml-exercise", config={
    "batch_size": batch_size,
    "epochs": epochs,
    "dataset_size": len(x_train),
    "validation_size": len(x_valid),
    "dataset": dataset_name,
    "num_classes": num_classes,
    "shape": shape,
})

wandb_callback = wandb.keras.WandbCallback()

model = create_cnn(shape, num_classes)

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(loss="categorical_crossentropy",
              optimizer=optimizer, metrics=["accuracy"])

# Hack to log initial accuracy to WandB, learning rate is initially set to 0
model.fit(
    x_train, y_train,
    batch_size=batch_size, epochs=1,
    validation_data=(x_valid, y_valid),
    callbacks=[
        wandb_callback, 
        keras.callbacks.LearningRateScheduler(lambda e, lr: 0., verbose=1)
    ],
)

# Now train the model
model.fit(
    x_train, y_train,
    batch_size=batch_size, epochs=epochs,
    validation_data=(x_valid, y_valid),
    callbacks=[
        wandb_callback, 
        keras.callbacks.LearningRateScheduler(lambda e, lr: learning_rate, verbose=1)
    ],
    initial_epoch=1,
)
wandb.finish()