# Example: Gaussian-Mixture data

## 여기서부터 다시 하기 TEST at colab: https://colab.research.google.com/drive/1lLE7oZCxzRfujsR9zL8CaqUwwMvW38Fl?usp=sharing

2020-11-26 first created

In [None]:
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import itertools
from model import InvertibleNet
from utils import *
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
print(tf.__version__)
!python --version

## Settings

In [None]:
n_means = 8
radius = 10
sd = 1.5
labels = ['red','red','red','red','blue','blue','green','purple']
assert len(labels) == n_means

x_dim = 2
y_dim = 4
z_dim = 2
tot_dim = y_dim + z_dim
pad_dim = tot_dim - x_dim
n_sample = 500
n_data = n_sample * n_means
n_couple_layer = 5
n_hid_layer = 3
n_hid_dim = 512

n_batch = 128
n_epoch = 1000
n_display = 100

In [None]:
# Make data
X_raw = np.zeros((n_means, n_sample, x_dim), dtype='float32')
for i in range(n_means):
    th = 2*np.pi / n_means * (i+1)
    mean = [radius*np.cos(th), radius*np.sin(th)]
    X_raw[i, :, :] = np.random.multivariate_normal(mean, np.identity(x_dim)*sd, size=n_sample)
print(X_raw.shape)

In [None]:
fig, ax = plt.subplots(figsize=(5,5), facecolor='white')
for i in range(n_means):
    ax.scatter(X_raw[i,:,0], X_raw[i,:,1], s=5)

In [None]:
# Preprocess
uq_labels = list(set(labels))
idx2lab = {i:lab for i, lab in enumerate(uq_labels)}
lab2idx = {idx2lab[key]:i for i, key in enumerate(idx2lab.keys())}

X = X_raw.reshape((-1, x_dim))
X = StandardScaler().fit_transform(X)
y = [[lab2idx[lab]]*n_sample for lab in labels]
y = list(itertools.chain.from_iterable(y)) # flatten
y_onehot = np.eye(len(uq_labels))[y].astype('int')

In [None]:
fig, ax = plt.subplots(figsize=(5,5), facecolor='white')
for i, color in zip(idx2lab.keys(), lab2idx.keys()):
    idx = [True if j==i else False for j in y]
    ax.scatter(X[idx,0], X[idx,1], s=5, c=color)

In [None]:
# Pad data
pad_x = np.zeros((X.shape[0], pad_dim))
x_data = np.concatenate([X, pad_x], axis=-1).astype('float32')
z = np.random.multivariate_normal([1.]*x_dim, np.eye(x_dim), X.shape[0])
y_data = np.concatenate([z, y_onehot], axis=-1).astype('float32')

# Make dataset generator
x_data = tf.data.Dataset.from_tensor_slices(x_data)
y_data = tf.data.Dataset.from_tensor_slices(y_data)
dataset = (tf.data.Dataset.zip((x_data, y_data))
           .shuffle(buffer_size=X.shape[0])
           .batch(n_batch, drop_remainder=True)
           .repeat())

In [None]:
def MMD_multiscale(x, y):
    xx = tf.linalg.matmul(x, tf.transpose(x))
    yy = tf.linalg.matmul(y, tf.transpose(y))
    zz = tf.linalg.matmul(x, tf.transpose(y))

    rx = tf.broadcast_to(tf.linalg.diag_part(xx), xx.shape)
    ry = tf.broadcast_to(tf.linalg.diag_part(yy), yy.shape)

    dxx = tf.transpose(rx) + rx - 2.*xx
    dyy = tf.transpose(ry) + ry - 2.*yy
    dxy = tf.transpose(rx) + ry - 2.*zz

    XX = tf.zeros(xx.shape, dtype='float32')
    YY = tf.zeros(xx.shape, dtype='float32')
    XY = tf.zeros(xx.shape, dtype='float32')

    for a in [0.05, 0.2, 0.9]:
        XX += a**2 * 1/(a**2 + dxx)
        YY += a**2 * 1/(a**2 + dyy)
        XY += a**2 * 1/(a**2 + dxy)

    return tf.reduce_mean(XX + YY - 2.*XY)


def MSE(y_true, y_pred):
    return tf.reduce_mean(tfk.losses.mean_squared_error(y_true, y_pred))


class Trainer(tfk.Model):
    def __init__(self, model, x_dim, y_dim, z_dim, n_couple_layer, n_hid_layer, n_hid_dim):
        super(Trainer, self).__init__()
        self.model = model
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.tot_dim = y_dim + z_dim
        self.pad_dim = self.tot_dim - x_dim
        self.n_couple_layer = n_couple_layer
        self.n_hid_layer = n_hid_layer
        self.n_hid_dim = n_hid_dim

        # Check if total dimension is even
        assert self.tot_dim % 2 == 0

    def train_step(self, data):
        x_data, y_data = data
        n_sample = x_data.shape[0]

        # Forward loss
        with tf.GradientTape() as tape:
            y_pred = self.model(x_data)
            pred_loss = MSE(y_data[:, self.z_dim:], y_pred[:, self.z_dim:])
            latent_loss = 0
            forward_loss = pred_loss + latent_loss
        grads_forward = tape.gradient(
            forward_loss, self.model.trainable_weights)
        grads_forward = [tf.clip_by_value(grad, -15., 15.) for grad in grads_forward]
        self.optimizer.apply_gradients(zip(grads_forward, self.model.trainable_weights))
        

        # Backward loss
        inverse_loss = 0


        total_loss = pred_loss + latent_loss + inverse_loss
        return {'loss': total_loss,
                'pred_loss': pred_loss,
                'latent_loss': latent_loss,
                'inverse_loss': inverse_loss}

#     def test_step(self, data):
#         x_data, y_data = data
#         return {}

In [None]:
INN = InvertibleNet(x_dim, n_couple_layer, n_hid_layer, n_hid_dim, name='INN')
x = tfk.Input((tot_dim,))
INN(x)
INN.summary()

In [None]:
trainer = Trainer(INN, x_dim, y_dim, z_dim, n_couple_layer, n_hid_layer, n_hid_dim)
trainer.compile(optimizer='Adam')

In [None]:
%%time
logger = NBatchLogger(n_display, n_epoch)
hist = trainer.fit(dataset,
                   batch_size=n_batch,
                   epochs=n_epoch,
                   steps_per_epoch=n_data//n_batch, 
                   callbacks=[logger], 
                   verbose=0)

In [None]:
INN = trainer.model

In [None]:
z = np.random.multivariate_normal([1.]*z_dim, np.eye(z_dim), y_onehot.shape[0])
y = np.concatenate([z, y_onehot], axis=-1).astype('float32')

x_pred = INN.inverse(y).numpy()

fig, ax = plt.subplots(figsize=(5,5), facecolor='white')
for i, color in zip(idx2lab.keys(), lab2idx.keys()):
    idx = [True if j==i else False for j in y_onehot.argmax(axis=-1)]
    ax.scatter(x_pred[idx,0], x_pred[idx,1], s=5, c=color)