In [1]:
import numpy as np
import tensorflow as tf
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle as pkl


import tensorflow_probability as tfp

np.random.seed(0)
tf.random.set_seed(0)

In [2]:
DIM = 64 # Model dimensionality (number of neurons in the hidden layer(s))
CRITIC_ITERS = 50 # How many critic iterations (Sinkhorn iterations) per generator iteration#was 50
BATCH_SIZE = 256 # Batch size
ITERS = 2500#100000 # how many generator iterations to train for
DATA_DIM = 32
LATENT_DIM = 4
INITIALIZATION = 'he'#'glorot'
COVARIANCE_SCALE = np.sqrt(DATA_DIM)
INITIALIZE_LAST = True
SAMPLE_SIZE = 100000
LAMBDA = 0.3#2/(COVARIANCE_SCALE)
MODE = 'divergence' #'loss'

In [3]:
from tensorflow import keras
from tensorflow.keras import layers

In [4]:
if INITIALIZATION == 'he':
    weight_initializer = keras.initializers.he_uniform()
if INITIALIZATION == 'glorot':
    weight_initializer = keras.initializers.glorot_uniform()
bias_initializer = keras.initializers.zeros()

In [5]:
import scipy.special 
def logsumexp(x, axis = None):
    if axis is None:
        val = scipy.special.logsumexp(x)
        if val != val:
            return np.max(x)
    val = scipy.special.logsumexp(x, axis = axis)
#    val = np.nan_to_num(val, -5)
#     idx = np.where(val!=val)[0]
#     if len(idx) == 0:
#         return val
#     print(len(idx))
#     if axis == 0:
#         val[idx] = np.max(x[:, idx], axis = 0)
#     if axis == 1:
#         val[idx] = np.max(x[idx, :], axis = 1)
    return val

In [14]:
def squared_l2_matrix_np(X, Y):
    squared_l2_X = np.sum(np.square(X), axis = 1, keepdims=True)
    squared_l2_Y = np.sum(np.square(Y), axis = 1, keepdims=True)
    XY = X.dot(Y.T)
    return squared_l2_X + squared_l2_Y.T - 2 * XY

def log_coupling_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon):
    C_tild = cost_matrix - np.expand_dims(psi_X, axis=1) - np.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon + log_X_prob + log_Y_prob

def sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
    #print(psi_X, psi_Y)
    K_tld = np.exp(log_coupling_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)- log_X_prob -log_Y_prob)    
    Ku = np.sum(K_tld, axis = 1)
    #print(Ku)
    psi_X_upd = epsilon * (log_X_prob - np.log(Ku + 10**(-6)))+psi_X
    K_tld = np.exp(log_coupling_np(psi_X_upd, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)-log_X_prob-log_Y_prob)  
    Kv = np.sum(K_tld, axis = 0)
    #print(Kv)
    psi_Y_upd = epsilon * (log_Y_prob - np.log(Kv + 10**(-6)))+psi_Y
    
    #psi_X_upd = - epsilon * logsumexp(log_Y_prob+(np.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1)
    #psi_Y_upd = - epsilon * logsumexp(log_X_prob+(np.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0)
    if return_diff:
        diff = np.linalg.norm(psi_X_upd - psi_X) + np.linalg.norm(psi_Y_upd - psi_Y)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd
    if return_diff:
        return psi_X, psi_Y, diff
    return psi_X, psi_Y

def sinkhorn_step_symm_np(psi_X, log_X_prob, cost_matrix, epsilon, return_diff = False):
    psi_X_upd = (psi_X - epsilon * logsumexp(log_X_prob+(np.expand_dims(psi_X, axis=0) - 
                                                         cost_matrix) / epsilon, axis = 1))/2
    if return_diff:
        diff = np.linalg.norm(psi_X_upd - psi_X) 
    psi_X = psi_X_upd
    if return_diff:
        return psi_X, diff
    return psi_X

