In [1]:
from tensorflow import keras
keras.__version__

'2.4.0'

To start using zoo.orca, we need to first initialize orca context. Here we specify local or distributed mode. In this example, we choose the local mode.

In [2]:
from zoo.orca import init_orca_context, stop_orca_context
from zoo.orca import OrcaContext

OrcaContext.log_output = True # recommended to set it to True when running Analytics Zoo in Jupyter notebook (this will display terminal's  stdout and stderr in the Jupyter notebook).

cluster_mode = "local"

if cluster_mode == "local":  
    init_orca_context(cluster_mode="local", cores=4) # run in local mode
elif cluster_mode == "k8s":  
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2) # run on K8s cluster
elif cluster_mode == "yarn":  
    init_orca_context(cluster_mode="yarn-client", num_nodes=2, cores=2) # run on Hadoop YARN cluster

Initializing orca context
Current pyspark location is : /intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/__init__.py
Start to getOrCreate SparkContext
pyspark_submit_args is:  --driver-class-path /home/zhenhao/analytics-zoo/zoo/target/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-dist-all/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-jar-with-dependencies.jar pyspark-shell 


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/zhenhao/analytics-zoo/zoo/target/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-dist-all/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-jar-with-dependencies.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/intern/spark/spark-2.4.3-bin-hadoop2.7/jars/slf4j-log4j12-1.7.16.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


2021-02-02 20:16:20 WARN  Utils:66 - Your hostname, intern01 resolves to a loopback address: 127.0.1.1; using 10.239.44.107 instead (on interface eno1)
2021-02-02 20:16:20 WARN  Utils:66 - Set SPARK_LOCAL_IP if you need to bind to another address
2021-02-02 20:16:21 WARN  NativeCodeLoader:62 - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
cls.getname: com.intel.analytics.bigdl.python.api.Sample
BigDLBasePickler registering: bigdl.util.common  Sample
cls.getname: com.intel.analytics.bigdl.python.api.EvaluatedResult
BigDLBasePickler registering: bigdl.util.common  EvaluatedResult
cls.getname: com.intel.analytics.bigdl.python.api.JTensor
BigDLBasePickler registering: bigdl.util.common  JTensor
cls.getname: com.intel.analytics.bigdl.python.api.JActivity
Successfully got a SparkContextBigDLBasePickler registering: bigdl.util.common  JActivity




User settings:

   KMP_AFFINITY=granularity=fine,compact,1,0
   KMP_BLOCKTIME=0
   KMP_SETTINGS=1
   OMP_NUM_THREADS=1

Effective settings:

   KMP_ABORT_DELAY=0
   KMP_ADAPTIVE_LOCK_PROPS='1,1024'
   KMP_ALIGN_ALLOC=64
   KMP_ALL_THREADPRIVATE=128
   KMP_ATOMIC_MODE=2
   KMP_BLOCKTIME=0
   KMP_CPUINFO_FILE: value is not defined
   KMP_DETERMINISTIC_REDUCTION=false
   KMP_DEVICE_THREAD_LIMIT=2147483647
   KMP_DISP_HAND_THREAD=false
   KMP_DISP_NUM_BUFFERS=7
   KMP_DUPLICATE_LIB_OK=false
   KMP_FORCE_REDUCTION: value is not defined
   KMP_FOREIGN_THREADS_THREADPRIVATE=true
   KMP_FORKJOIN_BARRIER='2,2'
   KMP_FORKJOIN_BARRIER_PATTERN='hyper,hyper'
   KMP_FORKJOIN_FRAMES=true
   KMP_FORKJOIN_FRAMES_MODE=3
   KMP_GTID_MODE=3
   KMP_HANDLE_SIGNALS=false
   KMP_HOT_TEAMS_MAX_LEVEL=1
   KMP_HOT_TEAMS_MODE=0
   KMP_INIT_AT_FORK=true
   KMP_INIT_WAIT=2048
   KMP_ITT_PREPARE_DELAY=0
   KMP_LIBRARY=throughput
   KMP_LOCK_KIND=queuing
   KMP_MALLOC_POOL_INCR=1M
   KMP_NEXT_WAIT=1024
   KMP_NUM_L

