In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_io as tfio

BATCH_SIZE=64
MAX_EPOCHS=5
NUM_COLUMNS=784

KAFKA_SERVERS="kafka:9092"
KAFKA_CONSUMER_GROUP="mnistcg"
KAFKA_TRAIN_TOPIC="tf.public.mnist_train"
KAFKA_TEST_TOPIC="tf.public.mnist_test"
KAFKA_STREAM_TIMEOUT=9000

In [None]:
def decode_kafka_record(record):
    img_int = tf.io.decode_csv(record.message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(record.key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

train_ds = tfio.IODataset.from_kafka(KAFKA_TRAIN_TOPIC, partition=0, offset=0, servers=KAFKA_SERVERS)
train_ds = train_ds.map(decode_kafka_record)
train_ds = train_ds.batch(BATCH_SIZE)

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(train_ds,epochs=MAX_EPOCHS)

In [None]:
def decode_kafka_stream_record(message, key):
    img_int = tf.io.decode_csv(message, [[0.0] for i in range(NUM_COLUMNS)])
    img_norm = tf.cast(img_int, tf.float32) / 255.
    label_int = tf.strings.to_number(key, out_type=tf.dtypes.int32)
    return (img_norm, label_int)

test_ds = tfio.experimental.streaming.KafkaGroupIODataset(
    topics=[KAFKA_TEST_TOPIC],
    group_id=KAFKA_CONSUMER_GROUP,
    servers=KAFKA_SERVERS,
    stream_timeout=KAFKA_STREAM_TIMEOUT,
    configuration=[
        "session.timeout.ms=10000",
        "max.poll.interval.ms=10000",
        "auto.offset.reset=earliest"
    ],
)

test_ds = test_ds.map(decode_kafka_stream_record)
test_ds = test_ds.batch(BATCH_SIZE)

model.evaluate(test_ds)

In [None]:
def plot_and_predict(pixels):
    test = tf.constant([pixels])
    tf.shape(test)
    test_norm = tf.cast(test, tf.float32) / 255.

    prediction = model.predict(test_norm)
    number = tf.nn.softmax(prediction).numpy().argmax()
    
    pixels_array = np.asarray(pixels)
    raw_img = np.split(pixels_array, 28)
    plt.imshow(raw_img)
    plt.title(number)
    plt.axis("off")

In [None]:
pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,108,43,6,6,6,6,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,84,248,254,254,254,254,254,241,45,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,90,254,254,254,223,173,173,173,253,156,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,79,157,228,245,251,188,63,17,0,0,54,252,132,0,0,0,0,0,0,0,0,0,0,0,0,0,0,32,254,254,254,244,131,0,0,0,0,13,220,254,122,0,0,0,0,0,0,0,0,0,0,0,0,0,0,83,254,225,160,47,0,0,0,0,59,211,254,206,50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,21,14,0,0,0,2,17,146,245,250,194,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,140,140,171,254,254,254,203,55,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,211,254,254,254,254,179,211,254,254,202,171,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,167,233,193,69,16,3,9,16,107,231,248,195,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,73,229,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,26,99,252,254,146,0,0,0,0,0,0,0,0,79,142,0,0,0,0,0,0,0,0,0,26,28,116,147,247,254,239,150,22,0,0,0,0,0,0,0,0,175,230,174,155,66,66,132,174,174,174,174,250,255,254,192,189,99,36,0,0,0,0,0,0,0,0,0,0,106,226,254,254,254,254,254,254,254,254,217,151,80,43,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,7,114,114,114,46,5,5,5,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
plot_and_predict(pixels)

In [None]:
pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,191,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,255,0,64,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,255,255,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,191,0,191,255,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,128,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,255,255,255,255,128,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,0,128,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,64,0,64,191,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,128,0,0,0,64,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,0,0,0,0,0,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,64,0,0,0,0,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,191,128,191,128,191,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,191,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
plot_and_predict(pixels)

In [None]:
manual_ds = tfio.experimental.streaming.KafkaGroupIODataset(
    topics=[KAFKA_TEST_TOPIC],
    group_id="mnistcg2",
    servers=KAFKA_SERVERS,
    stream_timeout=KAFKA_STREAM_TIMEOUT,
    configuration=[
        "session.timeout.ms=10000",
        "max.poll.interval.ms=10000",
        "auto.offset.reset=earliest"
    ],
)

manual_ds = manual_ds.map(decode_kafka_stream_record)

In [None]:
def predict(pixels):
    test = tf.constant([pixels])
    tf.shape(test)
    prediction = model.predict(test)
    return tf.nn.softmax(prediction).numpy().argmax()


iterator = iter(manual_ds)
figure, axis = plt.subplots(2, 2)
for i in range(0,2):
    for j in range(0,2):
        record = iterator.get_next()
        image = record[0].numpy()
        axis[i, j].imshow(np.split(np.asarray(image), 28))
        axis[i, j].set_title(predict(image))
        axis[i, j].axes.get_xaxis().set_visible(False)
        axis[i, j].axes.get_yaxis().set_visible(False)
plt.show()