def sinkhorn_loss_np(X, Y, epsilon, num_steps):
    cost_matrix = squared_l2_matrix_np(X,Y)
    log_X_prob = - np.log(X.shape[0])
    log_Y_prob = - np.log(Y.shape[0])
    psi_Y = np.zeros(Y.shape[0])
    psi_X = np.zeros(X.shape[0])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    log_pi = log_coupling_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    pi = np.exp(log_pi)
    loss = np.sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
    #X and Y are uniform, so entropy = -log(probability)
    return loss, diff

def sinkhorn_dual_potentials_np(X, Y, epsilon, num_steps, psi_Y = None):
    cost_matrix = squared_l2_matrix_np(X,Y)
    log_X_prob = - np.log(X.shape[0])
    log_Y_prob = - np.log(Y.shape[0])
    if psi_Y is None:
        psi_Y = np.zeros(Y.shape[0])
        psi_X = np.zeros(X.shape[0])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    #X and Y are uniform, so entropy = -log(probability)
    return psi_X, psi_Y, diff

def sinkhorn_dual_potentials_symm_np(X, epsilon, num_steps):
    cost_matrix = squared_l2_matrix_np(X,X)
    log_X_prob = - np.log(X.shape[0])
    psi_X = np.zeros(X.shape[0])
    for l in range(num_steps-1):
        psi_X = sinkhorn_step_symm_np(psi_X, log_X_prob, cost_matrix, epsilon)
    psi_X, diff = sinkhorn_step_symm_np(psi_X, log_X_prob, cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    #X and Y are uniform, so entropy = -log(probability)
    return psi_X, diff

def K_tild(u,v,C,N,M,epsilon):
    C_tild = C - tf.transpose(tf.reshape(tf.tile(u[:,0],[M]),[M,N])) - tf.reshape(tf.tile(v[:,0],[N]),[N,M])
    K_tild = tf.exp(-C_tild/epsilon)
    return K_tild

In [15]:
def squared_l2_matrix(X, Y):
    squared_l2_X = tf.reduce_sum(tf.square(X), axis = 1, keepdims=True)
    squared_l2_Y = tf.reduce_sum(tf.square(Y), axis = 1, keepdims=True)
    XY = tf.matmul(X,tf.transpose(Y))
    return squared_l2_X + tf.transpose(squared_l2_Y) - 2 * XY

def log_coupling(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon):
    C_tild = cost_matrix - tf.expand_dims(psi_X, axis=1) - tf.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon + log_X_prob + log_Y_prob

# def sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
#     psi_X_upd = epsilon * (log_X_prob - \
#                        tf.reduce_logsumexp((tf.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1))
#     psi_Y_upd = epsilon * (log_Y_prob - \
#                        tf.reduce_logsumexp((tf.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0))
#     if return_diff:
#         diff = tf.norm(psi_X_upd - psi_X) + tf.norm(psi_Y_upd - psi_Y)
#     psi_X = psi_X_upd
#     psi_Y = psi_Y_upd
#     if return_diff:
#         return psi_X, psi_Y, diff
#     return psi_X, psi_Y

# def sinkhorn_loss(X, Y, epsilon, num_steps, return_diff = False, return_diff_only = False):
#     cost_matrix = squared_l2_matrix(X,Y)
#     log_X_prob = - tf.math.log(tf.cast(tf.shape(X)[0], tf.float32))
#     log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))
#     psi_Y = tf.zeros([tf.shape(Y)[0]])
#     for l in range(num_steps-1):
#         psi_X, psi_Y = sinkhorn_step(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
#     psi_X, psi_Y, diff = sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, \
#                                        cost_matrix, epsilon, return_diff = True)
#     #return psi_X, psi_Y, diff
#     log_pi = log_coupling(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
#     pi = tf.exp(log_pi)
#     loss = tf.reduce_sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
#     #X and Y are uniform, so entropy = -log(probability)
#     if return_diff_only:
#         return diff
#     if return_diff:
#         return loss, diff
#     return loss

class dual_potentials:
    def __init__(self):
        self.psi_X = None
        self.psi_Y = None
        self.diff = 0

class dual_potentials_symm:
    def __init__(self):
        self.psi_X = np.zeros(BATCH_SIZE)
        self.diff = 0

def sinkhorn_loss_from_potentials(X, Y, epsilon, dual_vars):
    cost_matrix = squared_l2_matrix(X,Y)
    log_X_prob = - tf.math.log(tf.cast(tf.shape(X)[0], tf.float32))
    log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))
    
    log_pi = log_coupling(dual_vars.psi_X, dual_vars.psi_Y, \
                          log_X_prob, log_Y_prob, cost_matrix, epsilon)
    log_pi = tf.minimum(log_pi, log_X_prob)
    pi = tf.exp(log_pi)
    loss = tf.reduce_sum(pi*(cost_matrix+epsilon*log_pi))
    #X and Y are uniform, so entropy = -log(probability)
    return loss

