In [None]:
import os
import time

import numpy as np
import scipy.stats as stats

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import seaborn as sns

%matplotlib inline

In [None]:
dist = stats.dirichlet([.5, .5, .5])

In [None]:
corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
AREA = 0.5 * 1 * 0.75**0.5
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

In [None]:
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=4)

plt.figure(figsize=(8, 4))
for (i, mesh) in enumerate((triangle, trimesh)):
    plt.subplot(1, 2, i+ 1)
    plt.triplot(mesh)
    plt.axis('off')
    plt.axis('equal')

In [None]:
# For each corner of the triangle, the pair of other corners
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
# The area of the triangle formed by point xy and another pair or points
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))

def xy2bc(xy):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
    
    def validate_coord(coord):
        eps = 1e-5
        coord = [x if x > 0 else x + eps for x in coord]
        coord = [min(x, 1.0) for x in coord]
        
        if sum(coord) != 1:
            correction = 1 - sum(coord)
            coord[np.argmax(coord)] = coord[np.argmax(coord)] + correction
        
        return coord
    
    coords = np.apply_along_axis(validate_coord, 0, coords)
      
    return coords

In [None]:
def draw_pdf_contours(dist, nlevels=200, subdiv=6, **kwargs):

    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]

    plt.tricontourf(trimesh, pvals, nlevels, cmap='viridis', **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')


In [None]:
dist = stats.dirichlet([3, 3, 2])
draw_pdf_contours(dist)

In [None]:
dist.rvs()

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
def make_generator_model(input_length=16, n_dense_nodes=16, output_size=3):
    model = tf.keras.Sequential()
    model.add(layers.Dense(n_dense_nodes, use_bias=False, input_shape=(input_length,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    assert model.output_shape == (None, 16)
    
    model.add(layers.Dense(n_dense_nodes, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    assert model.output_shape == (None, 16)
    model.add(layers.Dense(output_size, use_bias=False))
    
    assert model.output_shape == (None, 3)
    
    return model

In [None]:
generator = make_generator_model()

z = tf.random.normal([1, 16])

generated_sample = generator(z, training=False)

In [None]:
generated_sample

In [None]:
def make_discriminator_model(input_length=3, n_dense_nodes=32):
    model = tf.keras.Sequential()
    model.add(layers.Dense(n_dense_nodes, use_bias=False, input_shape=(input_length,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    assert model.output_shape == (None, n_dense_nodes)
    
    model.add(layers.Dense(n_dense_nodes, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    assert model.output_shape == (None, n_dense_nodes)
    
    model.add(layers.Dense(1))
    
    return model

In [None]:
discriminator = make_discriminator_model()

In [None]:
decision = discriminator(generated_sample)
print(decision)


In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)


In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


In [None]:
EPOCHS = 50
noise_dim = 16
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])


In [None]:
# @tf.function
loss_trace = {
    "generator_loss": [],
    "discriminator_loss": []
}

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    loss_trace["generator_loss"].append(gen_loss)
    loss_trace["discriminator_loss"].append(disc_loss)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))


In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

        # TODO: generate samples from generator.

        # Save the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # TODO: generate samples from generator.

In [None]:
train_samples = dist.rvs(60000)

In [None]:
BATCH_SIZE = 256
BUFFER_SIZE = 60000
train_samples = dist.rvs(BUFFER_SIZE)
train_dataset = tf.data.Dataset.from_tensor_slices(train_samples).shuffle(buffer_size=BUFFER_SIZE).batch(BATCH_SIZE)


In [None]:
train(train_dataset, EPOCHS)

In [None]:
def get_generator_input():
    return tf.random.normal([1, 16])

z1 = get_generator_input()
generator(z1)

In [None]:
generator_inputs = [get_generator_input() for _ in range(1000)]

In [None]:
generated_samples = [generator(z) for z in generator_inputs]

In [None]:
v0 = [v[0] for v in generated_samples]

In [None]:
sns.distplot(v0)

In [None]:
plt.plot(loss_trace["discriminator_loss"])
plt.plot(loss_trace["generator_loss"])