# Kanji Generation with a stateless decoder

This notebook generates how to build a LSTM MDRNN decoder for generation without using `stateful=True` in the LSTM layers.

What we are going to do is create extra inputs and outputs for the decoder model which will let us add in the two LSTM state vectors ($h$ and $c$) for each LSTM layer and collect them at the output.

At the start of generating a Kanji character, the LSTM state vectors are initialised to zero.

In between generating Kanji strokes, we will store the LSTM state vectors in a variable.

The decoder model will be defined with Keras' functional API style and the model generations will use Tensorflow's eager execution.

This file does not train a model, you'll need to look at the Kanji Generation Example notebook for that. Make sure that the hyperparameters are the sample for the model that is being loaded here.

In [None]:
# Decoding Model:
# uses functional Keras API
# takes LSTM state as extra inputs (2 per LSTM layer)
# returns LSTM state as extra outputs (2 per LSTM layer) 

# imports
from context import * # imports the MDN layer 

# Hyper parameters
HIDDEN_UNITS = 64
OUTPUT_DIMENSION = 3
NUMBER_MIXTURES = 10

inputs = keras.layers.Input(shape=(1,OUTPUT_DIMENSION))
lstm_1_state_h_input = keras.layers.Input(shape=(HIDDEN_UNITS,))
lstm_1_state_c_input = keras.layers.Input(shape=(HIDDEN_UNITS,))
lstm_1_state_input = [lstm_1_state_h_input, lstm_1_state_c_input]
lstm_2_state_h_input = keras.layers.Input(shape=(HIDDEN_UNITS,))
lstm_2_state_c_input = keras.layers.Input(shape=(HIDDEN_UNITS,))
lstm_2_state_input = [lstm_2_state_h_input, lstm_2_state_c_input]
lstm_1, state_h_1, state_c_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True, return_state=True)(inputs, initial_state=lstm_1_state_input)
lstm_2, state_h_2, state_c_2 = keras.layers.LSTM(HIDDEN_UNITS, return_state=True)(lstm_1, initial_state=lstm_2_state_input)
lstm_1_state_output = [state_h_1, state_c_1]
lstm_2_state_output = [state_h_2, state_c_2]
mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)

decoder = keras.Model(inputs=[inputs] + lstm_1_state_input + lstm_2_state_input, 
                                outputs=[mdn_out] + lstm_1_state_output + lstm_2_state_output,
                                name="kanji-decoder")
decoder.summary()
decoder.load_weights('kanji_mdnrnn_model.keras') # load weights independently from file

In [None]:
import numpy as np
# Let's test that the model is working on just one prediction:

def zero_start_position():
    """A zeroed out start position with pen down"""
    out = np.zeros((1, 1, 3), dtype=np.float32)
    out[0, 0, 2] = 1 # set pen down.
    return out

def random_start_position():
    """A random start position with pen down"""
    limit = 5
    out = limit - (2*limit) * np.random.rand(1, 1, 3)
    out[0, 0, 2] = 1 # set pen down.
    return out

def generate_initial_lstm_states(units):
    return [np.zeros((1,units), dtype=np.float32), np.zeros((1,units), dtype=np.float32)] 

start_pos = random_start_position()
start_state_1 = generate_initial_lstm_states(HIDDEN_UNITS)
start_state_2 = generate_initial_lstm_states(HIDDEN_UNITS)

print("Start pos shape:", start_pos.shape)
print("Example state shape:", start_state_1[0].shape)
input_list = [start_pos] + start_state_1 + start_state_2

print("Input shapes:")
for i in input_list:
    print("Shape:", i.shape)
# run one prediction
output_list = decoder(input_list)
print("Output list length:", len(output_list))
print("Output shapes:")
for i in output_list:
    print("Shape:", i.shape)

# This test shows that everything... seems to be working...

# Generating Kanji (again)

Now we can generate some Kanji with this model in the same way as in the previous example.

In [None]:
# Hardmaru's Drawing Functions from write-rnn-tensorflow
# Big hat tip
# Here's the source:
# https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py

import svgwrite
from IPython.display import SVG, display

def get_bounds(data, factor):
    min_x = 0
    max_x = 0
    min_y = 0
    max_y = 0

    abs_x = 0
    abs_y = 0
    for i in range(len(data)):
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x = min(min_x, abs_x)
        min_y = min(min_y, abs_y)
        max_x = max(max_x, abs_x)
        max_y = max(max_y, abs_y)

    return (min_x, max_x, min_y, max_y)

def draw_strokes(data, factor=1, svg_filename='sample.svg'):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)

    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))

    lift_pen = 1

    abs_x = 25 - min_x
    abs_y = 25 - min_y
    p = "M%s,%s " % (abs_x, abs_y)

    command = "m"

    for i in range(len(data)):
        if (lift_pen == 1):
            command = "m"
        elif (command != "l"):
            command = "l"
        else:
            command = ""
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        lift_pen = data[i, 2]
        p += command + str(x) + "," + str(y) + " "

    the_color = "black"
    stroke_width = 1

    dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none"))

    dwg.save()
    display(SVG(dwg.tostring()))

def cutoff_stroke(x):
    return np.greater(x,0.5) * 1.0

In [None]:
import time
# Predict a character and plot the result.
pi_temperature = 2.5 # seems to work well with rather high temperature (2.5)
sigma_temp = 0.1 # seems to work well with low temp

start_pos = random_start_position()
start_state_1 = generate_initial_lstm_states(HIDDEN_UNITS)
start_state_2 = generate_initial_lstm_states(HIDDEN_UNITS)
input_list = [start_pos] + start_state_1 + start_state_2 # five inputs

sketch = [start_pos.reshape(3,)] # starting value for the sketch.

number_of_movements = 400
start_time = time.time()

for i in range(number_of_movements):
    output_list = decoder(input_list)
    mdn_values = output_list[0][0].numpy()
    next_point = mdn.sample_from_output(mdn_values, OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=pi_temperature, sigma_temp=sigma_temp)
    sketch.append(next_point.reshape((3,)))
    states = output_list[1:]
    input_list = [next_point.reshape(1, 1, 3)] + states

print("Finished. That took", round((time.time() - start_time)/number_of_movements, 4), "seconds per generation.")

# Draw the sketch
sketch = np.array(sketch)
sketch.T[2] = cutoff_stroke(sketch.T[2])
draw_strokes(sketch, factor=0.5)