# MNIST

MNIST is a simple computer vision dataset that consists of images of handwritten digits.

Some examples:

![MNIST digits sample](https://github.com/user-attachments/assets/996c11e0-47f9-4b21-8e23-3867b8942e64)

It also includes labels for each image, which we use to train our example models.

## The `mnist` library

This package includes the following functionality:

  - Download the dataset from [storage.googleapis.com/cvdf-datasets/mnist](https://storage.googleapis.com/cvdf-datasets/mnist),
  - Create a `Dataset` object to iterate over it, use for training and evaluation.
  - A linear and a CNN model demo.
  - A command-line demo (in the `demo` sub-directory).

This notebook serves as documentation and example for the [github.com/gomlx/gomlx/examples/mnist](https://github.com/gomlx/gomlx/examples/mnist) library, and the demo code in one piece can be seen in [.../examples/mnist/demo/](https://github.com/gomlx/gomlx/tree/main/examples/mnist/demo)

## Environment Set Up

Let's set up `go.mod` to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.

If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.

Notice the directory `${HOME}/Projects/gomlx` is where the GoMLX code is copied by default in [its Docker](https://hub.docker.com/repository/docker/janpfeifer/gomlx_jupyterlab/general).

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".


## Data Preparation

### Downloading data files

To download to the local directory, simply do the following. Notice if it's already downloaded in the given `--data` directory, it returns immediately.

In [2]:
import (
    "github.com/gomlx/gomlx/examples/mnist"
    "github.com/gomlx/gomlx/pkg/support/fsutil"
    "github.com/janpfeifer/must"
)

var flagDataDir = flag.String("data", "~/work/mnist", "Directory to cache downloaded and generated dataset files.")

func AssertDownloaded() {
    *flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))
    if !fsutil.MustFileExists(*flagDataDir) {
        must.M(os.MkdirAll(*flagDataDir, 0777))
    }
   must.M(mnist.Download(*flagDataDir))
}

%%
AssertDownloaded()

In [3]:
!ls -lh ~/work/mnist/

total 12M
drwxr-x--- 2 janpf janpf 4.0K Oct 11 11:33 cnn
drwxr-x--- 2 janpf janpf 4.0K Jun  3 15:39 cnn_triplet
drwxr-x--- 2 janpf janpf 4.0K Jun  3 15:38 linear
drwxr-x--- 2 janpf janpf 4.0K Feb 13  2025 linear_baseline
-rw-r--r-- 1 janpf janpf 1.6M Jun  8 08:10 t10k-images-idx3-ubyte.gz
-rw-r--r-- 1 janpf janpf 4.5K Jun  8 08:11 t10k-labels-idx1-ubyte.gz
-rw-r--r-- 1 janpf janpf 9.5M Jun  8 08:09 train-images-idx3-ubyte.gz
-rw-r--r-- 1 janpf janpf  29K Jun  8 08:09 train-labels-idx1-ubyte.gz


### Sample some images
The `mnist.NewDataset` creates a `data.InMemoryDataset` that can be used both for training, evaluation, or just to sample a few examples, which we do below:

In [4]:
import (
    "fmt"
    "strings"
	"strconv"
    "github.com/gomlx/gopjrt/dtypes"
    "github.com/gomlx/gomlx/backends"
    "github.com/gomlx/gomlx/examples/mnist"
    "github.com/gomlx/gomlx/pkg/core/shapes"
    "github.com/gomlx/gomlx/pkg/core/tensors/images"
    "github.com/janpfeifer/gonb/gonbui"

    _ "github.com/gomlx/gomlx/backends/default"
)

var (
    // Model DType, used everywhere.
    DType = dtypes.Float32
)

// sampleToNotebook generates a sample of MNIST in a GoNB Jupyter Notebook.
func sampleToNotebook() {
    // Load data into tensors.
    backend := backends.MustNew()
    if ds, err := mnist.NewDataset(backend, "Samples MNIST", *flagDataDir, "train", DType); err != nil {
        fmt.Printf("mnist.NewDataset: %v", err)
    } else {
        ds.Shuffle()
        sampleImages(ds, 10)
    }
   
}

// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.
func sampleImages(ds train.Dataset, numImages int) {
    gonbui.DisplayHTML(fmt.Sprintf("<p>%s</p>\n", ds.Name()))
    
    parts := make([]string, 0, numImages+5) // Leave last part empty.
    parts = append(parts, "<table><tr>")
    for ii := 0; ii < numImages; ii++ {
        _, inputs, labels := must.M3(ds.Yield())
        imgTensor := inputs[0]
        img := images.ToImage().Single(imgTensor)
        label := labels[0].Value().([]int8)
    
        imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
        size := imgTensor.Shape().Dimensions[0]
        parts = append(
            parts, 
            fmt.Sprintf(`<td><figure style="padding:4px;text-align: center;"><img width="%d" height="%d" src="%s">` + 
                        `<figcaption style="text-align: center;">(%d)</figcaption></figure></td>`, 
                        size*2, size*2, imgSrc, label),
        )
    }
    parts = append(parts, "</tr></table>", "")
    gonbui.DisplayHTML(strings.Join(parts, "\n"))
}

%%
AssertDownloaded()
sampleToNotebook()

0,1,2,3,4,5,6,7,8,9
([4]),([2]),([3]),([0]),([7]),([7]),([6]),([3]),([1]),([1])


## Training on MNIST

### Models Support

1. `flagModel` defines the model type, out of `validModels` options.
1. `createDefaultContext` creates a context and set the default values for the MNIST models. 
1. `contextFromSettings` uses `createDefaultContext` and incorporate changes passed by the `-set` flag.

In [5]:
import (
    "flags"
    
    "github.com/gomlx/gomlx/pkg/ml/layers"
    "github.com/gomlx/gomlx/ui/commandline"
    "github.com/gomlx/gomlx/pkg/ml/train/optimizers"
    "github.com/gomlx/gomlx/examples/mnist"
    "github.com/gomlx/gomlx/pkg/ml/context"
)

var (
	flagEval      = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
)

// settings is bound to a "-set" flag to be used to set context hyperparameters.
var settings = commandline.CreateContextSettingsFlag(CreateDefaultContext(), "set")

// createDefaultContext sets the context with default hyperparameters
func CreateDefaultContext() *context.Context {
	ctx := context.New()
	ctx.RngStateReset()
	ctx.SetParams(map[string]any{
		// Model type to use
		"model":           "linear",
		"loss":            "sparse_cross_logits",
		"num_checkpoints": 3,
		"train_steps":     4000,

		// batch_size for training.
		"batch_size": 600,

		// eval_batch_size can be larger than training, it's more efficient.
		"eval_batch_size": 1000,

		// Debug parameters.
		"nan_logger": false, // Trigger nan error as soon as it happens -- expensive, but helps debugging.

		// "plots" trigger generating intermediary eval data for plotting, and if running in GoNB, to actually
		// draw the plot with Plotly.
		//
		// From the command-line, an easy way to monitor the metrics being generated during the training of a model
		// is using the gomlx_checkpoints tool:
		//
		//	$ gomlx_checkpoints --metrics --metrics_labels --metrics_types=accuracy  --metrics_names='E(Tra)/#loss,E(Val)/#loss' --loop=3s "<checkpoint_path>"
		plotly.ParamPlots: false,

		optimizers.ParamOptimizer:       "adamw",
		optimizers.ParamLearningRate:    1e-4,
		optimizers.ParamAdamEpsilon:     1e-7,
		optimizers.ParamAdamDType:       "",
		cosineschedule.ParamPeriodSteps: 0,
		activations.ParamActivation:     "relu",
		layers.ParamDropoutRate:         0.5,
		regularizers.ParamL2:            0.0,
		regularizers.ParamL1:            0.0,

		// CNN
		"cnn_dropout_rate":  0.5,
		"cnn_normalization": "layer", // "layer" or "batch".

		// Triplet
		losses.ParamTripletLossPairwiseDistanceMetric: "L2",
		losses.ParamTripletLossMiningStrategy:         "Hard",
		losses.ParamTripletLossMargin:                 0.5,
	})
	return ctx
}

// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
func ContextFromSettings() (ctx *context.Context, paramsSet []string) {
    ctx = mnist.CreateDefaultContext()
    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
    return
}

// Let's test that we can set hyperparameters by setting it in the "-set" flag:
%% -set="batch_size=17;model='cnn';train_steps=10"
fmt.Printf("Models: %q\n", mnist.ModelList)
ctx, parametersSet := ContextFromSettings()
fmt.Printf("Parameters set (-set): %q\n", parametersSet)
fmt.Println(commandline.SprintContextSettings(ctx))

Models: ["linear" "cnn"]
Parameters set (-set): ["batch_size" "model" "train_steps"]
	"/activation": (string) relu
	"/adam_dtype": (string) 
	"/adam_epsilon": (float64) 1e-07
	"/batch_size": (int) 17
	"/cnn_dropout_rate": (float64) 0.5
	"/cnn_normalization": (string) layer
	"/cosine_schedule_steps": (int) 0
	"/dropout_rate": (float64) 0.5
	"/eval_batch_size": (int) 1000
	"/l1_regularization": (float64) 0
	"/l2_regularization": (float64) 0
	"/learning_rate": (float64) 0.0001
	"/loss": (string) sparse_cross_logits
	"/model": (string) 'cnn'
	"/nan_logger": (bool) false
	"/num_checkpoints": (int) 3
	"/optimizer": (string) adamw
	"/plots": (bool) false
	"/train_steps": (int) 10
	"/triplet_loss_margin": (float64) 0.5
	"/triplet_loss_mining_strategy": (string) Hard
	"/triplet_loss_pairwise_distance_metric": (string) L2


### Linear model

A linear model can easily get to ~92% accuracy (a random model would do 10%) with 4000 steps.

Later we are going to define a CNN model to compare, and we just set a placeholder model here for now.

> **Note**: 
>
> * The code is here just to exemplify. We are actually using the same code from the [`mnist`](https://github.com/gomlx/gomlx/tree/main/examples/mnist) package.

In [6]:
import (
	"github.com/gomlx/gomlx/backends"
	. "github.com/gomlx/gomlx/pkg/core/graph"
	"github.com/gomlx/gomlx/pkg/ml/context"
	"github.com/gomlx/gomlx/pkg/ml/layers"
)

var _ = NewGraph  // Make sure the graph package is in use.

// LinearModelGraph builds a simple  model logistic model
// It returns the logit, not the predictions, which works with most losses with shape `[batch_size, NumClasses]`.
// inputs: only one tensor, with shape `[batch_size, width, height, depth]`.
func LinearModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
	ctx = ctx.In("model") // Create the model by default under the "/model" scope.
	batchSize := inputs[0].Shape().Dimensions[0]
	embeddings := Reshape(inputs[0], batchSize, -1)
	logits := layers.DenseWithBias(ctx, embeddings, mnist.NumClasses)
	return []*Node{logits}
}

%% -set="batch_size=10"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
ctx, _ := ContextFromSettings()
g := NewGraph(backends.MustNew(), "placeholder")
batchSize := context.GetParamOr(ctx, "batch_size", int(100))
logits := LinearModelGraph(ctx, nil, []*Node{Parameter(g, "images", shapes.Make(DType, batchSize, mnist.Height, mnist.Width, mnist.Depth))})
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits[0].Shape())

Logits shape for batch_size=10: (Float32)[10 10]


### Training Loop

With a model function defined, we use the training loop create for the MNIST.

The trainer is provided in the [`mnist` package](https://github.com/gomlx/gomlx/tree/main/examples/mnist). It is straight forward (and almost the same for every different project) and does the following for us:

- If a checkpoing is given (--checkpoint) and it has previously saved model, it loads hyperparmeters and trained variables.
- Create trainer: with selected model function (see [Linear model](#Linear-model) and [Linear model for MNIST](#CNN-model-for-MNIST) sections), optimizer, loss and metrics.
- Create a `train.Loop` and attach to it a progressbar, a periodic checkpoint saver and a plotter (`--set="plots=true"`).
- Train the selected number of train steps.
- Report results.

Below we train 4000 steps with the default settings just to check things are working.

In [7]:
var flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")

// trainModel with hyperparameters configured with `-set=...`.
func trainModel() {
    ctx, paramsSet := ContextFromSettings()
    must.M(mnist.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet))
}

// Train 50 steps, only to test things are working. No plots.
%% --checkpoint=linear  --set="model=linear;train_steps=4000;plots=true"
trainModel()

Training linear model:
Backend stablehlo: stablehlo:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]
	- checkpoint in /home/janpf/work/mnist/linear


	- restarting from global step 4000
	 - target train_steps=4000 already reached. To train further, set a number additional to current global step.

Results on train:
	Mean Loss+Regularization (#loss+): 0.287
	Mean Loss (#loss): 0.287
	Mean Accuracy (#acc): 92.07%
Results on test:
	Mean Loss+Regularization (#loss+): 0.284
	Mean Loss (#loss): 0.284
	Mean Accuracy (#acc): 92.09%



### CNN Model for MNIST

Let's now properly define a CNN model to compare.

The model was built following a [Deep MNIST for Experts](https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/r0.10/tensorflow/g3doc/tutorials/mnist/pros/index.md)

In [8]:
// CnnModelGraph builds the CNN model for our demo.
// It returns the logit, not the predictions, which works with most losses with shape `[batch_size, NumClasses]`.
// inputs: only one tensor, with shape `[batch_size, width, height, depth]`.
func CnnModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
	ctx = ctx.In("model") // Create the model by default under the "/model" scope.
	embeddings := CnnEmbeddings(ctx, inputs[0])
	logits := layers.Dense(ctx, embeddings, true, mnist.NumClasses)
	return []*Node{logits}
}

func CnnEmbeddings(ctx *context.Context, images *Node) *Node {
	batchSize := images.Shape().Dimensions[0]
	g := images.Graph()
	dtype := images.DType()

	layerIdx := 0
	nextCtx := func(name string) *context.Context {
		newCtx := ctx.Inf("%03d_%s", layerIdx, name)
		layerIdx++
		return newCtx
	}
	// Dropout.
	dropoutRate := context.GetParamOr(ctx, "cnn_dropout_rate", -1.0)
	if dropoutRate < 0 {
		dropoutRate = context.GetParamOr(ctx, layers.ParamDropoutRate, 0.0)
	}
	var dropoutNode *Node
	if dropoutRate > 0.0 {
		dropoutNode = Scalar(g, dtype, dropoutRate)
	}

	images = layers.Convolution(nextCtx("conv"), images).Filters(32).KernelSize(3).PadSame().Done()
	images.AssertDims(batchSize, 28, 28, 32)
	images = activations.Relu(images)
	images = normalizeCNN(nextCtx("norm"), images)
	images = MaxPool(images).Window(2).Done()
	images.AssertDims(batchSize, 14, 14, 32)

	images = layers.Convolution(nextCtx("conv"), images).Filters(64).KernelSize(3).PadSame().Done()
	images.AssertDims(batchSize, 14, 14, 64)
	images = activations.Relu(images)
	images = normalizeCNN(nextCtx("norm"), images)
	images = MaxPool(images).Window(2).Done()
	images = layers.DropoutNormalize(nextCtx("dropout"), images, dropoutNode, true)
	images.AssertDims(batchSize, 7, 7, 64)

	// Flatten images
	images = Reshape(images, batchSize, -1)
	return images
}

func normalizeCNN(ctx *context.Context, logits *Node) *Node {
	normalizationType := context.GetParamOr(ctx, "cnn_normalization", "none")
	switch normalizationType {
	case "layer":
		if logits.Rank() == 2 {
			return layers.LayerNormalization(ctx, logits, -1).Done()
		} else if logits.Rank() == 4 {
			return layers.LayerNormalization(ctx, logits, 2, 3).Done()
		} else {
			return logits
		}
	case "batch":
		return batchnorm.New(ctx, logits, -1).Done()
	case "none", "":
		return logits
	default:
		exceptions.Panicf("invalid normalization type %q -- set it with parameter %q", normalizationType, "cnn_normalization")
		return nil
	}
}
%% -set="batch_size=10"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
ctx, _ := ContextFromSettings()
g := NewGraph(backends.MustNew(), "placeholder")
batchSize := context.GetParamOr(ctx, "batch_size", int(100))
logits := CnnModelGraph(ctx, nil, []*Node{Parameter(g, "images", shapes.Make(DType, batchSize, mnist.Height, mnist.Width, mnist.Depth))})
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits[0].Shape())

Logits shape for batch_size=10: (Float32)[10 10]


### CNN Model Training

Let's train the CNN for real this time. 

In [9]:
// Remove a previously trained model
!rm -rf ~/work/mnist/cnn  

In [10]:
%% --checkpoint=cnn --set="model=cnn;train_steps=4000;plots=true"
trainModel()

Training cnn model:
Backend stablehlo: stablehlo:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO]
	- checkpoint in /home/janpf/work/mnist/cnn


      [1m   0% [........................................] (17 steps/s) [1s:3m52s][0m [step=35] [loss+=2.05] [~loss+=2.97] [~loss=2.97] [~acc=24.69%]        


	- saving checkpoint@4000


	- trained to step 4000, median train step: 3965 microseconds

Results on train:
	Mean Loss+Regularization (#loss+): 0.0201
	Mean Loss (#loss): 0.0201
	Mean Accuracy (#acc): 99.40%
Results on test:
	Mean Loss+Regularization (#loss+): 0.0356
	Mean Loss (#loss): 0.0356
	Mean Accuracy (#acc): 98.75%

