<img src="https://cdn.comet.ml/img/notebook_logo.png">

[Comet](https://www.comet.com/site/products/ml-experiment-tracking/?utm_campaign=keras&utm_medium=colab) is an MLOps Platform that is designed to help Data Scientists and Teams build better models faster! Comet provides tooling to track, Explain, Manage, and Monitor your models in a single place! It works with Jupyter Notebooks and Scripts and most importantly it's 100% free to get started!

[Keras](https://keras.io/) is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. It also has extensive documentation and developer guides.

Instrument Keras with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

[Find more information about our integration with Keras](https://www.comet.ml/docs/v2/integrations/ml-frameworks/keras/)

Get a preview for what's to come. Check out a completed experiment created from this notebook [here](https://www.comet.com/examples/comet-example-keras-notebook/ef20812c7b1d4876a7cbd5679a12c309).

This example is based on the [following Keras MNist example](https://keras.io/examples/vision/mnist_convnet/).

# Install Dependencies

In [None]:
%pip install comet_ml tensorflow numpy

# Initialize Comet

In [None]:
import comet_ml

comet_ml.init(project_name="comet-example-keras-notebook")

# Import Dependencies

In [None]:
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Create a Comet Experiment

In [None]:
experiment = comet_ml.Experiment(
    auto_histogram_weight_logging=True,
    auto_histogram_gradient_logging=True,
    auto_histogram_activation_logging=True,
)

# Define hyper-parameters

In [None]:
parameters = {
    "batch_size": 128,
    "epochs": 3,
    "optimizer": "adam",
    "loss": "categorical_crossentropy",
    "validation_split": 0.1,
}

experiment.log_parameters(parameters)

# Load and prepare the dataset

We are using tensorflow datasets library to download the dataset for us.

In [None]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# Create the model

In [None]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

# Train the model

In [None]:
model.compile(
    loss=parameters["loss"], optimizer=parameters["optimizer"], metrics=["accuracy"]
)

model.fit(
    x_train,
    y_train,
    batch_size=parameters["batch_size"],
    epochs=parameters["epochs"],
    validation_split=parameters["validation_split"],
)

# Evaluate the trained model

It is now time to test the trained model against the test dataset.

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

# End experiment

This will logs the notebook code and make sure that all the data has been sent to Comet.

In [None]:
experiment.end()