In [16]:
def inf_train_gen():
    np.random.seed(1)
    full_dataset = np.random.randn(SAMPLE_SIZE,DATA_DIM) / np.sqrt(COVARIANCE_SCALE) 
    i = 0
    offset = 0
    while True:
        dataset = full_dataset[i*BATCH_SIZE+offset:(i+1)*BATCH_SIZE+offset,:]
        if (i+1)*BATCH_SIZE+offset > SAMPLE_SIZE: 
            offset = (i+1)*BATCH_SIZE+offset - SAMPLE_SIZE
            np.random.shuffle(full_dataset)
            dataset = np.concatenate([dataset,full_dataset[:offset,:]], axis = 0)
            i = -1 
        i+=1
        yield dataset
data_gen = inf_train_gen()

In [17]:
#class WGAN:
 #   def __init__()self
latent_sample = keras.Input(shape=(LATENT_DIM,))
hidden_layer = latent_sample
for i in range(3):
    hidden_layer = layers.Dense(DIM, activation="relu", kernel_initializer=weight_initializer,
        bias_initializer=bias_initializer)(hidden_layer)
output = layers.Dense(DATA_DIM, kernel_initializer=weight_initializer,
                      bias_initializer=bias_initializer)(hidden_layer)
generator = keras.Model(inputs=latent_sample, outputs = output, name="generator")
print('Generator:')
print(generator.summary())

Generator:
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense_4 (Dense)              (None, 64)                320       
_________________________________________________________________
dense_5 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_6 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_7 (Dense)              (None, 32)                2080      
Total params: 10,720
Trainable params: 10,720
Non-trainable params: 0
_________________________________________________________________
None


In [18]:
dual_vars = dual_potentials()
dual_vars_x = dual_potentials()#true
dual_vars_y = dual_potentials()#pred
#MODE = 'loss'
epsilon = LAMBDA
if MODE == 'divergence':
    model_loss = lambda x, y: sinkhorn_loss_from_potentials(x, y, epsilon, dual_vars) - 0.5*\
        (sinkhorn_loss_from_potentials(x, x, epsilon, dual_vars_x))
else:
    model_loss = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS, dual_vars)

if MODE == 'divergence':
    args = [epsilon, CRITIC_ITERS, True, True]
    sinkhorn_diff = lambda x, y: (sinkhorn_loss(x, y, *args)\
        +sinkhorn_loss(x, x, *args)+sinkhorn_loss(y, y, *args))/3
else:
    sinkhorn_diff = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS)

metric = lambda x, y: tf.norm(tfp.stats.covariance(x) - tfp.stats.covariance(y))

generator.compile(optimizer="RMSprop", loss=model_loss, metrics=metric)

In [19]:
y_batch = next(data_gen)
x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
y_sample = np.array(generator.predict(x_batch, batch_size=BATCH_SIZE))
res_train = {'loss':{}, 'cov_diff' : {}}
res_test = {'loss':{}, 'cov_diff' : {}, 'sample':{}, 'sink_eps':{}}

