In [1]:
import numpy as np
import tensorflow as tf
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle as pkl


import tensorflow_probability as tfp

np.random.seed(0)
tf.random.set_seed(0)

In [2]:
DIM = 64 # Model dimensionality (number of neurons in the hidden layer(s))
CRITIC_ITERS = 10 # How many critic iterations (Sinkhorn iterations) per generator iteration#was 50
BATCH_SIZE = 1024#256 # Batch size
ITERS = 2500#100000 # how many generator iterations to train for
DATA_DIM = 32
LATENT_DIM = 4
INITIALIZATION = 'he'#'glorot'
COVARIANCE_SCALE = np.sqrt(DATA_DIM)
INITIALIZE_LAST = True
SAMPLE_SIZE = 100000
LAMBDA = 0.3#2/(COVARIANCE_SCALE)
MODE = 'divergence' #'loss'

In [3]:
from tensorflow import keras
from tensorflow.keras import layers

In [4]:
if INITIALIZATION == 'he':
    weight_initializer = keras.initializers.he_uniform()
if INITIALIZATION == 'glorot':
    weight_initializer = keras.initializers.glorot_uniform()
bias_initializer = keras.initializers.zeros()

In [5]:
#class WGAN:
 #   def __init__()self
latent_sample = keras.Input(shape=(LATENT_DIM,))
hidden_layer = latent_sample
for i in range(3):
    hidden_layer = layers.Dense(DIM, activation="relu", kernel_initializer=weight_initializer,
        bias_initializer=bias_initializer)(hidden_layer)
output = layers.Dense(DATA_DIM, kernel_initializer=weight_initializer,
                      bias_initializer=bias_initializer)(hidden_layer)
generator = keras.Model(inputs=latent_sample, outputs = output, name="generator")
print('Generator:')
print(generator.summary())

Generator:
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 64)                320       
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 32)                2080      
Total params: 10,720
Trainable params: 10,720
Non-trainable params: 0
_________________________________________________________________
None


In [6]:
from scipy.special import logsumexp

In [7]:
def squared_l2_matrix_np(X, Y):
    squared_l2_X = np.sum(np.square(X), axis = 1, keepdims=True)
    squared_l2_Y = np.sum(np.square(Y), axis = 1, keepdims=True)
    XY = X.dot(Y.T)
    return squared_l2_X + tf.transpose(squared_l2_Y) - 2 * XY

def log_coupling_np(psi_X, psi_Y, cost_matrix, epsilon):
    C_tild = cost_matrix - np.expand_dims(psi_X, axis=1) - np.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon

def sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
    psi_X_upd = epsilon * (log_X_prob - \
                       logsumexp((np.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1))
    psi_Y_upd = epsilon * (log_Y_prob - \
                       logsumexp((np.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0))
    if return_diff:
        diff = np.linalg.norm(psi_X_upd - psi_X) + np.linalg.norm(psi_Y_upd - psi_Y)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd
    if return_diff:
        return psi_X, psi_Y, diff
    return psi_X, psi_Y

def sinkhorn_loss_np(X, Y, epsilon, num_steps):
    cost_matrix = squared_l2_matrix_np(X,Y)
    log_X_prob = - np.log(X.shape[0])
    log_Y_prob = - np.log(Y.shape[0])
    psi_Y = np.zeros(Y.shape[0])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step_np(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    log_pi = log_coupling_np(psi_X, psi_Y, cost_matrix, epsilon)
    pi = np.exp(log_pi)
    loss = np.sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
    #X and Y are uniform, so entropy = -log(probability)
    return loss, diff

In [8]:
def squared_l2_matrix(X, Y):
    squared_l2_X = tf.reduce_sum(tf.square(X), axis = 1, keepdims=True)
    squared_l2_Y = tf.reduce_sum(tf.square(Y), axis = 1, keepdims=True)
    XY = tf.matmul(X,tf.transpose(Y))
    return squared_l2_X + tf.transpose(squared_l2_Y) - 2 * XY

def log_coupling(psi_X, psi_Y, cost_matrix, epsilon):
    C_tild = cost_matrix - tf.expand_dims(psi_X, axis=1) - tf.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon

def sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
    psi_X_upd = epsilon * (log_X_prob - \
                       tf.reduce_logsumexp((tf.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1))
    psi_Y_upd = epsilon * (log_Y_prob - \
                       tf.reduce_logsumexp((tf.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0))
    if return_diff:
        diff = tf.norm(psi_X_upd - psi_X) + tf.norm(psi_Y_upd - psi_Y)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd
    if return_diff:
        return psi_X, psi_Y, diff
    return psi_X, psi_Y

def sinkhorn_loss(X, Y, epsilon, num_steps, return_diff = False, return_diff_only = False):
    cost_matrix = squared_l2_matrix(X,Y)
    log_X_prob = - tf.math.log(tf.cast(tf.shape(X)[0], tf.float32))
    log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))
    psi_Y = tf.zeros([tf.shape(Y)[0]])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    log_pi = log_coupling(psi_X, psi_Y, cost_matrix, epsilon)
    pi = tf.exp(log_pi)
    loss = tf.reduce_sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
    #X and Y are uniform, so entropy = -log(probability)
    if return_diff_only:
        return diff
    if return_diff:
        return loss, diff
    return loss

In [9]:
epsilon = LAMBDA
if MODE == 'divergence':
    model_loss = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS) - 0.5*\
        (sinkhorn_loss(x, x, epsilon, CRITIC_ITERS)-sinkhorn_loss(y, y, epsilon, CRITIC_ITERS))
else:
    model_loss = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS)

if MODE == 'divergence':
    args = [epsilon, CRITIC_ITERS, True, True]
    sinkhorn_diff = lambda x, y: (sinkhorn_loss(x, y, *args)\
        +sinkhorn_loss(x, x, *args)+sinkhorn_loss(y, y, *args))/3
else:
    sinkhorn_diff = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS)

