In [None]:
%%loadFromPOM
<dependencies>
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.20.0</version>
    </dependency>
    <dependency>
        <groupId>ai.djl.mxnet</groupId>
        <artifactId>mxnet-engine</artifactId>
        <version>0.20.0</version>
    </dependency>
    <dependency>
        <groupId>ai.djl.paddlepaddle</groupId>
        <artifactId>paddlepaddle-model-zoo</artifactId>
        <version>0.20.0</version>
    </dependency>
    <dependency>
        <groupId>ai.djl.onnxruntime</groupId>
        <artifactId>onnxruntime-engine</artifactId>
        <version>0.20.0</version>
        <scope>runtime</scope>
        <exclusions>
            <exclusion>
                <groupId>com.microsoft.onnxruntime</groupId>
                <artifactId>onnxruntime</artifactId>
            </exclusion>
        </exclusions>
    </dependency>
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime_gpu</artifactId>
        <version>1.13.1</version>
        <scope>runtime</scope>
    </dependency>
    <dependency>
        <groupId>ai.djl.opencv</groupId>
        <artifactId>opencv</artifactId>
        <version>0.20.0</version>
    </dependency>
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-simple</artifactId>
        <version>1.7.36</version>
    </dependency>
</dependencies>

<repositories>
    <repository>
        <id>aliyun</id>
        <url>https://maven.aliyun.com/repository/central</url>
    </repository>
</repositories>

In [3]:
import ai.djl.MalformedModelException;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;

import java.io.IOException;
import java.nio.file.Paths;

public class Models {

    public static ZooModel<Image, DetectedObjects> getModel() throws ModelNotFoundException, MalformedModelException, IOException {
        return Criteria.builder()
                .optEngine("OnnxRuntime") // 选择引擎
                .setTypes(Image.class, DetectedObjects.class) // 设置输入输出
                .optModelPath(Paths.get("/root/autodl-nas/pedestrian_yolov3_darknet.onnx")) // 设置模型地址。Jar 包、Zip 包根据 API 自行配置
                .optProgress(new ProgressBar()) // 进度条
                .optTranslator(new PedestrianTranslator(.5f)) // 默认的转换器，不是线程安全的
                .build().loadModel();
    }
}

In [4]:
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslatorContext;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

// 非批量输入输出应实现 NoBatchifyTranslator 接口，而不是 Translator
public class PedestrianTranslator implements NoBatchifyTranslator<Image, DetectedObjects> {
    private final Pipeline pipeline;
    private final float threshold;
    private final List<String> classes;
    private final float imageWidth = 608f;
    private final float imageHeight = 608f;

    public PedestrianTranslator(float threshold) {
        // 定义图片预处理过程
        pipeline = new Pipeline();
        pipeline.add(new Resize((int) imageWidth, (int) imageHeight)) // resize 到图片输入格式，此时为 608 * 608 * 3，HWC
                .add(new ToTensor()) // HWC -> CHW
                .add(new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f})) // 归一化
                .add(array -> array.expandDims(0)); // CHW -> NCHW
        // 预测阈值
        this.threshold = threshold;
        // 类别
        classes = Collections.singletonList("pedestrian");
    }

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        // 内存管理器，负责 NDArray 的内存回收
        NDManager manager = ctx.getNDManager();
        // 通过构造函数定义好的管道把图片转换到模型需要的图片格式。NDList 是一个集合，与 List<NDArray> 类似
        NDList ndList = pipeline.transform(new NDList(input.toNDArray(manager, Image.Flag.COLOR)));
        // 添加原图尺寸参数
        ndList.add(0, manager.create(new float[]{input.getHeight(), input.getWidth()}).expandDims(0));
        // 添加原图片尺寸与输入图片尺寸的比值
        ndList.add(manager.create(new float[]{input.getHeight() / 608f, input.getWidth() / 608f}).expandDims(0));
        return ndList;
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        // 获取第一个参数预测结果，第二个预测数量没什么用
        NDArray result = list.get(0);
        /*
        result demo:
        ND: (3, 6) cpu() float32
        [[  0.    ,   0.9759,  10.0805, 276.1631, 298.1623, 586.246 ],
         [  0.    ,   0.955 , 486.306 , 221.0572, 585.966 , 480.4897],
         [  0.    ,   0.8031, 295.0543, 206.104 , 395.3066, 485.3789],
        ]
         */
        // 获取类别
        int[] classIndices = result.get(":, 0").toType(DataType.INT32, true).flatten().toIntArray();
        // 获取置信度
        double[] probs = result.get(":, 1").toType(DataType.FLOAT64, true).toDoubleArray();
        // 获取预测的目标数量
        int detected = Math.toIntExact(probs.length);

        // 获取矩形框左上角 x 坐标比例（第 2 列）
        NDArray xMin = result.get(":, 2:3").clip(0, imageWidth).div(imageWidth);
        // 获取矩形框左上角 y 坐标比例（第 3 列）
        NDArray yMin = result.get(":, 3:4").clip(0, imageHeight).div(imageHeight);
        // 获取矩形框右上角 x 坐标比例（第 4 列）
        NDArray xMax = result.get(":, 4:5").clip(0, imageWidth).div(imageWidth);
        // 获取矩形框右上角 y 坐标比例（第 5 列）
        NDArray yMax = result.get(":, 5:6").clip(0, imageHeight).div(imageHeight);

        // 转为可以直接绘制的数据，分别是矩形框左上角的 x 和 y 坐标、矩形框的宽和高，均为比例
        float[] boxX = xMin.toFloatArray();
        float[] boxY = yMin.toFloatArray();
        float[] boxWidth = xMax.sub(xMin).toFloatArray();
        float[] boxHeight = yMax.sub(yMin).toFloatArray();

        // 封装到 DetectedObjects 对象输出
        List<String> retClasses = new ArrayList<>(detected);
        List<Double> retProbs = new ArrayList<>(detected);
        List<BoundingBox> retBB = new ArrayList<>(detected);
        for (int i = 0; i < detected; i++) {
            // 类别不存在或者置信度低于预测阈值则跳过
            if (classIndices[i] < 0 || probs[i] < threshold) {
                continue;
            }
            retClasses.add(classes.get(0));
            retProbs.add(probs[i]);
            retBB.add(new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]));
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }
}

In [5]:
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

public class Inference {
    public static void run(String imageFilePath) throws IOException, MalformedModelException, TranslateException, ModelNotFoundException {
        // 加载模型
        try (ZooModel<Image, DetectedObjects> model = Models.getModel()) {
            // 新建一个推理
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(Device.gpu())) {
                Image image = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
                // 推理
                DetectedObjects result = predictor.predict(image);
                // 绘制矩形框
                image.drawBoundingBoxes(result);
                image.save(Files.newOutputStream(Paths.get("output.png")), "png");
            }
        }
    }
}

In [None]:
Inference.run("/root/autodl-nas/958*604.png")