# Load MXNet model

In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.


## Preparation

This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md).

In [1]:
%%loadFromPOM
    <dependency>
        <groupId>ai.djl.aws</groupId>
        <artifactId>aws-ai</artifactId>
        <version>0.10.0-SNAPSHOT</version>
    </dependency>

In [2]:
%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.10.0-SNAPSHOT
%maven ai.djl:model-zoo:0.10.0
%maven ai.djl.mxnet:mxnet-engine:0.10.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.10.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0
%maven com.fasterxml.jackson.core:jackson-annotations:2.10.5
%maven com.fasterxml.jackson.core:jackson-core:2.10.5
%maven com.fasterxml.jackson.core:jackson-databind:2.10.5.1
    
//%maven ai.djl.aws:aws-ai:0.10.0-SNAPSHOT    

    
    
// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport


In [3]:
import java.awt.image.*;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.repository.zoo.ModelZoo;
import com.google.gson.reflect.TypeToken;

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import ai.djl.aws.sagemaker.*;
//import org.testng.SkipException;
//import org.testng.annotations.Test;
//import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
//import software.amazon.awssdk.core.exception.SdkClientException;
//import org.testng.Assert;



//import software.amazon.awssdk.services.s3.S3Client;





## Step 1: Prepare your MXNet model

This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:
* Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model
* Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias
* Synset file: synset.txt - an optional text file that stores classification classes labels

This tutorial uses a pre-trained MXNet `resnet18_v1` model.

We use `DownloadUtils` for downloading files from internet.

In [4]:
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json", "build/resnet/resnet18_v1-symbol.json", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz", "build/resnet/resnet18_v1-0000.params", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt", "build/resnet/synset.txt", new ProgressBar());


## Step 2: Load your model

In [5]:
Path modelDir = Paths.get("build/resnet");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18_v1");
System.out.println(model.getClass().getName());

ai.djl.mxnet.engine.MxModel


In [6]:
//S3Client.builder().build()

In [8]:
//Criteria<NDList, NDList> criteria =
//                Criteria.builder()
//                        .setTypes(NDList.class, NDList.class)
//                        .optModelUrls("file://build/resnet")
//                        .build();


String modelUrl = "https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1";
Criteria<Image, Classifications> criteria = Criteria.builder()
  .setTypes(Image.class, Classifications.class)
  .optModelUrls(modelUrl)
  .optTranslator(ImageClassificationTranslator.builder()
                 .setPipeline(new Pipeline()
                              .add(new Resize(224, 224))
                              .add(new ToTensor()))
                 .optApplySoftmax(true).build())
  .build();


        //ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);

        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) {
            SageMaker sageMaker =
                    SageMaker.builder()
                            .setModel(model)
                            .optBucketName("djl-sm-test")
                            .optModelName("resnet")
                            .optContainerImage("125045733377.dkr.ecr.us-east-1.amazonaws.com/djl")
                            .optExecutionRole(
                                    "arn:aws:iam::125045733377:role/service-role/DJLSageMaker-ExecutionRole-20210213T1027050")
                            .build();

            sageMaker.deploy();

            byte[] image;
            Path imagePath = Paths.get("../../examples/src/test/resources/0.png");
            try (InputStream is = Files.newInputStream(imagePath)) {
                image = Utils.toByteArray(is);
            }
            String ret = new String(sageMaker.invoke(image), StandardCharsets.UTF_8);
            Type type = new TypeToken<List<Classifications.Classification>>() {}.getType();
            List<Classifications.Classification> list = JsonUtils.GSON.fromJson(ret, type);
            String className = list.get(0).getClassName();
            System.out.println(className);
            //Assert.assertEquals(className, "0");

            sageMaker.deleteEndpoint();
            sageMaker.deleteEndpointConfig();
            sageMaker.deleteSageMakerModel();
        }

[IJava-executor-0] WARN software.amazon.awssdk.profiles.internal.ProfileFileReader - Ignoring profile 'tf-neo' on line 4 because it did not start with 'profile ' and it was not 'default'.
[IJava-executor-0] WARN software.amazon.awssdk.profiles.internal.ProfileFileReader - Ignoring profile 'tf-neo' on line 4 because it did not start with 'profile ' and it was not 'default'.
[IJava-executor-0] WARN software.amazon.awssdk.profiles.internal.ProfileFileReader - Ignoring profile 'tf-neo' on line 4 because it did not start with 'profile ' and it was not 'default'.
[IJava-executor-0] WARN software.amazon.awssdk.profiles.internal.ProfileFileReader - Ignoring profile 'tf-neo' on line 4 because it did not start with 'profile ' and it was not 'default'.
[IJava-executor-0] INFO ai.djl.aws.sagemaker.SageMaker - S3 bucket: djl-sm-test already exists.
[IJava-executor-0] INFO ai.djl.aws.sagemaker.SageMaker - Model uploaded to: s3://djl-sm-test/resnet.tar.gz
[IJava-executor-0] INFO ai.djl.aws.sagemaker.

EvalException: A waiter acceptor was matched and transitioned the waiter to failure state

## Step 3: Create a `Translator`

In [None]:
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optSynsetArtifactName("synset.txt")
                .optApplySoftmax(true)
                .build();

## Step 4: Load image for classification

In [None]:
var img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/kitten.jpg");
img.getWrappedImage()

## Step 5: Run inference

In [None]:
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);

classifications

## Summary

Now, you can load any MXNet symbolic model and run inference.

You might also want to check out [load_pytorch_model.ipynb](https://github.com/awslabs/djl/blob/master/jupyter/load_pytorch_model.ipynb) which demonstrates loading a local model using the ModelZoo API.