metrics = lambda x, y: tf.norm(tfp.stats.covariance(x) - tfp.stats.covariance(y))
generator.compile(optimizer="Adam", loss=model_loss, metrics=metrics)

In [10]:
def inf_train_gen():
    np.random.seed(1)
    full_dataset = np.random.randn(SAMPLE_SIZE,DATA_DIM) / np.sqrt(COVARIANCE_SCALE) 
    i = 0
    offset = 0
    while True:
        dataset = full_dataset[i*BATCH_SIZE+offset:(i+1)*BATCH_SIZE+offset,:]
        if (i+1)*BATCH_SIZE+offset > SAMPLE_SIZE: 
            offset = (i+1)*BATCH_SIZE+offset - SAMPLE_SIZE
            np.random.shuffle(full_dataset)
            dataset = np.concatenate([dataset,full_dataset[:offset,:]], axis = 0)
            i = -1 
        i+=1
        yield dataset
data_gen = inf_train_gen()

In [11]:
y_batch = next(data_gen)
x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))

In [12]:
y_sample = generator.predict(x_batch, batch_size=BATCH_SIZE)

In [13]:
np.linalg.norm(np.cov(y_sample) - np.cov(y_batch))

801.1129233615842

In [14]:
model_loss(tf.constant(y_sample, dtype='float32'), tf.constant(y_batch, dtype='float32'))

<tf.Tensor: shape=(), dtype=float32, numpy=28.487322>

In [15]:
def cost_mat(X,Y,N,M): #for Sinkhorm divergence
    XX = tf.reduce_sum(tf.multiply(X,X),axis=1)#norms of X
    YY = tf.reduce_sum(tf.multiply(Y,Y),axis=1)#norms of Y
    C1 = tf.transpose(tf.reshape(tf.tile(XX,[M]),[M,N]))
    C2 = tf.reshape(tf.tile(YY,[N]),[N,M])
    C3 = tf.transpose(tf.matmul(Y,tf.transpose(X)))
    C = C1 + C2 - 2*C3; #squared norms of difference
    return C

def K_tild(u,v,C,N,M,epsilon):
    C_tild = C - tf.transpose(tf.reshape(tf.tile(u[:,0],[M]),[M,N])) - tf.reshape(tf.tile(v[:,0],[N]),[N,M])
    K_tild = tf.exp(-C_tild/epsilon)
    return K_tild


def log_K_tild(u,v,C,N,M,epsilon):
    C_tild = C - tf.transpose(tf.reshape(tf.tile(u[:,0],[M]),[M,N])) - tf.reshape(tf.tile(v[:,0],[N]),[N,M])
    return -C_tild/epsilon

def sinkhorn_step_log(j,u,v,C, N,M,epsilon,diff,Lambda = 1):
    mu = tf.cast(1/N, tf.float32)
    nu = tf.cast(1/M, tf.float32)
    Ku = tf.reshape( tf.reduce_sum(K_tild(u,v,C,N,M,epsilon),axis = 1) ,[N,1] )
    u_new = Lambda * ( epsilon*(tf.math.log(mu) - tf.math.log(Ku +10**(-6))) + u )
    diff = tf.norm(u_new - u)
    u = u_new
    
    Kv = tf.reshape( tf.reduce_sum(K_tild(u,v,C,N,M,epsilon),axis = 0), [M,1] )
    v_new = Lambda * ( epsilon*(tf.math.log(nu) - tf.math.log(Kv +10**(-6))) + v )
    diff += tf.norm(v - v_new)
    v = v_new
    j += 1
    return j,u,v,C,N,M,epsilon,diff

