# Train your first model


In this tutorial, you will learn how to train an image classification model that can recognize handwritten digit.

## Preparation

This tutorial requires the installation of Java Kernel. Read [here](https://github.com/awslabs/djl/blob/master/jupyter/README.md) to install Java Kernel.

In [None]:
%mavenRepo s3 https://djl-ai.s3.amazonaws.com/dev

In [None]:
%maven ai.djl:api:0.2.0-SNAPSHOT
%maven ai.djl:basicdataset:0.2.0-SNAPSHOT
%maven ai.djl:model-zoo:0.2.0-SNAPSHOT
%maven ai.djl:repository:0.2.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-engine:0.2.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
// %maven ai.djl.mxnet:mxnet-native-mkl:jar:osx-x86_64:1.6.0

### Include MXNet engine dependency

You may need to update the following xml `<classifier>` section for your platform. This may take some time when downloading the library for the first time.

Mac OS
* **osx**-x86_64

Linux

* **linux**-x86_64

In [None]:
%%loadFromPOM
  <repositories>
    <repository>
      <id>djl.ai</id>
      <url>https://djl-ai.s3.amazonaws.com/dev</url>
    </repository>
  </repositories>

  <dependencies>
    <dependency>
      <groupId>ai.djl.mxnet</groupId>
      <artifactId>mxnet-native-mkl</artifactId>
      <version>1.6.0</version>
      <classifier>osx-x86_64</classifier>
    </dependency>
  </dependencies>


In [None]:
import java.nio.file.*;

import ai.djl.*;
import ai.djl.basicdataset.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.metrics.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.optimizer.learningrate.*;
import ai.djl.training.util.*;
import ai.djl.zoo.cv.classification.*;

# Step 1: Create your neural network

In this tutoral, we use built-in MLP block from Model Zoo. To learn more about MLP block see [here](create_your_first_network.ipynb)

Images in MNIST dataset is 28x28 grayscale image. So we create MLP back with 28 x 28 input.

In [None]:
Model model = Model.newInstance();
model.setBlock(new Mlp(28, 28));

# Step 2: Setup your training configurations


The following is a few common items you need to configure:
* batch size: to take the advantage of GPU, we usually train model with batch, pick a proper size based on your model.
* [`Initializer`](https://djl-ai.s3.amazonaws.com/java-api/0.2.0/api/index.html?ai/djl/training/initializer/Initializer.html): Pick an `Initializer` to initialize model parameters.
* [`Loss`](https://djl-ai.s3.amazonaws.com/java-api/0.2.0/api/index.html?ai/djl/training/loss/Loss.html) function: A loss function is used to measure how good (or bad) our models are. Because the lower value of the function is better, we call it "loss" function.
* [`Optimizer`](https://djl-ai.s3.amazonaws.com/java-api/0.2.0/api/index.html?ai/djl/training/optimizer/Optimizer.html): Optimization algorithms allowed us to continue updating model parameters and to minimize the value of the loss function.


In [None]:
int batchSize = 32;

Initializer initializer = new XavierInitializer();
Loss loss = Loss.softmaxCrossEntropyLoss();
Optimizer optimizer = Optimizer.adam()
        .setRescaleGrad(1.0f / batchSize)
        .build();
Accuracy accuracy = new Accuracy();

TrainingConfig config = new DefaultTrainingConfig(initializer, loss)
        .setOptimizer(optimizer)
        .addTrainingMetric(accuracy)
        .setBatchSize(batchSize);

# Step 3: Prepare MNIST dataset for training

The [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset is a database of handwritten digits that is commonly used for training image classification models. 

We provide the MNIST dataset as a built-in dataset, which may it easy for you to use it.

In [None]:
Mnist mnist = Mnist.builder(model.getNDManager()).setSampling(batchSize, true).build();
mnist.prepare(); // Download MNIST dataset, this may take a few seconds.

# Step 4: Create a Trainer

Now you can create a `Trainer` with your training configuration.
You need to initialize the model paramters in the trainer with proper input shape:
* The first axis of the input is batch, that won't impact the parameter, we we can just use 1 here
* The second axis of MNIST image, which is 28 * 28.


In [None]:
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 28 * 28));

# Step 5: Train your model

In [None]:
ProgressBar progressBar = new ProgressBar("Training", (int)(mnist.size() / batchSize));
int epoch = 2;
float trainingAccuracy = 0f;
for (int i = 0; i < epoch; ++i) {
    int index = 0;
    for (Batch batch : trainer.iterateDataset(mnist)) {
        trainer.trainBatch(batch);
        trainer.step();
        batch.close();

        trainingAccuracy = accuracy.getValue();
        progressBar.update(index++, String.format("Epoch: %d, Accuracy: %.3f", i, trainingAccuracy));
    }
    // reset training and validation metric at end of epoch
    trainer.resetTrainingMetrics();
}

# Step 6: Save your model

While saving your model, you can add metadata to it, such as: training accuracy, number epoch, etc.

In [None]:
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));
model.setProperty("Accuracy", String.valueOf(trainingAccuracy));
model.save(modelDir, "mlp");

model

# Summary

Now you successfully trained your model that can recognize handwritten digit. You can proceed to next chapter: [Run image classification with your model](image_classification_with_your_model.ipynb).

You can find complete source code in [examples project](https://github.com/awslabs/djl/blob/master/examples/docs/train_your_first_model.md).