### This code connects to GCP Kubernetes workload and perform simple MNIST prediction. Observe the following
1. Most time is spent on connecting to the server, while the code in GCP would be much quicker since server is closer.
2. After the first or second connection, the thread moves to another clientID and already "warms up", so the server connection is faster.

In [3]:
import logging, grpc, time
import numpy as np

import server_tools_pb2
import server_tools_pb2_grpc

PORT = '50051'
f = open("IP.txt")
IP = f.read()
if IP[-1] == '\n':
    IP = IP[:-1]
f.close()

# Set this flag to indicate whether the client should wait until the prediction
# is finished or check in with the server periodically until it is 
WAIT = False

# Change this parameter depending on how many images you want to send at once.
# There is an upper limit (668 on my machine) where the size of the package 
# becomes too great and the client will throw an error.
NUM_IMAGES = 668

def run():
    # Get a handle to the server
    channel = grpc.insecure_channel(IP + ':' + PORT)
    stub = server_tools_pb2_grpc.MnistServerStub(channel)

    # Get a client ID which you need to talk to the server
    try:
        response = stub.RequestClientID(server_tools_pb2.NullParam())
    except:
        print("Connection to the server could not be established. Press enter to try again.")
        return
    client_id = response.new_id

    # Generate lots of data
    data = np.random.rand(NUM_IMAGES, 28, 28, 1)
    data = data.tostring()

    # Send the data to the server and receive an answer
    start_time = time.time()
    if WAIT:
        print("Submitting images and waiting")
        response = stub.StartJobWait(server_tools_pb2.DataMessage(images=data, client_id=client_id, batch_size=32))
    else:
        print("Submitting images")
        try:
            idPackage = stub.StartJobNoWait(server_tools_pb2.DataMessage(images=data, client_id=client_id, batch_size=32))
        except:
            print("NUM_IMAGES is too high")
            return
        response = stub.ProbeJob(idPackage)
        print("Checking in with server")
        while not response.complete:
            response = stub.ProbeJob(idPackage)
            if response.error != '':
                print(response.error)
                break

    # Print output
    original_array = np.frombuffer(response.prediction).reshape(NUM_IMAGES, 10)
    whole_time = time.time() - start_time
    print("Total time:", whole_time)
    print("Predict time:", response.infer_time)
    print("Fraction of time spent not predicting:", (1 - response.infer_time / whole_time) * 100, '%')
    channel.close()

In [4]:
logging.basicConfig()
for i in range(5):
    run()

Submitting images
Checking in with server
Total time: 11.17451024055481
Predict time: 0.705836296081543
Fraction of time spent not predicting: 93.68351470545971 %
Submitting images
Checking in with server
Total time: 22.498075246810913
Predict time: 0.8005023002624512
Fraction of time spent not predicting: 96.44190762329359 %
Submitting images
Checking in with server
Total time: 8.726824522018433
Predict time: 0.7406020164489746
Fraction of time spent not predicting: 91.51349938823246 %
Submitting images
Checking in with server
Total time: 8.583730697631836
Predict time: 0.7114744186401367
Fraction of time spent not predicting: 91.7113613683567 %
Submitting images
Checking in with server
Total time: 9.714072942733765
Predict time: 0.7159483432769775
Fraction of time spent not predicting: 92.62978209554711 %