# Text generation with LSTM
This notebook contains the code samples found in Chapter 8, Section 1 of Deep Learning with Python. Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.

-----------------------------------------

[...]

## Implementing character-level LSTM text generation
Let's put these ideas in practice in a Keras implementation. The first thing we need is a lot of text data that we can use to learn a language model. You could use any sufficiently large text file or set of text files -- Wikipedia, the Lord of the Rings, etc. In this example we will use some of the writings of Nietzsche, the late-19th century German philosopher (translated to English). The language model we will learn will thus be specifically a model of Nietzsche's writing style and topics of choice, rather than a more generic model of the English language.

## Preparing the data
Let's start by downloading the corpus and converting it to lowercase:

In [3]:
import numpy as np

path = keras.utils.get_file(
    'nietzsche.txt',
    origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
text = open(path).read().lower()
print('Corpus length:', len(text))

Corpus length: 600893


Next, we will extract partially-overlapping sequences of length maxlen, one-hot encode them and pack them in a 3D Numpy array `x` of shape `(sequences, maxlen, unique_characters)`. Simultaneously, we prepare a array `y` containing the corresponding targets: the one-hot encoded characters that come right after each extracted sequence.

In [4]:
# Length of extracted character sequences
maxlen = 60

# We sample a new sequence every `step` characters
step = 3

# This holds our extracted sequences
sentences = []

# This holds the targets (the follow-up characters)
next_chars = []

for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i: i + maxlen])
    next_chars.append(text[i + maxlen])
print('Number of sequences:', len(sentences))

# List of unique characters in the corpus
chars = sorted(list(set(text)))
print('Unique characters:', len(chars))
# Dictionary mapping unique characters to their index in `chars`
char_indices = dict((char, chars.index(char)) for char in chars)

# Next, one-hot encode the characters into binary arrays.
print('Vectorization...')
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        x[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

Number of sequences: 200278
Unique characters: 57
Vectorization...


In [5]:
x.shape

(200278, 60, 57)

In [6]:
y.shape

(200278, 57)


## Building the network
Our network is a single LSTM layer followed by a Dense classifier and softmax over all possible characters. But let us note that recurrent neural networks are not the only way to do sequence data generation; 1D convnets also have proven extremely successful at it in recent times. Since our targets are one-hot encoded, we will use `categorical_crossentropy` as the loss to train the model:

In [5]:
from tensorflow.keras import layers

def build_model(config):
    model = keras.Sequential()
    model.add(layers.LSTM(128, input_shape=(maxlen, len(chars))))
    model.add(layers.Dense(len(chars), activation='softmax'))
    optimizer = keras.optimizers.RMSprop(lr=0.01)
    model.compile(loss='categorical_crossentropy', optimizer=optimizer)
    return model

We also create a `data_creator` function that help us import the data into the orca estimator.

In [6]:
import tensorflow as tf

def train_data_creator(config):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.repeat()
    dataset = dataset.shuffle(1000)
    dataset = dataset.batch(config["batch_size"])
    return dataset

Training the language model and sampling from it
Given a trained model and a seed text snippet, we generate new text by repeatedly:

* 1) Drawing from the model a probability distribution over the next character given the text available so far
* 2) Reweighting the distribution to a certain "temperature"
* 3) Sampling the next character at random according to the reweighted distribution
* 4) Adding the new character at the end of the available text

This is the code we use to reweight the original probability distribution coming out of the model, and draw a character index from it (the "sampling function"):

