# MDN-transformer with examples

- What kind of data can be predicted by a mixture density network Transformer?
    - Continuous sequential data
- Drawing data and RoboJam Touch Screem would be good examples for this, continuous values yield high resolution in 2d space.

# 1. Kanji Generation

- Firstly, let's try modelling some drawing data for Kanji writing using an MDN-Transformer.

- This work is inspired by previous work "MDN-RNN for Kanji Generation", hardmaru's Kanji tutorial and the original Sketch-RNN repository:

    - http://blog.otoro.net/2015/12/28/recurrent-net-dreams-up-fake-chinese-characters-in-vector-format-with-tensorflow/
    - https://github.com/hardmaru/sketch-rnn

    - The idea is to learn how to draw kanji characters from a dataset of vector representations. 
    - This means learning how to move a pen in 2D space.
    - The data consists of a sequence of pen movements (loations in 2D) and whether the pen is up or down.
    - In this example, we will use one 3D MDN to model everything!

We will end up with a system that will continue writing Kanji given a short sequence, like this:



In [None]:
# Setup and modules
import sys
!{sys.executable} -m pip install keras-mdn-layer 
!{sys.executable} -m pip install tensorflow
!{sys.executable} -m pip install tensorflow-probability
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install pandas
!{sys.executable} -m pip install svgwrite

import mdn
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 
%matplotlib inline
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
%matplotlib notebook

# Only for GPU use:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.keras import backend as K
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
K.set_session(sess)

### Download and process the data set

In [None]:
# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets
# Other datasets in "Sketch 3" format should also work.
import urllib.request
# url = 'https://github.com/hardmaru/sketch-rnn-datasets/raw/master/kanji/kanji.rdp25.npz'  
# urllib.request.urlretrieve(url, './kanji.rdp25.npz')  

### Dataset

Includes about 11000 handwritten kanji characters divied into training, validation, and testing sets.

In [None]:
with np.load('./kanji.rdp25.npz', allow_pickle=True) as data:
    train_set = data['train']
    valid_set = data['valid']
    test_set = data['test']
    
print("Training kanji:", len(train_set))
print("Validation kanji:", len(valid_set))
print("Testing kanji:", len(test_set))

In [None]:
# Functions for slicing up data
def slice_sequence_examples(sequence, num_steps):
    xs = []
    for i in range(len(sequence) - num_steps - 1):
        example = sequence[i: i + num_steps]
        xs.append(example)
    return xs

def seq_to_singleton_format(examples):
    xs = []
    ys = []
    for ex in examples:
        xs.append(ex[:SEQ_LEN])
        ys.append(ex)
    return xs, ys

# Functions for making the data set
def format_dataset(x, y):
    return ({
        "input": x,
        "target": y[:, :-1, :],
    }, y[:, 1:, :])

def make_dataset(X, y):
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(format_dataset)
    return dataset.shuffle(2048).prefetch(16).cache()

In [None]:
# Data shapes
NUM_FEATS = 3
SEQ_LEN = 20
gap_len = 1
batch_size = 128

# Prepare training data as X and Y.
slices = []
for seq in train_set:
    slices += slice_sequence_examples(seq, SEQ_LEN+gap_len)
X, y = seq_to_singleton_format(slices)

X = np.array(X)
y = np.array(y)
train_ds = make_dataset(X, y)
print("Number of training examples:")
print("X:", X.shape)
print("y:", y.shape)
print(train_ds)

### Constructing the MDN Transformer

Our MDN Transformer has the following settings:
- an embedding layer with positional embedding
- a transformer encoder
- a transformer decoder
- a three-dimensional mixture layer with 10 mixtures
- train for sequence length ___
- training for ___ epochs with a batch size of ___

Here's a diagram:


In [None]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Dense(output_dim)
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim
    def call(self, inputs, padding_mask=None):
        length = inputs.shape[1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions
    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
        })
        return config

    
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation="relu"),
             layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
    def call(self, inputs, mask=None):
        attention_output = self.attention(inputs, inputs, attention_mask=mask)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)
    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config


class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation="relu"),
             layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True
    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1),
             tf.constant([1, 1], dtype=tf.int32)], axis=0)
        return tf.tile(mask, mult)
    def call(self, inputs, encoder_outputs, padding_mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask)
        attention_output_1 = self.layernorm_1(inputs + attention_output_1)
        if padding_mask==None:
            attention_output_2 = self.attention_2(
                query=attention_output_1,
                value=encoder_outputs,
                key=encoder_outputs)
        else:
            attention_output_2 = self.attention_2(
                query=attention_output_1,
                value=encoder_outputs,
                key=encoder_outputs,
                attention_mask=padding_mask)
        attention_output_2 = self.layernorm_2(
            attention_output_1 + attention_output_2)
        proj_output = self.dense_proj(attention_output_2)
        return self.layernorm_3(attention_output_2 + proj_output)
    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config

In [None]:
# Training Hyperparameters:
input_dim = 3
sequence_length = 20
target_length = 20
embed_dim = 256
dense_dim = 128
num_heads = 2
output_dim = 3
number_mixtures = 10