In [20]:
for i in range(ITERS):
    psi_X, psi_Y, dual_vars.diff = sinkhorn_dual_potentials_np(y_batch, y_sample, epsilon, CRITIC_ITERS)
    dual_vars.psi_X = tf.cast(psi_X, 'float32')
    dual_vars.psi_Y = tf.cast(psi_Y, 'float32')
    
    #while dual_vars.diff > .1:
    #    psi_X, psi_Y, dual_vars.diff = sinkhorn_dual_potentials_np(y_batch, y_sample, epsilon, CRITIC_ITERS, psi_Y)
    #    dual_vars.psi_X = tf.cast(psi_X, 'float32')
    #    dual_vars.psi_Y = tf.cast(psi_Y, 'float32')
    print(np.isfinite(psi_X).all())
    print(np.isfinite(psi_Y).all())
        
    psi_X, dual_vars_x.diff = sinkhorn_dual_potentials_symm_np(y_batch, epsilon, CRITIC_ITERS)
    dual_vars_x.psi_X = tf.cast(psi_X, 'float32')
    dual_vars_x.psi_Y = dual_vars_x.psi_X
    psi_X, dual_vars_y.diff = sinkhorn_dual_potentials_symm_np(y_sample, epsilon, CRITIC_ITERS)
    dual_vars_x.psi_X = tf.cast(psi_X, 'float32')
    dual_vars_y.psi_Y = dual_vars_y.psi_X
    out = generator.train_on_batch(x = x_batch,y = y_batch)
    res_train['loss'][i] = out[0]
    res_train['cov_diff'][i] = out[1]
    y_batch = next(data_gen)
    x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
    #validate on the next data
    y_sample = np.array(generator.predict(x_batch, batch_size=BATCH_SIZE))
    #res_test['loss'][i], res_test['sink_eps'][i] = sinkhorn_loss_np(y_sample, y_batch, epsilon, CRITIC_ITERS)
    #res_test['sample'][i] = y_sample
    #res_test['cov_diff'][i] = np.linalg.norm(np.cov(y_sample) - np.cov(y_batch))
    print('Iteration ', i)
    print('Training: loss {}, covariance difference {}'.format(res_train['loss'][i], res_train['cov_diff'][i]))
    #print('Validation: loss {}, covariance difference {}, sinkhorn epsilon {}'.format(
    #    res_test['loss'][i], res_test['cov_diff'][i], res_test['sink_eps'][i]))
    #if i % 10 == 0:
    #    with open('logs.pkl', 'wb') as f:
    #        pkl.dump([res_test, res_train], f)

True
True
Iteration  0
Training: loss 0.827548086643219, covariance difference 13.74793815612793
True
True
Iteration  1
Training: loss 3831.65478515625, covariance difference 18.534156799316406
True
True
Iteration  2
Training: loss 3215.572509765625, covariance difference 26.967411041259766
True
True
Iteration  3
Training: loss 2261.848388671875, covariance difference 33.0068244934082
True
True
Iteration  4
Training: loss 1352.07177734375, covariance difference 37.229461669921875
True
True
Iteration  5
Training: loss 1082.1693115234375, covariance difference 55.39617156982422
True
True
Iteration  6
Training: loss 1275.5467529296875, covariance difference 52.75816345214844
True
True
Iteration  7
Training: loss 973.2342529296875, covariance difference 58.38312530517578
True
True
Iteration  8
Training: loss 1088.186767578125, covariance difference 67.35969543457031
True
True
Iteration  9
Training: loss 509.3439636230469, covariance difference 77.8211441040039
True
True
Iteration  10
Train

True
True
Iteration  84
Training: loss 0.826879620552063, covariance difference 383.71917724609375
True
True
Iteration  85
Training: loss 0.8268800973892212, covariance difference 312.02490234375
True
True
Iteration  86
Training: loss 0.8268795609474182, covariance difference 348.30877685546875
True
True
Iteration  87
Training: loss 109.18159484863281, covariance difference 363.69580078125
True
True
Iteration  88
Training: loss 0.8268797397613525, covariance difference 311.3428039550781
True
True
Iteration  89
Training: loss 0.8268803358078003, covariance difference 394.6433410644531
True
True
Iteration  90
Training: loss 24.17708969116211, covariance difference 458.63970947265625
True
True
Iteration  91
Training: loss 0.8268795013427734, covariance difference 459.36395263671875
True
True
Iteration  92
Training: loss 0.8268795013427734, covariance difference 413.78857421875
True
True
Iteration  93
Training: loss 0.8268795609474182, covariance difference 420.3114929199219
True
True
Iter

