---
title: Tensorflow2 加速 predict
tags: 小书匠,tensorflow2,predict,inference,function
grammar_cjkRuby: true
renderNumberedHeading: true
---

[toc]

# Tensorflow2 加速 predict

有的时候，`model.predict` 的预测速度很慢，此时我们可以直接调用 `model.call` 接口，并结合 `tf.function` 来代替 `model.predict`，会带来很大的加速。

In [7]:
from tensorflow.keras import Input, Model
import time
import numpy as np

x = Input(shape=(1, 1))
model = Model(inputs=x, outputs=x)

t = time.time()
i = 0
while i<100:
    model.predict(np.zeros((1, 1, 1)))
    i += 1
print(time.time() - t)

2.7391841411590576


In [10]:
from tensorflow.keras import Input, Model
import time
import numpy as np
import tensorflow as tf

x = Input(shape=(1, 1))
model = Model(inputs=x, outputs=x)
model.call = tf.function(model.call)

t = time.time()
i = 0
while i<100:
    # 这样会调用 model.call 接口
    model(np.zeros((1, 1, 1)), training=False)
    i += 1
print(time.time() - t)

0.03891730308532715


# References
- http://localhost:8888/lab/tree/DL-Project/learnTensorflow/Tensorflow2%20%E5%8A%A0%E9%80%9F%20predict.ipynb
- [model.predict is much slower on TF 2.1+ · Issue #40261 · tensorflow/tensorflow](https://github.com/tensorflow/tensorflow/issues/40261)