EPOCHS = 20
SEED = 2345  # set random seed for reproducibility
random.seed(SEED)
np.random.seed(SEED)

encoder_inputs = keras.Input(shape=(sequence_length, input_dim), dtype="float64", name="input")
x = PositionalEmbedding(sequence_length, input_dim, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
print(encoder_outputs.shape)
# encoder_outputs2 = TransformerEncoder(embed_dim, dense_dim, num_heads)(encoder_outputs)

decoder_inputs = keras.Input(shape=(target_length, input_dim), dtype="float64", name="target")
x = PositionalEmbedding(target_length, input_dim, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, dense_dim, num_heads)(x, encoder_outputs)
# x = TransformerDecoder(embed_dim, dense_dim, num_heads)(x, encoder_outputs2)
x = layers.Dropout(0.2)(x)
decoder_outputs = layers.Dense(input_dim, activation="softmax")(x)
outputs = mdn.MDN(output_dim, number_mixtures) (decoder_outputs)
model = keras.Model([encoder_inputs, decoder_inputs], outputs)
model.compile(loss=mdn.get_mixture_loss_func(output_dim,number_mixtures), 
              optimizer=keras.optimizers.Adam())
model.summary()

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint("full_transformer.keras",
                                    save_best_only=True)
]

history=model.fit(train_ds, batch_size=batch_size, epochs=EPOCHS, callbacks=callbacks)

!mkdir -p saved_model
model.save('my_model_100_128_256_128_2_02')

# print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

In [None]:
plt.figure()
plt.plot(history.history['loss'])
plt.show()

In [None]:
ls saved_model/my_model6

In [None]:
model = keras.models.load_model('saved_model/my_model6')

## Generating drawings

First need some helper functions to view the output.

In [None]:
def zero_start_position():
    """A zeroed out start position with pen down"""
    out = np.zeros((1, 1, 3), dtype=np.float32)
    out[0, 0, 2] = 1 # set pen down.
    return out

def generate_sketch(model, start_pos, num_points=100):
     return None

def cutoff_stroke(x):
    return np.greater(x,0.5) * 1.0

def plot_sketch(sketch_array):
    """Plot a sketch quickly to see what it looks like."""
    sketch_df = pd.DataFrame({'x':sketch_array.T[0],'y':sketch_array.T[1],'z':sketch_array.T[2]})
    sketch_df.x = sketch_df.x.cumsum()
    sketch_df.y = -1 * sketch_df.y.cumsum()
    # Do the plot
    fig = plt.figure(figsize=(8, 8))
    ax1 = fig.add_subplot(111)
    #ax1.scatter(sketch_df.x,sketch_df.y,marker='o', c='r', alpha=1.0)
    # Need to do something with sketch_df.z
    ax1.plot(sketch_df.x,sketch_df.y,'r-')
    plt.show()

## SVG Drawing Function

Here's Hardmaru's Drawing Functions from _write-rnn-tensorflow_. Big hat tip to Hardmaru for this!

Here's the source: https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py

In [None]:
import svgwrite
from IPython.display import SVG, display

def get_bounds(data, factor):
    min_x = 0
    max_x = 0
    min_y = 0
    max_y = 0

    abs_x = 0
    abs_y = 0
    for i in range(len(data)):
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x = min(min_x, abs_x)
        min_y = min(min_y, abs_y)
        max_x = max(max_x, abs_x)
        max_y = max(max_y, abs_y)

    return (min_x, max_x, min_y, max_y)

def draw_strokes(data, factor=1, svg_filename='sample.svg'):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)

    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))

    lift_pen = 1

    abs_x = 25 - min_x
    abs_y = 25 - min_y
    p = "M%s,%s " % (abs_x, abs_y)

    command = "m"

    for i in range(len(data)):
        if (lift_pen == 1):
            command = "m"
        elif (command != "l"):
            command = "l"
        else:
            command = ""
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        lift_pen = data[i, 2]
        p += command + str(x) + "," + str(y) + " "

    the_color = "black"
    stroke_width = 1

    dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none"))

    dwg.save()
    display(SVG(dwg.tostring()))

In [None]:
original = valid_set[0]
x0 = np.array([valid_set[0][:SEQ_LEN]])
y0 = x0
# y0 = np.array([valid_set[0][:(SEQ_LEN+9)]])

In [None]:
# Predict a character and plot the result.
pi_temperature = 3 # seems to work well with rather high temperature (2.5)
sigma_temp = 0.1 # seems to work well with low temp

### Generation using one example from validation set as 
p = x0
sketch = p

for i in range(100):
    params = model.predict([p, p])
    out = mdn.sample_from_output(params[0][49], output_dim, number_mixtures, temp=pi_temperature, sigma_temp=sigma_temp)
    p = np.concatenate((p[:,1:],np.array([out])), axis=1)
    sketch = np.concatenate((sketch, np.array([out])), axis=1)

sketch.T[2] = cutoff_stroke(sketch.T[2])
draw_strokes(sketch[0], factor=0.5)
draw_strokes(x0[0], factor=0.5)