Iteration  167
Training: loss 0.826880931854248, covariance difference 452.7852478027344
True
True
Iteration  168
Training: loss 0.8268794417381287, covariance difference 529.7664794921875
True
True
Iteration  169
Training: loss 0.8268791437149048, covariance difference 525.42919921875
True
True
Iteration  170
Training: loss 0.826879620552063, covariance difference 475.8092346191406
True
True
Iteration  171
Training: loss 0.8268797397613525, covariance difference 667.373046875
True
True
Iteration  172
Training: loss 0.8268798589706421, covariance difference 456.39447021484375
True
True
Iteration  173
Training: loss 0.8268795013427734, covariance difference 559.6903686523438
True
True
Iteration  174
Training: loss 0.826880156993866, covariance difference 525.8646240234375
True
True
Iteration  175
Training: loss 0.8268799185752869, covariance difference 538.6388549804688
True
True
Iteration  176
Training: loss 0.8268803954124451, covariance difference 436.82171630859375
True
True
Iterati

KeyboardInterrupt: 

In [None]:
log_X_prob = - tf.math.log(tf.cast(tf.shape(hY)[0], tf.float32))
log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))

In [94]:
log_pi = log_coupling(dual_vars.psi_X, dual_vars.psi_Y, \
                          log_X_prob, log_Y_prob, cost_matrix, epsilon)

In [95]:
np.max(dual_vars.psi_Y)

522.78375

In [81]:
np.max(cost_matrix)

403.02386

In [72]:
np.max(np.array(log_pi))

1781.4918

In [73]:
log_pi

<tf.Tensor: shape=(256, 256), dtype=float32, numpy=
array([[-113.68503 ,  -26.168352,  191.39458 , ...,  -63.165714,
        -147.3079  , -204.64305 ],
       [   9.470531,  118.72862 ,  322.0069  , ...,   49.36162 ,
         -28.387571,  -88.39143 ],
       [-127.35713 ,  -19.360134,  168.58821 , ...,  -66.24303 ,
        -146.08418 , -216.43121 ],
       ...,
       [ -59.572117,    8.696758,  222.1756  , ...,  -49.007694,
        -101.53757 , -165.60045 ],
       [  40.280746,  124.44064 ,  362.07047 , ...,   66.759384,
          30.082405,  -47.806503],
       [  89.34366 ,  186.66008 ,  402.57306 , ...,  138.01866 ,
          47.45031 ,  -19.337582]], dtype=float32)>

In [96]:
pi = tf.exp(np.minimum(log_pi, log_X_prob))
np.sum(pi, axis = 0)