In [7]:
def sample(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

Finally, this is the loop where we repeatedly train and generated text. We start generating text using a range of different temperatures after every epoch. This allows us to see how the generated text evolves as the model starts converging, as well as the impact of temperature in the sampling strategy.

In [8]:
from zoo.orca.learn.tf2 import Estimator
import random
import sys

batch_size=128
est = Estimator.from_keras(model_creator=build_model,
                           config={},
                           workers_per_node=1,
                           verbose=0)

2021-02-02 20:16:44,936	INFO resource_spec.py:212 -- Starting Ray with 17.68 GiB memory available for workers and up to 8.85 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


{'node_ip_address': '10.239.44.107', 'redis_address': '10.239.44.107:23207', 'object_store_address': '/tmp/ray/session_2021-02-02_20-16-44_935881_18041/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2021-02-02_20-16-44_935881_18041/sockets/raylet', 'webui_url': None, 'session_dir': '/tmp/ray/session_2021-02-02_20-16-44_935881_18041'}
[2m[36m(pid=18170)[0m 2021-02-02 20:16:46.007307: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
[2m[36m(pid=18170)[0m 2021-02-02 20:16:46.007333: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
[2m[36m(pid=18170)[0m 2021-02-02 20:16:46.838793: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared obje

In [19]:
from zoo.orca.data import XShards

def transform_to_dict(data):
    return {"x": data[0]}


for epoch in range(1, 60):
    print('epoch', epoch)
    # Fit the model for 1 epoch on the available training data
    stats = est.fit(train_data_creator, 
                    epochs=1,
                    batch_size=batch_size,
                    steps_per_epoch=200278 // batch_size,
                    verbose=0)
    # Select a text seed at random
    start_index = random.randint(0, len(text) - maxlen - 1)
    generated_text = text[start_index: start_index + maxlen]
    print('--- Generating with seed: "' + generated_text + '"')

    for temperature in [0.2, 0.5, 1.0, 1.2]:
        print('------ temperature:', temperature)
        sys.stdout.write(generated_text)

        # We generate 400 characters
        for i in range(400):
            sampled = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(generated_text):
                sampled[0, t, char_indices[char]] = 1.
                
            sample_shards = XShards.partition({"x": sampled})
            #sample_shards = sample_shards.transform_shard(transform_to_dict)

            preds = est.predict(sample_shards)
            next_index = sample(preds, temperature)
            next_char = chars[next_index]

            generated_text += next_char
            generated_text = generated_text[1:]

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()
    

epoch 1
[2m[36m(pid=14333)[0m Instructions for updating:
[2m[36m(pid=14333)[0m Use `tf.data.Iterator.get_next_as_optional()` instead.




--- Generating with seed: " been only too glad to
look upon itself as the ultimate end "
------ temperature: 0.2
 been only too glad to
look upon itself as the ultimate end Try to unpersist an uncached rdd
Try to unpersist an uncached rdd
Try to unpersist an uncached rdd




2021-02-02 15:52:52 ERROR Executor:91 - Exception in task 1.0 in stage 10.0 (TID 41)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=14333, ip=10.239.44.107)
  Fi

[2m[36m(pid=14333)[0m   numdigits = int(np.log10(self.target)) + 1
[2m[36m(pid=14333)[0m 


[2m[36m(pid=14333)[0m   numdigits = int(np.log10(self.target)) + 1
[2m[36m(pid=14333)[0m 
[2m[36m(pid=14333)[0m   numdigits = int(np.log10(self.target)) + 1
[2m[36m(pid=14333)[0m 


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 10.0 failed 1 times, most recent failure: Lost task 1.0 in stage 10.0 (TID 41, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=14333, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1613, in predict
    callbacks.on_predict_end()
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 582, in on_predict_end
    callback.on_predict_end(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 979, in on_predict_end
    self._finalize_progbar(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 1026, in _finalize_progbar
    self.progbar.update(self.seen, list(logs.items()), finalize=True)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 581, in update
    numdigits = int(np.log10(self.target)) + 1
OverflowError: cannot convert float infinity to integer

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1182)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:1091)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:882)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:335)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:286)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:166)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=14333, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1613, in predict
    callbacks.on_predict_end()
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 582, in on_predict_end
    callback.on_predict_end(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 979, in on_predict_end
    self._finalize_progbar(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 1026, in _finalize_progbar
    self.progbar.update(self.seen, list(logs.items()), finalize=True)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 581, in update
    numdigits = int(np.log10(self.target)) + 1
OverflowError: cannot convert float infinity to integer

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1182)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:1091)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:882)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:335)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:286)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more




[2m[36m(pid=14333)[0m 
[2m[36m(pid=14333)[0m 


2021-02-02 15:52:57,190	ERROR worker.py:1012 -- Possible unhandled error from worker: [36mray::TFRunner.predict()[39m (pid=14333, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-package

In [21]:
est.shutdown()
#stop_orca_context()

[2m[36m(pid=14333)[0m 2021-02-02 16:13:41.082677: W tensorflow/core/common_runtime/eager/context.cc:566] Unable to destroy server_ object, so releasing instead. Servers don't support clean shutdown.


In [15]:
from zoo.orca.data import XShards

start_index = random.randint(0, len(text) - maxlen - 1)
generated_text = text[start_index: start_index + maxlen]
print('--- Generating with seed: "' + generated_text + '"')

for temperature in [0.2, 0.5, 1.0, 1.2]:
    print('------ temperature:', temperature)
    sys.stdout.write(generated_text)

    # We generate 400 characters
    for i in range(400):
        sampled = np.zeros((1, maxlen, len(chars)))
        for t, char in enumerate(generated_text):
            sampled[0, t, char_indices[char]] = 1.
                

--- Generating with seed: "y regard him on
that account with some apprehension, they wo"
------ temperature: 0.2
y regard him on
that account with some apprehension, they wo------ temperature: 0.5
y regard him on
that account with some apprehension, they wo------ temperature: 1.0
y regard him on
that account with some apprehension, they wo------ temperature: 1.2
y regard him on
that account with some apprehension, they wo

In [17]:
sample_shards = XShards.partition({"x": sampled})

[Stage 0:>                                                          (0 + 4) / 4]                                                                                

In [18]:
preds = est.predict(sample_shards, batch_size=1).collect()



2021-02-02 20:27:33 ERROR Executor:91 - Exception in task 1.0 in stage 2.0 (TID 9)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=18170, ip=10.239.44.107)
  File

[2m[36m(pid=18170)[0m 
[2m[36m(pid=18170)[0m   numdigits = int(np.log10(self.target)) + 1


[2m[36m(pid=18170)[0m 
[2m[36m(pid=18170)[0m   numdigits = int(np.log10(self.target)) + 1
[2m[36m(pid=18170)[0m 
[2m[36m(pid=18170)[0m   numdigits = int(np.log10(self.target)) + 1


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 2.0 failed 1 times, most recent failure: Lost task 1.0 in stage 2.0 (TID 9, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=18170, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1613, in predict
    callbacks.on_predict_end()
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 582, in on_predict_end
    callback.on_predict_end(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 979, in on_predict_end
    self._finalize_progbar(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 1026, in _finalize_progbar
    self.progbar.update(self.seen, list(logs.items()), finalize=True)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 581, in update
    numdigits = int(np.log10(self.target)) + 1
OverflowError: cannot convert float infinity to integer

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1182)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:1091)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:882)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:335)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:286)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:166)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 133, in <lambda>
    lambda idx, _: get_from_ray(idx, address, password, meta_store_name))
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/data/ray_rdd.py", line 99, in get_from_ray
    partition = ray.get(object_id)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/ray/worker.py", line 1513, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OverflowError): [36mray::TFRunner.predict()[39m (pid=18170, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1613, in predict
    callbacks.on_predict_end()
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 582, in on_predict_end
    callback.on_predict_end(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 979, in on_predict_end
    self._finalize_progbar(logs)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 1026, in _finalize_progbar
    self.progbar.update(self.seen, list(logs.items()), finalize=True)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 581, in update
    numdigits = int(np.log10(self.target)) + 1
OverflowError: cannot convert float infinity to integer

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:349)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1182)
	at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:1091)
	at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1156)
	at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:882)
	at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:335)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:286)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more




[2m[36m(pid=18170)[0m 
[2m[36m(pid=18170)[0m 


2021-02-02 20:27:38,505	ERROR worker.py:1012 -- Possible unhandled error from worker: [36mray::TFRunner.predict()[39m (pid=18170, ip=10.239.44.107)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in predict
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 451, in <listcomp>
    new_part = [predict_fn(shard) for shard in partition]
  File "/home/zhenhao/analytics-zoo/pyzoo/zoo/orca/learn/tf2/tf_runner.py", line 448, in predict_fn
    y = local_model.predict(shard["x"], **params)
  File "/home/zhenhao/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/zhenhao/.local/lib/python3.7/site-package

In [23]:
sampled[0].sum(axis=1)

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])