def sinkhorn_loss1(X,Y):#LOSS, NOT DIVERGENCE
    epsilon = tf.constant(LAMBDA, dtype=tf.float32) # smoothing sinkhorn
    Lambda = tf.constant(1.) # unbalanced parameter
    k = tf.constant(CRITIC_ITERS) # number of iterations for sinkhorn
    N = tf.shape(X)[0] # sample size from mu_theta
    M = tf.shape(Y)[0] # sample size from \hat nu
    
    mu = tf.cast(1/N, tf.float32)
    nu = tf.cast(1/M, tf.float32)
    
    D = tf.shape(Y)[1] # dimension of the obervation space
    C = cost_mat(X,Y,N,M)
    K = tf.exp(-C/epsilon)
    #sinkhorn iterations
    j0 = tf.constant(0)
    u0 = tf.zeros([N,1])
    v0 = tf.zeros([M,1])
    diff = tf.cast(0., tf.float32)
    cond_iter = lambda j, u, v, C, N, M, epsilon, diff: j < k
    j,u,v,C,N,M,epsilon,diff = tf.while_loop(
    cond_iter, sinkhorn_step_log, loop_vars=[j0, u0, v0,C, N,M,epsilon,diff])
    gamma_log = K_tild(u,v,C,N,M,epsilon)
    log_gamma_log = log_K_tild(u,v,C,N,M,epsilon)
    final_cost = tf.reduce_sum(gamma_log*(C+epsilon*(log_gamma_log - tf.math.log(mu) - tf.math.log(nu))))
    return final_cost, diff

In [16]:
X = tf.constant(y_sample, dtype = 'float32')
Y = tf.constant(y_batch, dtype = 'float32')
loss_their = sinkhorn_loss1(X,Y)
loss_my = sinkhorn_loss(X,Y, epsilon=LAMBDA, num_steps=CRITIC_ITERS)

In [None]:
y_batch = next(data_gen)
x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
res_train = {'loss':{}, 'cov_diff' : {}}
res_test = {'loss':{}, 'cov_diff' : {}, 'sample':{}, 'sink_eps':{}}
for i in range(ITERS):
    out = generator.train_on_batch(x = x_batch,y = y_batch)
    res_train['loss'][i] = out[0]
    res_train['cov_diff'][i] = out[1]
    y_batch = next(data_gen)
    x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
    #validate on the next data
    y_sample = generator.predict(x_batch, batch_size=BATCH_SIZE)
    res_test['loss'][i], res_test['sink_eps'][i] = sinkhorn_loss_np(y_sample, y_batch, epsilon, CRITIC_ITERS)
    res_test['sample'][i] = y_sample
    res_test['cov_diff'][i] = np.linalg.norm(np.cov(y_sample) - np.cov(y_batch))
    print('Iteration ', i)
    print('Training: loss {}, covariance difference {}'.format(res_train['loss'][i], res_train['cov_diff'][i]))
    print('Validation: loss {}, covariance difference {}, sinkhorn epsilon {}'.format(
        res_test['loss'][i], res_test['cov_diff'][i], res_test['sink_eps'][i]))
    if i % 10 == 0:
        with open('logs.pkl', 'wb') as f:
            pkl.dump([res_test, res_train], f)

Iteration  0
Training: loss 37.74919128417969, covariance difference 8.588347434997559
Validation: loss 25.007077780213667, covariance difference 639.828849980119, sinkhorn epsilon 8.804343744228428
Iteration  1
Training: loss 31.41640281677246, covariance difference 7.111966133117676
Validation: loss 23.21416514396919, covariance difference 530.1967824664677, sinkhorn epsilon 7.139346413990609
Iteration  2
Training: loss 27.95612144470215, covariance difference 6.575160980224609
Validation: loss 21.57743422653142, covariance difference 444.7808165818492, sinkhorn epsilon 6.338020658890348
Iteration  3
Training: loss 25.23994255065918, covariance difference 6.028720378875732
Validation: loss 19.01532778122752, covariance difference 334.78127336813276, sinkhorn epsilon 4.513818354099044
Iteration  4
Training: loss 21.1697940826416, covariance difference 5.084527015686035
Validation: loss 16.630164845733763, covariance difference 265.7067050705343, sinkhorn epsilon 4.2487144999017525
Ite

