In [None]:
global project, model
project = "1"
model = "featurestore"

In [None]:
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#!/usr/bin/env python2.7

"""A client that talks to tensorflow_model_server loaded with mnist model.

The client downloads test images of mnist data set, queries the service with
such test images to get predictions, and calculates the inference error rate.

Typical usage example:

    mnist_client.py --num_tests=100 --server=localhost:9000
"""

from __future__ import print_function

import sys
import threading

import json

# This is a placeholder for a Google-internal import.

import numpy
import requests
import tensorflow as tf
import time
from hops import util
from hops import constants
from multiprocessing import Pool, TimeoutError
from concurrent.futures import ThreadPoolExecutor, wait, as_completed

from tensorflow.examples.tutorials.mnist import input_data as mnist_input_data


tf.app.flags.DEFINE_integer('concurrency', 1,
                            'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_integer('num_tests', 5, 'Number of test images')
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory. ')
FLAGS = tf.app.flags.FLAGS
pool = None


class _ResultCounter(object):
    """Counter for the prediction results."""

    def __init__(self, num_tests, concurrency):
        self._num_tests = num_tests
        self._concurrency = concurrency
        self._error = 0
        self._done = 0
        self._active = 0
        self._condition = threading.Condition()

    def inc_error(self):
        with self._condition:
            self._error += 1

    def inc_done(self):
        with self._condition:
            self._done += 1
            self._condition.notify()

    def dec_active(self):
        with self._condition:
            self._active -= 1
            self._condition.notify()

    def get_error_rate(self):
        with self._condition:
            while self._done != self._num_tests:
                self._condition.wait()
            return self._error / float(self._num_tests)

    def throttle(self):
        with self._condition:
            while self._active == self._concurrency:
                self._condition.wait()
            self._active += 1


def _create_rpc_callback(label, result_counter):
    """Creates RPC callback function.

    Args:
      label: The correct label for the predicted example.
      result_counter: Counter for the prediction result.
    Returns:
      The callback function.
    """
    def _callback(result_future):
        """Callback function.

        Calculates the statistics for the prediction result.

        Args:
          result_future: Result future of the RPC.
        """
        exception = result_future.exception()
        if exception:
            result_counter.inc_error()
            print(exception)
        else:
            sys.stdout.write('.')
            sys.stdout.flush()
            response = numpy.array(
                result_future.result().outputs['scores'].float_val)
            prediction = numpy.argmax(response)
            if label != prediction:
                result_counter.inc_error()
        result_counter.inc_done()
        result_counter.dec_active()
    return _callback


def do_inference(hostport, work_dir, concurrency, num_tests):
    """Tests PredictionService with concurrent requests.

    Args:
      hostport: Host:port address of the PredictionService.
      work_dir: The full path of working directory for test data set.
      concurrency: Maximum number of concurrent requests.
      num_tests: Number of test images to use.

    Returns:
      The classification error rate.

    Raises:
      IOError: An error occurred processing test data set.
    """

    test_data_set = mnist_input_data.read_data_sets(work_dir).test
    result_counter = _ResultCounter(num_tests, concurrency)

    futures = [pool.submit(send_single_request,test_data_set) for i in range(num_tests)]
    results = [r.result() for r in as_completed(futures)]


def send_single_request(test_data_set):
    image, label = test_data_set.next_batch(1)
    request_data={}
#         request_data['signature_name'] = 'dense_input'
    request_data['instances'] = [[1,2,3,4]]
    headers={}
    headers[constants.HTTP_CONFIG.HTTP_AUTHORIZATION] = "Bearer " + util.get_jwt()
    inference_url = "https://localhost:8181/hopsworks-api/api/project/" + project + "/inference/models/" + model + ":predict"
    r = requests.post(inference_url, headers=headers, data=json.dumps(request_data), verify=False)
    print(r.text)



    #if FLAGS.num_tests > 10000:
    #  print('num_tests should not be greater than 10k')
    #  return
    #if not FLAGS.server:
    #  print('please specify server host:port')
    #  return
    #error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
    #                          FLAGS.concurrency, FLAGS.num_tests)
    #print('\nInference error rate: %s%%dfadsfasdf' % (error_rate * 100))
start = time.time()
global pool
pool = ThreadPoolExecutor(FLAGS.concurrency)
do_inference(FLAGS.server, FLAGS.work_dir, FLAGS.concurrency, FLAGS.num_tests)
end = time.time()
print(end - start)