# Train your first model


In this tutorial, you learn how to train an image classification model that can recognize handwritten digits.

## Preparation

This tutorial requires the installation of the Java Kernel. To install the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md).

In [None]:
%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.3.0-SNAPSHOT
%maven ai.djl:basicdataset:0.3.0-SNAPSHOT
%maven ai.djl:model-zoo:0.3.0-SNAPSHOT
%maven ai.djl:repository:0.3.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-engine:0.3.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

### Include MXNet engine dependency

This tutorial uses MXNet engine as its backend. MXNet has different [build flavor](https://mxnet.apache.org/get_started?version=v1.5.1&platform=linux&language=python&environ=pip&processor=cpu) and it is platform specific.
Please read [here](https://github.com/awslabs/djl/blob/master/examples/README.md#engine-selection) for how to select MXNet engine flavor.

In [None]:
String osName = System.getProperty("os.name");
String classifier = osName.startsWith("Mac") ? "osx-x86_64" : osName.startsWith("Win") ? "win-x86_64" : "linux-x86_64";


%maven ai.djl.mxnet:mxnet-native-mkl:jar:${classifier}:1.6.0-b-SNAPSHOT

In [None]:
import java.nio.file.*;

import ai.djl.*;
import ai.djl.basicdataset.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.metrics.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.optimizer.learningrate.*;
import ai.djl.training.util.*;
import ai.djl.zoo.cv.classification.*;

# Step 1: Create your neural network

In this tutoral, use the built-in MLP block from Model Zoo. To learn more about MLP block, see [Create Your First Network](create_your_first_network.ipynb)

Images in MNIST dataset are 28x28 grayscale images, so create an MLP block with 28 x 28 input.

In [None]:
Model model = Model.newInstance();
model.setBlock(new Mlp(28, 28));

# Step 2: Setup your training configurations


The following are a few common items you need to configure your training:
* batch size: To take the advantage of GPU, you usually train models in batches. Pick a batch size based on your model.
* [`Initializer`](https://javadoc.djl.ai/api/0.2.1/index.html?ai/djl/training/initializer/Initializer.html): An `Initializer` initializes model parameters.
* [`Loss`](https://javadoc.djl.ai/api/0.2.1/index.html?ai/djl/training/loss/Loss.html) function: A loss function is used to measure how good (or bad) our models are. Because the lower value of the function is better, it's called the "loss" function.
* [`Optimizer`](https://javadoc.djl.ai/api/0.2.1/index.html?ai/djl/training/optimizer/Optimizer.html): Optimization algorithms allow us to continue updating model parameters and to minimize the value of the loss function.
* `Device`: DJL can automatically detect whether a GPU is available. If GPUs are available, it will run on a single GPU by default. If you need to train with multiple GPUs, you need to set devices as : `config.setDevices(Devices.getDevices(maxNumberOfGPUs))`.


In [None]:
int batchSize = 32;

Initializer initializer = new XavierInitializer();
Loss loss = Loss.softmaxCrossEntropyLoss();
Optimizer optimizer = Optimizer.adam()
        .setRescaleGrad(1.0f / batchSize)
        .build();
Accuracy accuracy = new Accuracy();

TrainingConfig config = new DefaultTrainingConfig(initializer, loss)
        .setOptimizer(optimizer)
        .addTrainingMetric(accuracy)
        .setBatchSize(batchSize);

# Step 3: Prepare MNIST dataset for training

The [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset is a database of handwritten digits that is commonly used for training image classification models. 

We provide the MNIST dataset as a built-in dataset, which makes it easy for you to use it.

In [None]:
Mnist mnist = Mnist.builder(model.getNDManager()).setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());

# Step 4: Create a Trainer

Now, you can create a `Trainer` with your training configuration.
You need to initialize the model parameters in the trainer with proper input shape:
* The first axis of the input is batch, which won't impact the parameter, so you can use 1 here.
* The second axis of the MNIST image, which is 28 * 28.


In [None]:
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 28 * 28));

# Step 5: Train your model

In [None]:
ProgressBar progressBar = new ProgressBar("Training", (int)(mnist.size() / batchSize));
int epoch = 2;
float trainingAccuracy = 0f;
for (int i = 0; i < epoch; ++i) {
    int index = 0;
    for (Batch batch : trainer.iterateDataset(mnist)) {
        trainer.trainBatch(batch);
        trainer.step();
        batch.close();

        trainingAccuracy = accuracy.getValue();
        progressBar.update(index++, String.format("Epoch: %d, Accuracy: %.3f", i, trainingAccuracy));
    }
    // reset training and validation metric at end of epoch
    trainer.resetTrainingMetrics();
}

# Step 6: Save your model

While saving your model, you can add metadata to it such as: training accuracy, number epoch, etc.

In [None]:
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));
model.setProperty("Accuracy", String.valueOf(trainingAccuracy));
model.save(modelDir, "mlp");

model

# Summary

Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: [Run image classification with your model](image_classification_with_your_model.ipynb).

You can find the complete source code in [examples project](https://github.com/awslabs/djl/blob/master/examples/docs/train_your_first_model.md).