Iteration  41
Training: loss 5.571918487548828, covariance difference 0.8134355545043945
Validation: loss 5.86058089765997, covariance difference 34.774679257994435, sinkhorn epsilon 0.03824589523424893
Iteration  42
Training: loss 5.514727592468262, covariance difference 0.823392391204834
Validation: loss 5.817507722233056, covariance difference 34.08927667514661, sinkhorn epsilon 0.08628549305521464
Iteration  43
Training: loss 5.432528495788574, covariance difference 0.8387850522994995
Validation: loss 5.867210309481747, covariance difference 34.64568196914918, sinkhorn epsilon 0.09217519188227913
Iteration  44
Training: loss 5.47402811050415, covariance difference 0.8481283783912659
Validation: loss 5.87071034917437, covariance difference 34.61756097686972, sinkhorn epsilon 0.027788824048907083
Iteration  45
Training: loss 5.476345539093018, covariance difference 0.8542080521583557
Validation: loss 5.81745618965108, covariance difference 34.25037473430803, sinkhorn epsilon 0.168086

Iteration  82
Training: loss 4.823050022125244, covariance difference 0.9741846919059753
Validation: loss 5.626899723962247, covariance difference 32.925259731146234, sinkhorn epsilon 0.003973317168813634
Iteration  83
Training: loss 4.780470848083496, covariance difference 0.9736776351928711
Validation: loss 5.625148445920805, covariance difference 33.01283093022833, sinkhorn epsilon 0.0008437895346624984
Iteration  84
Training: loss 4.7795209884643555, covariance difference 0.9729753136634827
Validation: loss 5.6186944263194585, covariance difference 32.8944899030192, sinkhorn epsilon 0.009862588354578504
Iteration  85
Training: loss 4.759727954864502, covariance difference 0.9734035134315491
Validation: loss 5.680182304370346, covariance difference 33.227705940663675, sinkhorn epsilon 0.0005415776847359871
Iteration  86
Training: loss 4.823807239532471, covariance difference 0.982150137424469
Validation: loss 5.690797902394639, covariance difference 33.33929366835476, sinkhorn epsil

Iteration  122
Training: loss 4.646111488342285, covariance difference 1.008522391319275
Validation: loss 5.672703170742475, covariance difference 33.063296502061135, sinkhorn epsilon 0.0003255529351237172
Iteration  123
Training: loss 4.670071125030518, covariance difference 1.0121859312057495
Validation: loss 5.654226913226512, covariance difference 33.083823068933626, sinkhorn epsilon 3.130682922469983e-05
Iteration  124
Training: loss 4.653210639953613, covariance difference 1.0100913047790527
Validation: loss 5.517174886088744, covariance difference 32.175440691190225, sinkhorn epsilon 0.0005361614588134216
Iteration  125
Training: loss 4.513741493225098, covariance difference 0.9825506806373596
Validation: loss 5.618091446327301, covariance difference 32.748191887066255, sinkhorn epsilon 2.2966691325535604e-06
Iteration  126
Training: loss 4.610605716705322, covariance difference 1.0023987293243408
Validation: loss 5.65691151320774, covariance difference 33.09238356241306, sinkho

Iteration  162
Training: loss 4.594809532165527, covariance difference 1.0082141160964966
Validation: loss 5.697498813427376, covariance difference 33.31276838891705, sinkhorn epsilon 5.219799492429218e-08
Iteration  163
Training: loss 4.6634931564331055, covariance difference 1.0223615169525146
Validation: loss 5.6927263927181935, covariance difference 33.210091857654156, sinkhorn epsilon 2.375479258965467e-13
Iteration  164
Training: loss 4.659740447998047, covariance difference 1.0208648443222046
Validation: loss 5.67690097728609, covariance difference 33.13003023923635, sinkhorn epsilon 2.491775177880678e-07
Iteration  165
Training: loss 4.643725395202637, covariance difference 1.016624927520752
Validation: loss 5.693576392794467, covariance difference 33.260687163542826, sinkhorn epsilon 1.0344007412938269e-13
Iteration  166
Training: loss 4.659294128417969, covariance difference 1.0200093984603882
Validation: loss 5.6576503626981545, covariance difference 33.035130455053974, sink

Iteration  202
Training: loss 4.58421516418457, covariance difference 1.0097519159317017
Validation: loss 5.639323500043002, covariance difference 32.89797767048951, sinkhorn epsilon 4.228049100380474e-15
Iteration  203
Training: loss 4.601114749908447, covariance difference 1.0116932392120361
Validation: loss 5.691661465764227, covariance difference 33.23179055870439, sinkhorn epsilon 1.8112397691977804e-14
Iteration  204
Training: loss 4.654012680053711, covariance difference 1.0218461751937866
Validation: loss 5.719849097036582, covariance difference 33.38482886657158, sinkhorn epsilon 9.051038349787262e-15
Iteration  205
Training: loss 4.681933403015137, covariance difference 1.025163173675537
Validation: loss 5.647175500035607, covariance difference 32.93832813452978, sinkhorn epsilon 2.0097844821418505e-14
Iteration  206
Training: loss 4.610136032104492, covariance difference 1.013540267944336
Validation: loss 5.740390318755755, covariance difference 33.50022814831523, sinkhorn e