array([5.27525902e-01, 7.00110734e-01, 9.10156250e-01, 3.64875168e-01,
       1.11972511e-01, 1.82493269e-01, 5.31922817e-01, 3.63683224e-01,
       6.51517451e-01, 8.64191175e-01, 1.82750538e-01, 8.04702520e-01,
       3.71962599e-02, 6.25111222e-01, 9.53129113e-01, 9.26196456e-01,
       9.72656250e-01, 1.75975874e-01, 9.53126967e-01, 7.31868148e-01,
       9.88281250e-01, 1.18378878e-01, 5.51946104e-01, 9.72656250e-01,
       6.56578004e-01, 7.34416366e-01, 9.80468750e-01, 3.03956419e-01,
       7.08360314e-01, 7.61959612e-01, 8.09335828e-01, 9.88575637e-01,
       6.86276019e-01, 6.93455637e-01, 7.68851042e-01, 6.22675538e-01,
       4.43837434e-01, 8.68086159e-01, 2.36477908e-02, 5.00675082e-01,
       9.26009297e-01, 9.76562738e-01, 8.24343801e-01, 7.04391956e-01,
       9.25984621e-01, 2.92849123e-01, 4.63242471e-01, 9.37180161e-01,
       9.64843750e-01, 4.57297564e-01, 9.57031250e-01, 8.18132997e-01,
       7.30552733e-01, 9.37500000e-01, 2.01298624e-01, 8.35937500e-01,
      

In [None]:
dual_vars.psi_X, dual_vars.psi_Y, dual_vars.diff = sinkhorn_dual_potentials_np(y_batch, y_sample, epsilon, CRITIC_ITERS)
dual_vars_x.psi_X, dual_vars_x.diff = sinkhorn_dual_potentials_symm_np(y_batch, epsilon, CRITIC_ITERS)
dual_vars_x.psi_Y = dual_vars_x.psi_X
dual_vars_y.psi_X, dual_vars_y.diff = sinkhorn_dual_potentials_symm_np(y_sample, epsilon, CRITIC_ITERS)
dual_vars_y.psi_Y = dual_vars_y.psi_X
#out = generator.train_on_batch(x = x_batch,y = y_batch)
#res_train['loss'][i] = out[0]
#res_train['cov_diff'][i] = out[1]
#y_batch = next(data_gen)
#x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
#validate on the next data
#y_sample = np.array(generator.predict(x_batch, batch_size=BATCH_SIZE))

In [None]:
model_loss(tf.cast(y_batch, 'float32'),tf.cast(y_sample, 'float32'))

In [None]:
out = generator.train_on_batch(x = x_batch,y = y_batch)

In [None]:
np.diag(squared_l2_matrix_np(y_batch, y_batch))

In [None]:
res_train['loss']

In [None]:
Y = tf.cast(y_sample, 'float32')
X = tf.cast(y_batch, 'float32')
dual_vars.psi_X, dual_vars.psi_Y, dual_vars.diff = sinkhorn_dual_potentials_np(y_batch, y_sample, epsilon, CRITIC_ITERS)
dual_vars_x.psi_X, dual_vars_x.diff = sinkhorn_dual_potentials_symm_np(y_batch, epsilon, CRITIC_ITERS)
dual_vars_x.psi_Y = dual_vars_x.psi_X
dual_vars_y.psi_X, dual_vars_y.diff = sinkhorn_dual_potentials_symm_np(y_sample, epsilon, CRITIC_ITERS)
dual_vars_y.psi_Y = dual_vars_y.psi_X
model_loss(X, Y)

In [None]:
out = generator.train_on_batch(x = x_batch,y = y_batch)

In [None]:
model_loss(X, Y)

In [None]:
sinkhorn_loss_from_potentials(Y, X, epsilon, CRITIC_ITERS, dual_vars)

In [None]:
cost_matrix = squared_l2_matrix(X,Y)

In [None]:
log_X_prob = - tf.math.log(tf.cast(tf.shape(X)[0], tf.float32))
log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))

In [None]:
log_pi = log_coupling(tf.cast(dual_vars.psi_X, 'float32'), tf.cast(dual_vars.psi_Y, 'float32'), log_X_prob, log_Y_prob, cost_matrix, epsilon)

In [None]:
X = y_sample
Y = y_batch

cost_matrix = squared_l2_matrix_np(X,Y)

In [None]:
psi_X, psi_Y, _ = sinkhorn_dual_potentials_np(X, Y, epsilon, 100)
ln_pi = log_coupling_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)

In [None]:

log_X_prob = - np.log(X.shape[0])
log_Y_prob = - np.log(Y.shape[0])
psi_Y = np.zeros(Y.shape[0])

In [None]:
#psi_X, psi_Y = sinkhorn_step_np(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
for i in range(100):
    psi_X_upd =  -epsilon*logsumexp(l+(np.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1)
    psi_Y_upd =  -epsilon*logsumexp(l+(np.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd

In [None]:
l = -np.log(BATCH_SIZE)

In [None]:
psi_Y -= epsilon*np.log(BATCH_SIZE)/2
psi_X -= epsilon*np.log(BATCH_SIZE)/2

In [None]:
np.sum(np.exp((- cost_matrix + np.expand_dims(psi_Y, axis=0) + np.expand_dims(psi_X, axis=1))/epsilon), axis = 1)/(BATCH_SIZE**2)

In [None]:
for i in range(100):
    psi_X, psi_Y, _ = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
log_pi = log_coupling_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)

In [None]:
np.sum(np.exp(log_pi), axis = 1)

In [None]:
tf.reduce_min(log_pi)

In [None]:
pi = tf.exp(log_pi)

In [None]:
pi

In [None]:
pi = tf.exp(log_pi)
loss = tf.reduce_sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)