Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.27调用tensorflow的pb模型崩溃(可以加载模型,推理时崩溃) #3173

Open
LEEay opened this issue May 11, 2024 · 11 comments
Open
Labels
bug Something isn't working

Comments

@LEEay
Copy link

LEEay commented May 11, 2024

Description

0.27调用tensorflow的pb模型崩溃(可以加载模型,推理时崩溃)
Java frames: (J=compiled Java code, j=interpreted, Vv=VM code)
j org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun(Lorg/tensorflow/internal/c_api/TF_Session;Lorg/tensorflow/internal/c_api/TF_Buffer;Lorg/tensorflow/internal/c_api/TF_Output;Lorg/bytedeco/javacpp/PointerPointer;ILorg/tensorflow/internal/c_api/TF_Output;Lorg/bytedeco/javacpp/PointerPointer;ILorg/bytedeco/javacpp/PointerPointer;ILorg/tensorflow/internal/c_api/TF_Buffer;Lorg/tensorflow/internal/c_api/TF_Status;)V+0
j ai.djl.tensorflow.engine.javacpp.JavacppUtils.runSession(Lorg/tensorflow/internal/c_api/TF_Session;Lorg/tensorflow/proto/framework/RunOptions;[Lorg/tensorflow/internal/c_api/TF_Tensor;[Lorg/tensorflow/internal/c_api/TF_Operation;[I[Lorg/tensorflow/internal/c_api/TF_Operation;[I[Lorg/tensorflow/internal/c_api/TF_Operation;)[Lorg/tensorflow/internal/c_api/TF_Tensor;+270
j ai.djl.tensorflow.engine.TfSymbolBlock.forwardInternal(Lai/djl/training/ParameterStore;Lai/djl/ndarray/NDList;ZLai/djl/util/PairList;)Lai/djl/ndarray/NDList;+202
j ai.djl.nn.AbstractBaseBlock.forward(Lai/djl/training/ParameterStore;Lai/djl/ndarray/NDList;ZLai/djl/util/PairList;)Lai/djl/ndarray/NDList;+36
j ai.djl.nn.Block.forward(Lai/djl/training/ParameterStore;Lai/djl/ndarray/NDList;Z)Lai/djl/ndarray/NDList;+5
j ai.djl.inference.Predictor.predictInternal(Lai/djl/translate/TranslatorContext;Lai/djl/ndarray/NDList;)Lai/djl/ndarray/NDList;+21
j ai.djl.inference.Predictor.batchPredict(Ljava/util/List;)Ljava/util/List;+133
j ai.djl.inference.Predictor.predict(Ljava/lang/Object;)Ljava/lang/Object;+5
j com.fly.ai.wdtag.WdTagService.detect(Lai/djl/modality/cv/Image;)Lai/djl/modality/Classifications;+17
j com.fly.ai.wdtag.WdTagService.tag(Lai/djl/modality/cv/Image;)Lai/djl/modality/Classifications;+2
j com.fly.ai.wdtag.WdTagService.tag(Ljava/io/InputStream;)Lai/djl/modality/Classifications;+10
j com.fly.ai.wdtag.WdTagController.tag(Lorg/springframework/web/multipart/MultipartFile;)Ljava/util/List;+10
v ~StubRoutines::call_stub
J 2923 jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Ljava/lang/reflect/Method;Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object; java.base@17.0.8 (0 bytes) @ 0x000001712a9e2327 [0x000001712a9e22a0+0x0000000000000087]
J 2922 c1 jdk.internal.reflect.NativeMethodAccessorImpl.invoke(Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object; java.base@17.0.8 (137 bytes) @ 0x000001712a9e31cc [0x000001712a9e2e80+0x000000000000034c]
J 2373 c1 jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object; java.base@17.0.8 (10 bytes) @ 0x000001712a8f7564 [0x000001712a8f7520+0x0000000000000044]
J 2399 c1 java.lang.reflect.Method.invoke(Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object; java.base@17.0.8 (65 bytes) @ 0x000001712a90266c [0x000001712a902560+0x000000000000010c]
j org.springframework.web.method.support.InvocableHandlerMethod.doInvoke([Ljava/lang/Object;)Ljava/lang/Object;+28
j org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(Lorg/springframework/web/context/request/NativeWebRequest;Lorg/springframework/web/method/support/ModelAndViewContainer;[Ljava/lang/Object;)Ljava/lang/Object;+54
j org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(Lorg/springframework/web/context/request/ServletWebRequest;Lorg/springframework/web/method/support/ModelAndViewContainer;[Ljava/lang/Object;)V+4
j org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Lorg/springframework/web/method/HandlerMethod;)Lorg/springframework/web/servlet/ModelAndView;+244
j org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Lorg/springframework/web/method/HandlerMethod;)Lorg/springframework/web/servlet/ModelAndView;+81
j org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Ljava/lang/Object;)Lorg/springframework/web/servlet/ModelAndView;+7
j org.springframework.web.servlet.DispatcherServlet.doDispatch(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+259
j org.springframework.web.servlet.DispatcherServlet.doService(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+241
j org.springframework.web.servlet.FrameworkServlet.processRequest(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+71
j org.springframework.web.servlet.FrameworkServlet.doPost(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+3
j javax.servlet.http.HttpServlet.service(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+149
j org.springframework.web.servlet.FrameworkServlet.service(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;)V+33
j javax.servlet.http.HttpServlet.service(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+36
j org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+304
j org.apache.catalina.core.ApplicationFilterChain.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+99
j org.apache.tomcat.websocket.server.WsFilter.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;Ljavax/servlet/FilterChain;)V+21
j org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+135
j org.apache.catalina.core.ApplicationFilterChain.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+99
j org.springframework.web.filter.RequestContextFilter.doFilterInternal(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Ljavax/servlet/FilterChain;)V+21
j org.springframework.web.filter.OncePerRequestFilter.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;Ljavax/servlet/FilterChain;)V+147
j org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+135
j org.apache.catalina.core.ApplicationFilterChain.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+99
j org.springframework.web.filter.FormContentFilter.doFilterInternal(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Ljavax/servlet/FilterChain;)V+38
j org.springframework.web.filter.OncePerRequestFilter.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;Ljavax/servlet/FilterChain;)V+147
j org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+135
j org.apache.catalina.core.ApplicationFilterChain.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+99
j org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(Ljavax/servlet/http/HttpServletRequest;Ljavax/servlet/http/HttpServletResponse;Ljavax/servlet/FilterChain;)V+53
j org.springframework.web.filter.OncePerRequestFilter.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;Ljavax/servlet/FilterChain;)V+147
j org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+135
j org.apache.catalina.core.ApplicationFilterChain.doFilter(Ljavax/servlet/ServletRequest;Ljavax/servlet/ServletResponse;)V+99
j org.apache.catalina.core.StandardWrapperValve.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+694
j org.apache.catalina.core.StandardContextValve.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+169
j org.apache.catalina.authenticator.AuthenticatorBase.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+260
j org.apache.catalina.core.StandardHostValve.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+128
j org.apache.catalina.valves.ErrorReportValve.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+6
j org.apache.catalina.core.StandardEngineValve.invoke(Lorg/apache/catalina/connector/Request;Lorg/apache/catalina/connector/Response;)V+59
j org.apache.catalina.connector.CoyoteAdapter.service(Lorg/apache/coyote/Request;Lorg/apache/coyote/Response;)V+203
j org.apache.coyote.http11.Http11Processor.service(Lorg/apache/tomcat/util/net/SocketWrapperBase;)Lorg/apache/tomcat/util/net/AbstractEndpoint$Handler$SocketState;+796
j org.apache.coyote.AbstractProcessorLight.process(Lorg/apache/tomcat/util/net/SocketWrapperBase;Lorg/apache/tomcat/util/net/SocketEvent;)Lorg/apache/tomcat/util/net/AbstractEndpoint$Handler$SocketState;+170
j org.apache.coyote.AbstractProtocol$ConnectionHandler.process(Lorg/apache/tomcat/util/net/SocketWrapperBase;Lorg/apache/tomcat/util/net/SocketEvent;)Lorg/apache/tomcat/util/net/AbstractEndpoint$Handler$SocketState;+495
j org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun()V+216
j org.apache.tomcat.util.net.SocketProcessorBase.run()V+21
j org.apache.tomcat.util.threads.ThreadPoolExecutor.runWorker(Lorg/apache/tomcat/util/threads/ThreadPoolExecutor$Worker;)V+92
j org.apache.tomcat.util.threads.ThreadPoolExecutor$Worker.run()V+5
j org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run()V+4
j java.lang.Thread.run()V+11 java.base@17.0.8
v ~StubRoutines::call_stub

.h5文件和.pb 转化后的文件在python里调用正常
使用tensorflow的java代码,也是正常的

python 3.10.4
tensorflow 2.16.1

@LEEay LEEay added the bug Something isn't working label May 11, 2024
@frankfliu
Copy link
Contributor

Can you provide a mini-reproduce project?

Does it wither older version?

Do you have code that works for TFJava?

@LEEay
Copy link
Author

LEEay commented May 13, 2024

使用的是KichangKim/DeepDanbooru github上的项目,tags v3-20211112-sgd-e28, 然后根据djl文档的转化成pb模型
我的tfjava代码:
Session session = SavedModelBundle.load("E:\Download\mymodel\deepdanbooru-v3-20211112",
"serve").session();

    float[][] input = {
            {2.6327686f, -9.201903f},
            {-1.3209248f, 8.569574f},
            {-5.6642127f, 3.3681698f},
            {9.604832f, 5.9664965f},
            {-0.8812313f, -6.76733f}
    };
    LongNdArray matrix3d = NdArrays.ofLongs(Shape.of(1, 100, 9));
    TInt64 rank3Tensor = TInt64.tensorOf(matrix3d);
    Tensor resultTensor = session.runner()
            .feed("serving_default_inputs:0", rank3Tensor)
            .fetch("StatefulPartitionedCall:0")
            .run().get(0);

    resultTensor.shape();
    session.close();

模型参数不是正确的,是我测试的,但是tfjava会提醒我参数错误,djl这样推理就会程序崩溃

因为参数是python代码,我还没有完全掌握转成djl 的写法,所有使用了前置后置使用python代码实现,模型使用djl实现
代码:
PythonTranslator;
@OverRide
public void prepare(TranslatorContext ctx) throws Exception {
if (predictor == null) {
Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelUrls(getRealUrl("/model/wdtag"))
.optEngine("Python")
.build();
model = criteria.loadModel();
predictor = model.newPredictor();
}
}

@Override
public Classifications processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
    System.out.println("##################################");
    return null;
}

@Override
public NDList processInput(TranslatorContext translatorContext, String s) throws Exception {
    Input input = new Input();
    input.add("data", s);
    input.addProperty("Content-Type", "text/plain");
    input.addProperty("handler", "preprocess");
    Output output = predictor.predict(input);
    if (output.getCode() != 200) {
        throw new TranslateException("Python preprocess() failed: " + output.getMessage());
    }

// NDList list=output.getDataAsNDList(translatorContext.getNDManager());
NDArray ndArray = output.getDataAsNDList(translatorContext.getNDManager()).get(0);
ndArray.setName("serving_default_inputs:0");
System.out.println(new NDList(ndArray).toString());
return new NDList(ndArray);
// return output.getDataAsNDList(translatorContext.getNDManager());
}

加载模型代码:
PythonTranslator translator = new PythonTranslator();
Criteria<String, Classifications> criteria = Criteria.builder()
.optEngine(ENGINE_TENSORFLOW)
.optDevice(device)
.optModelUrls(getRealUrl(prop.getDetectUrl()))
.setTypes(String.class, Classifications.class)
// .optModelName(prop.getModelName())
.optProgress(new ProgressBar())
// .optOption("Tags", "")
// .optOption("SignatureDefKey", "serving_default")
.optTranslator(translator)
.build();
// Engine.getEngine("TensorFlow");
detectionModel = criteria.loadModel();

python的图片前置后置操作是按照KichangKim/DeepDanbooru 代码写的

@LEEay
Copy link
Author

LEEay commented May 31, 2024

有人能帮忙看看吗?或者有社区地址什么的吗?微信群或者qq群?

@frankfliu
Copy link
Contributor

Can you provide your .pb or converted saved model bundle file?

@LEEay
Copy link
Author

LEEay commented Jun 3, 2024

模型地址https://github.com/KichangKim/DeepDanbooru/tags, 使用的v3-20211112-sgd-e28 这个tag,然后使用djl官方的代码转成的.pb
代码:
loaded_model = tf.keras.models.load_model(r"E:\Download\model-resnet_custom_v3.h5", compile=False)
tf.saved_model.save(loaded_model, r"E:\Download\1")

@frankfliu
Copy link
Contributor

@LEEay

I looked into this model, it seems broken after converted to SavedModelBunndle, I use the following TFJava code to run it:

        Path path = Paths.get("resnet").toRealPath();
        Session session = SavedModelBundle.load(path.toString(), "serve").session();

        FloatNdArray matrix3d = NdArrays.ofFloats(org.tensorflow.ndarray.Shape.of(1, 255, 255, 3));
        TFloat32 rank3Tensor = TFloat32.tensorOf(matrix3d);
        Tensor resultTensor = session.runner()
                .feed("serving_default_inputs:0", rank3Tensor)
                .fetch("StatefulPartitionedCall:0")
                .run().get(0);
        System.out.println(resultTensor.shape());
        session.close();

I got the following error:

org.tensorflow.exceptions.TFFailedPreconditionException: Could not find variable batch_normalization_111/moving_mean. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Resource localhost/batch_normalization_111/moving_mean/N10tensorflow3VarE does not exist.
         [[{{function_node __inference_serving_default_11824}}{{node resnet_custom_v3_1/batch_normalization_111_1/Cast/ReadVariableOp}}]]
        at app//org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:95)
        at app//org.tensorflow.Session.run(Session.java:835)
        at app//org.tensorflow.Session$Runner.runHelper(Session.java:558)
        at app//org.tensorflow.Session$Runner.run(Session.java:485)

You might need work with TFJava team to troubleshoot this issue

I use the following DJL code, it running as expect (the same error as TFJava):

        Criteria<NDList, NDList> criteria = Criteria.builder()
                .optEngine("TensorFlow")
                .optModelPath(path)
                .setTypes(NDList.class, NDList.class)
                .build();

        try (ZooModel<NDList, NDList> model = criteria.loadModel();
             Predictor<NDList, NDList> predictor = model.newPredictor()) {
            NDManager manager = model.getNDManager();
            NDArray array = manager.create(new Shape(1, 512, 512, 3));
            NDList input = new NDList(array);
            NDList result = predictor.predict(input);
            System.out.println(result.get(0).getShape());
        }

@frankfliu
Copy link
Contributor

@LEEay

By the way, you are using image classification model the following code should work for you if you fix your model issue:

        Path path = Paths.get("resnet").toRealPath();
        Criteria<Image, Classifications> criteria =
                Criteria.builder()
                        .setTypes(Image.class, Classifications.class)
                        .optModelPath(path)
                        .optEngine("TensorFlow")
                        .optArgument("width", "512")
                        .optArgument("height", "512")
                        .optArgument("resize", "true")
                        .optArgument("toTensor", "false")
                        .optTranslatorFactory(new ImageClassificationTranslatorFactory())
                        .build();

        Path file = Paths.get("../../../examples/src/test/resources/kitten.jpg");
        Image img = ImageFactory.getInstance().fromFile(file);
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
                Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Classifications result = predictor.predict(img);
            System.out.println(result.best().getClassName());
        }

@LEEay
Copy link
Author

LEEay commented Jun 7, 2024

#3173 (comment)
使用python的代码可以执行转换后的模型文件
import tensorflow as tf
import tensorflow_io as tfio
import skimage

model = tf.keras.models.load_model(r"E:\Download\model-resnet_custom_v3.h5", compile=False)

model = tf.saved_model.load(r"E:\Download\deepdanbooru-v3-20211112")
def evaluate_image(
image_input: str ,model:object, threshold: float
) :
# width = model.input_shape[2]
# height = model.input_shape[1]
width = 512
height = 512

image = load_image_for_evaluate(image_input, width=width, height=height)

image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
print(image)
size = image.shape
print("数组大小:", size)
# y = model.predict(image)[0]
infer = model.signatures["serving_default"]
print(infer.inputs[0])
output = infer(inputs=tf.constant(image))
print(output)
# print(y)
result_dict = {}

# for i, tag in enumerate(tags):
#     result_dict[tag] = y[i]

# for tag in tags:
#     if result_dict[tag] >= threshold:
#         yield tag, result_dict[tag]

def load_image_for_evaluate(
input_ , width: int, height: int, normalize: bool = True
) :
print(width,height)
if True:
image_raw = tf.io.read_file(input_)
try:
image = tf.io.decode_png(image_raw, channels=3)
except:
image = tfio.image.decode_webp(image_raw)
image = tfio.experimental.color.rgba_to_rgb(image)

image = tf.image.resize(
    image,
    size=(height, width),
    method=tf.image.ResizeMethod.AREA,
    preserve_aspect_ratio=True,
)
image = image.numpy()  # EagerTensor to np.array
# print("image-up",image)
image = transform_and_pad_image(image, width, height)
# print("image-down",image)
if normalize:
    image = image / 255.0

return image

def transform_and_pad_image(
image,
target_width,
target_height,
scale=None,
rotation=None,
shift=None,
order=1,
mode="edge",
):
"""
Transform image and pad by edge pixles.
"""
image_width = image.shape[1]
image_height = image.shape[0]
image_array = image

# centerize
t = skimage.transform.AffineTransform(
    translation=(-image_width * 0.5, -image_height * 0.5)
)

if scale:
    t += skimage.transform.AffineTransform(scale=(scale, scale))

if rotation:
    radian = (rotation / 180.0) * math.pi
    t += skimage.transform.AffineTransform(rotation=radian)

t += skimage.transform.AffineTransform(
    translation=(target_width * 0.5, target_height * 0.5)
)

if shift:
    t += skimage.transform.AffineTransform(
        translation=(target_width * shift[0], target_height * shift[1])
    )

warp_shape = (target_height, target_width)

image_array = skimage.transform.warp(
    image_array, (t).inverse, output_shape=warp_shape, order=order, mode=mode
)

return image_array

yz : int =0.5 #生成的标签可信度阈值
evaluate_image(r"E:\Download\111111111111.png",model,yz)

@LEEay
Copy link
Author

LEEay commented Jun 7, 2024

#3173 (comment)
因为模型的示例里的转化图片的方法,在djl里我不会转化,所以我现在使用djl 调用python的前置操作来转化图片,然后用djl调用模型文件来推理。
这个是我的model.py,转化图片代码
`import tensorflow as tf
import tensorflow_io as tfio
import skimage
import logging

from typing import Optional, Any
from djl_python import Input
from djl_python import Output
from djl_python.np_util import to_nd_list

class Processing(object):

def __init__(self):
    self.topK = 5
    self.image_processing = None
    self.mapping = None
    self.initialized = False

def preprocess(self, inputs: Input) -> Output:
    outputs = Output()
    try:
        width = 512
        height = 512
        imageinput = "test"
        batch = inputs.get_batches()
        for i, item in enumerate(batch):
            imageinput = item.get_as_string()
        image = self.load_image_for_evaluate(imageinput, 512, 512)
    
        image_shape = image.shape
        image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
        #print(image)
        #print(dsd)
        outputs.add_as_numpy(image)
        outputs.add_property("content-type", "tensor/ndlist")
        #print(outputs.)
    except Exception as e:
        logging.exception("pre-process failed")
        # error handling
        outputs = Output().error(str(e))

    return outputs

def postprocess(self, inputs: Input) -> Output:
    outputs = Output()
    try:
        data = inputs.get_as_numpy(0)[0]
        print(data)
    except Exception as e:
        logging.exception("post-process failed")
        # error handling
        outputs = Output().error(str(e))

    return outputs

def load_image_for_evaluate(self,
    input_ , width: int, height: int, normalize: bool = True
    )  :
        print(input_)
        if True:
            image_raw = tf.io.read_file(r"E:\Download\111111111111.png")
        try:
            image = tf.io.decode_png(image_raw, channels=3)
        except:
            image = tfio.image.decode_webp(image_raw)
            image = tfio.experimental.color.rgba_to_rgb(image)
        image = tf.image.resize(
            image,
            size=(height, width),
            method=tf.image.ResizeMethod.AREA,
            preserve_aspect_ratio=True,
        )
        image = image.numpy()  # EagerTensor to np.array
        # print("image-up",image)
        image = self.transform_and_pad_image(image, width, height)
        # print("image-down",image)
        if normalize:
            image = image / 255.0
    
        return image
        
def transform_and_pad_image(self,
        image,
        target_width,
        target_height,
        scale=None,
        rotation=None,
        shift=None,
        order=1,
        mode="edge",
    ):
        """
        Transform image and pad by edge pixles.
        """
        image_width = image.shape[1]
        image_height = image.shape[0]
        image_array = image
    
        # centerize
        t = skimage.transform.AffineTransform(
            translation=(-image_width * 0.5, -image_height * 0.5)
        )
    
        if scale:
            t += skimage.transform.AffineTransform(scale=(scale, scale))
    
        if rotation:
            radian = (rotation / 180.0) * math.pi
            t += skimage.transform.AffineTransform(rotation=radian)
    
        t += skimage.transform.AffineTransform(
            translation=(target_width * 0.5, target_height * 0.5)
        )
    
        if shift:
            t += skimage.transform.AffineTransform(
                translation=(target_width * shift[0], target_height * shift[1])
            )
    
        warp_shape = (target_height, target_width)
    
        image_array = skimage.transform.warp(
            image_array, (t).inverse, output_shape=warp_shape, order=order, mode=mode
        )
    
        return image_array

_service = Processing()

def preprocess(inputs: Input) -> Output:
return _service.preprocess(inputs)

def postprocess(inputs: Input) -> Output:
return _service.postprocess(inputs)

def handle(inputs: Input) -> Optional[Output]:
return None`

djl调用python处理图片代码
`public class PythonTranslator implements NoBatchifyTranslator<String, Classifications> {

private ZooModel<Input, Output> model;
private Predictor<Input, Output> predictor;

@Override
public void prepare(TranslatorContext ctx) throws Exception {
    if (predictor == null) {
        Criteria<Input, Output> criteria =
                Criteria.builder()
                        .setTypes(Input.class, Output.class)
                        .optModelUrls(getRealUrl("/model/wdtag"))
                        .optEngine("Python")
                        .build();
        model = criteria.loadModel();
        predictor = model.newPredictor();
    }
}

@Override
public Classifications processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
    System.out.println("##################################");
    return null;
}

@Override
public NDList processInput(TranslatorContext translatorContext, String s) throws Exception {
    Input input = new Input();
    input.add("data", s);
    input.addProperty("Content-Type", "text/plain");
    input.addProperty("handler", "preprocess");
    Output output = predictor.predict(input);
    if (output.getCode() != 200) {
        throw new TranslateException("Python preprocess() failed: " + output.getMessage());
    }

// NDList list=output.getDataAsNDList(translatorContext.getNDManager());
NDArray ndArray = output.getDataAsNDList(translatorContext.getNDManager()).get(0);
ndArray.setName("serving_default_inputs:0");
System.out.println(new NDList(ndArray).toString());
return new NDList(ndArray);
// return output.getDataAsNDList(translatorContext.getNDManager());
}
}`

推理代码
Device device = Device.Type.CPU.equalsIgnoreCase(prop.getDeviceType()) ? Device.cpu() : Device.gpu(); PythonTranslator translator = new PythonTranslator(); // Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).setPipeline(new Pipeline(new ToTensor())).optApplySoftmax(true).build(); //加载神经网络模型,文字检测和识别都使用百度的PaddlePaddle Criteria<String, Classifications> criteria = Criteria.builder() .optEngine(ENGINE_TENSORFLOW) .optDevice(device) .optModelUrls(getRealUrl(prop.getDetectUrl())) .setTypes(String.class, Classifications.class) // .optModelName(prop.getModelName()) .optProgress(new ProgressBar()) // .optOption("Tags", "") // .optOption("SignatureDefKey", "serving_default") .optTranslator(translator) .build(); // Engine.getEngine("TensorFlow"); detectionModel = criteria.loadModel();

@LEEay
Copy link
Author

LEEay commented Jun 7, 2024

这个是加载模型的代码
Device device = Device.Type.CPU.equalsIgnoreCase(prop.getDeviceType()) ? Device.cpu() : Device.gpu();
PythonTranslator translator = new PythonTranslator();
Criteria<String, Classifications> criteria = Criteria.builder()
.optEngine(ENGINE_TENSORFLOW)
.optDevice(device)
.optModelUrls(getRealUrl(prop.getDetectUrl()))
.setTypes(String.class, Classifications.class)
.optProgress(new ProgressBar())
.optTranslator(translator)
.build();
detectionModel = criteria.loadModel();

@frankfliu
Copy link
Contributor

The exception comes from TFJava code:

org.tensorflow.exceptions.TFFailedPreconditionException: Could not find variable batch_normalization_111/moving_mean. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Resource localhost/batch_normalization_111/moving_mean/N10tensorflow3VarE does not exist.
         [[{{function_node __inference_serving_default_11824}}{{node resnet_custom_v3_1/batch_normalization_111_1/Cast/ReadVariableOp}}]]
        at app//org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:95)
        at app//org.tensorflow.Session.run(Session.java:835)
        at app//org.tensorflow.Session$Runner.runHelper(Session.java:558)
        at app//org.tensorflow.Session$Runner.run(Session.java:485)

I think you'd better ask TFjava github to solve this issue first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants