In [1]:
import numpy as np
import tensorflow as tf
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle as pkl


import tensorflow_probability as tfp

np.random.seed(0)
tf.random.set_seed(0)

In [2]:
DIM = 64 # Model dimensionality (number of neurons in the hidden layer(s))
CRITIC_ITERS = 10 # How many critic iterations (Sinkhorn iterations) per generator iteration#was 50
BATCH_SIZE = 1024#256 # Batch size
ITERS = 2500#100000 # how many generator iterations to train for
DATA_DIM = 32
LATENT_DIM = 4
INITIALIZATION = 'he'#'glorot'
COVARIANCE_SCALE = np.sqrt(DATA_DIM)
INITIALIZE_LAST = True
SAMPLE_SIZE = 100000
LAMBDA = 0.3#2/(COVARIANCE_SCALE)
MODE = 'divergence' #'loss'

In [3]:
from tensorflow import keras
from tensorflow.keras import layers

In [4]:
if INITIALIZATION == 'he':
    weight_initializer = keras.initializers.he_uniform()
if INITIALIZATION == 'glorot':
    weight_initializer = keras.initializers.glorot_uniform()
bias_initializer = keras.initializers.zeros()

In [5]:
#class WGAN:
 #   def __init__()self
latent_sample = keras.Input(shape=(LATENT_DIM,))
hidden_layer = latent_sample
for i in range(3):
    hidden_layer = layers.Dense(DIM, activation="relu", kernel_initializer=weight_initializer,
        bias_initializer=bias_initializer)(hidden_layer)
output = layers.Dense(DATA_DIM, kernel_initializer=weight_initializer,
                      bias_initializer=bias_initializer)(hidden_layer)
generator = keras.Model(inputs=latent_sample, outputs = output, name="generator")
print('Generator:')
print(generator.summary())

Generator:
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 64)                320       
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 32)                2080      
Total params: 10,720
Trainable params: 10,720
Non-trainable params: 0
_________________________________________________________________
None


In [6]:
from scipy.special import logsumexp

In [7]:
def squared_l2_matrix_np(X, Y):
    squared_l2_X = np.sum(np.square(X), axis = 1, keepdims=True)
    squared_l2_Y = np.sum(np.square(Y), axis = 1, keepdims=True)
    XY = X.dot(Y.T)
    return squared_l2_X + tf.transpose(squared_l2_Y) - 2 * XY

def log_coupling_np(psi_X, psi_Y, cost_matrix, epsilon):
    C_tild = cost_matrix - np.expand_dims(psi_X, axis=1) - np.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon

def sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
    psi_X_upd = epsilon * (log_X_prob - \
                       logsumexp((np.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1))
    psi_Y_upd = epsilon * (log_Y_prob - \
                       logsumexp((np.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0))
    if return_diff:
        diff = np.linalg.norm(psi_X_upd - psi_X) + np.linalg.norm(psi_Y_upd - psi_Y)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd
    if return_diff:
        return psi_X, psi_Y, diff
    return psi_X, psi_Y

def sinkhorn_loss_np(X, Y, epsilon, num_steps):
    cost_matrix = squared_l2_matrix_np(X,Y)
    log_X_prob = - np.log(X.shape[0])
    log_Y_prob = - np.log(Y.shape[0])
    psi_Y = np.zeros(Y.shape[0])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step_np(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step_np(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    log_pi = log_coupling_np(psi_X, psi_Y, cost_matrix, epsilon)
    pi = np.exp(log_pi)
    loss = np.sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
    #X and Y are uniform, so entropy = -log(probability)
    return loss, diff

In [8]:
def squared_l2_matrix(X, Y):
    squared_l2_X = tf.reduce_sum(tf.square(X), axis = 1, keepdims=True)
    squared_l2_Y = tf.reduce_sum(tf.square(Y), axis = 1, keepdims=True)
    XY = tf.matmul(X,tf.transpose(Y))
    return squared_l2_X + tf.transpose(squared_l2_Y) - 2 * XY

def log_coupling(psi_X, psi_Y, cost_matrix, epsilon):
    C_tild = cost_matrix - tf.expand_dims(psi_X, axis=1) - tf.expand_dims(psi_Y, axis=0)
    return -C_tild/epsilon

def sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon, return_diff = False):
    psi_X_upd = epsilon * (log_X_prob - \
                       tf.reduce_logsumexp((tf.expand_dims(psi_Y, axis=0) - cost_matrix) / epsilon, axis = 1))
    psi_Y_upd = epsilon * (log_Y_prob - \
                       tf.reduce_logsumexp((tf.expand_dims(psi_X_upd, axis=1) - cost_matrix)/epsilon,axis = 0))
    if return_diff:
        diff = tf.norm(psi_X_upd - psi_X) + tf.norm(psi_Y_upd - psi_Y)
    psi_X = psi_X_upd
    psi_Y = psi_Y_upd
    if return_diff:
        return psi_X, psi_Y, diff
    return psi_X, psi_Y

def sinkhorn_loss(X, Y, epsilon, num_steps, return_diff = False, return_diff_only = False):
    cost_matrix = squared_l2_matrix(X,Y)
    log_X_prob = - tf.math.log(tf.cast(tf.shape(X)[0], tf.float32))
    log_Y_prob = - tf.math.log(tf.cast(tf.shape(Y)[0], tf.float32))
    psi_Y = tf.zeros([tf.shape(Y)[0]])
    for l in range(num_steps-1):
        psi_X, psi_Y = sinkhorn_step(None, psi_Y, log_X_prob, log_Y_prob, cost_matrix, epsilon)
    psi_X, psi_Y, diff = sinkhorn_step(psi_X, psi_Y, log_X_prob, log_Y_prob, \
                                       cost_matrix, epsilon, return_diff = True)
    #return psi_X, psi_Y, diff
    log_pi = log_coupling(psi_X, psi_Y, cost_matrix, epsilon)
    pi = tf.exp(log_pi)
    loss = tf.reduce_sum(pi*(cost_matrix+epsilon*log_pi)) - epsilon*(log_X_prob + log_Y_prob)
    #X and Y are uniform, so entropy = -log(probability)
    if return_diff_only:
        return diff
    if return_diff:
        return loss, diff
    return loss

In [9]:
epsilon = LAMBDA
if MODE == 'divergence':
    model_loss = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS) - 0.5*\
        (sinkhorn_loss(x, x, epsilon, CRITIC_ITERS)-sinkhorn_loss(y, y, epsilon, CRITIC_ITERS))
else:
    model_loss = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS)

if MODE == 'divergence':
    args = [epsilon, CRITIC_ITERS, True, True]
    sinkhorn_diff = lambda x, y: (sinkhorn_loss(x, y, *args)\
        +sinkhorn_loss(x, x, *args)+sinkhorn_loss(y, y, *args))/3
else:
    sinkhorn_diff = lambda x, y: sinkhorn_loss(x, y, epsilon, CRITIC_ITERS)

metrics = lambda x, y: tf.norm(tfp.stats.covariance(x) - tfp.stats.covariance(y))
generator.compile(optimizer="Adam", loss=model_loss, metrics=metrics)

In [10]:
def inf_train_gen():
    np.random.seed(1)
    full_dataset = np.random.randn(SAMPLE_SIZE,DATA_DIM) / np.sqrt(COVARIANCE_SCALE) 
    i = 0
    offset = 0
    while True:
        dataset = full_dataset[i*BATCH_SIZE+offset:(i+1)*BATCH_SIZE+offset,:]
        if (i+1)*BATCH_SIZE+offset > SAMPLE_SIZE: 
            offset = (i+1)*BATCH_SIZE+offset - SAMPLE_SIZE
            np.random.shuffle(full_dataset)
            dataset = np.concatenate([dataset,full_dataset[:offset,:]], axis = 0)
            i = -1 
        i+=1
        yield dataset
data_gen = inf_train_gen()

In [12]:
y_batch = next(data_gen)
x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
res_train = {'loss':{}, 'cov_diff' : {}}
res_test = {'loss':{}, 'cov_diff' : {}, 'sample':{}, 'sink_eps':{}}
for i in range(ITERS):
    out = generator.train_on_batch(x = x_batch,y = y_batch)
    res_train['loss'][i] = out[0]
    res_train['cov_diff'][i] = out[1]
    y_batch = next(data_gen)
    x_batch = np.random.normal(size=(y_batch.shape[0], LATENT_DIM))
    #validate on the next data
    y_sample = generator.predict(x_batch, batch_size=BATCH_SIZE)
    res_test['loss'][i], res_test['sink_eps'][i] = sinkhorn_loss_np(y_sample, y_batch, epsilon, CRITIC_ITERS)
    res_test['sample'][i] = y_sample
    res_test['cov_diff'][i] = np.linalg.norm(np.cov(y_sample[:256,:]) - np.cov(y_batch[:256,:]))
    print('Iteration ', i)
    print('Training: loss {}, covariance difference {}'.format(res_train['loss'][i], res_train['cov_diff'][i]))
    print('Validation: loss {}, covariance difference {}, sinkhorn epsilon {}'.format(
        res_test['loss'][i], res_test['cov_diff'][i], res_test['sink_eps'][i]))
    if i % 10 == 0:
        with open('logs.pkl', 'wb') as f:
            pkl.dump([res_test, res_train], f)

Iteration  0
Training: loss 4.678339958190918, covariance difference 1.0191571712493896
Validation: loss 5.60212407967657, covariance difference 8.62967076942069, sinkhorn epsilon 0.00026994421971048867
Iteration  1
Training: loss 4.577845573425293, covariance difference 1.0029255151748657
Validation: loss 5.58790516611306, covariance difference 8.460720479334046, sinkhorn epsilon 3.248205632779033e-11
Iteration  2
Training: loss 4.562339782714844, covariance difference 0.9992423057556152
Validation: loss 5.581281204060176, covariance difference 8.395882684734616, sinkhorn epsilon 8.801552625700866e-08
Iteration  3
Training: loss 4.552487373352051, covariance difference 0.9983996748924255
Validation: loss 5.6549487054529095, covariance difference 8.494231741067493, sinkhorn epsilon 3.8858213705074945e-10
Iteration  4
Training: loss 4.627856254577637, covariance difference 1.0117124319076538
Validation: loss 5.718280632689711, covariance difference 8.689667058771258, sinkhorn epsilon 2.

Iteration  41
Training: loss 4.669532775878906, covariance difference 1.0246940851211548
Validation: loss 5.604947636871195, covariance difference 8.266364896379814, sinkhorn epsilon 3.4610760049220496e-14
Iteration  42
Training: loss 4.5683135986328125, covariance difference 1.0040119886398315
Validation: loss 5.616511702321495, covariance difference 8.543063073719352, sinkhorn epsilon 1.2189317731430368e-14
Iteration  43
Training: loss 4.579812049865723, covariance difference 1.0052196979522705
Validation: loss 5.656889041374058, covariance difference 8.915047515354901, sinkhorn epsilon 2.346716343773844e-13
Iteration  44
Training: loss 4.620477199554443, covariance difference 1.0139461755752563
Validation: loss 5.708619079343099, covariance difference 8.713153250121515, sinkhorn epsilon 2.0353743076872016e-14
Iteration  45
Training: loss 4.671381950378418, covariance difference 1.0230929851531982
Validation: loss 5.748954251049458, covariance difference 8.776916813466269, sinkhorn e

Iteration  81
Training: loss 4.637916564941406, covariance difference 1.0177866220474243
Validation: loss 5.695385863917938, covariance difference 8.727246345844677, sinkhorn epsilon 3.417184930139891e-14
Iteration  82
Training: loss 4.656761646270752, covariance difference 1.0209609270095825
Validation: loss 5.687597215367942, covariance difference 8.636706367598565, sinkhorn epsilon 2.590385200531446e-15
Iteration  83
Training: loss 4.6490960121154785, covariance difference 1.020076870918274
Validation: loss 5.686344790644112, covariance difference 8.613278395725294, sinkhorn epsilon 1.334607041491032e-15
Iteration  84
Training: loss 4.647243976593018, covariance difference 1.0221680402755737
Validation: loss 5.638856645331748, covariance difference 8.513427927873858, sinkhorn epsilon 3.088795318704387e-14
Iteration  85
Training: loss 4.600470542907715, covariance difference 1.012421727180481
Validation: loss 5.570053248105544, covariance difference 8.530423130247806, sinkhorn epsilo

Iteration  122
Training: loss 4.49701452255249, covariance difference 0.9938242435455322
Validation: loss 5.681553071730638, covariance difference 8.577952459460251, sinkhorn epsilon 0.0
Iteration  123
Training: loss 4.642490386962891, covariance difference 1.0172357559204102
Validation: loss 5.715524533258426, covariance difference 8.90490937148199, sinkhorn epsilon 0.0
Iteration  124
Training: loss 4.676111221313477, covariance difference 1.024778127670288
Validation: loss 5.661212502941945, covariance difference 8.758997182520309, sinkhorn epsilon 4.044954707664469e-14
Iteration  125
Training: loss 4.621967792510986, covariance difference 1.0155813694000244
Validation: loss 5.631757943524001, covariance difference 8.612479413222795, sinkhorn epsilon 3.3864672621763756e-14
Iteration  126
Training: loss 4.592863082885742, covariance difference 1.008902907371521
Validation: loss 5.79035310409523, covariance difference 8.729532365236244, sinkhorn epsilon 2.7755575615628914e-16
Iteration

Iteration  163
Training: loss 4.615317344665527, covariance difference 1.0150104761123657
Validation: loss 5.660746729425448, covariance difference 8.77425165742974, sinkhorn epsilon 3.571861599547592e-14
Iteration  164
Training: loss 4.621325492858887, covariance difference 1.0168911218643188
Validation: loss 5.729930953070834, covariance difference 8.980283172904025, sinkhorn epsilon 0.0
Iteration  165
Training: loss 4.690426826477051, covariance difference 1.0270447731018066
Validation: loss 5.682201181822675, covariance difference 8.577154178310003, sinkhorn epsilon 1.4790793018918976e-15
Iteration  166
Training: loss 4.642974376678467, covariance difference 1.0194635391235352
Validation: loss 5.632280210565149, covariance difference 8.694287109425526, sinkhorn epsilon 0.0
Iteration  167
Training: loss 4.592820167541504, covariance difference 1.0085008144378662
Validation: loss 5.648914007794456, covariance difference 8.663784009390676, sinkhorn epsilon 3.808532904460161e-15
Iterat

Iteration  205
Training: loss 4.638387203216553, covariance difference 1.0186784267425537
Validation: loss 5.631267731330201, covariance difference 8.565605957946905, sinkhorn epsilon 0.0
Iteration  206
Training: loss 4.59186315536499, covariance difference 1.0104889869689941
Validation: loss 5.671927685031521, covariance difference 8.723095612964235, sinkhorn epsilon 0.0
Iteration  207
Training: loss 4.632365703582764, covariance difference 1.0160123109817505
Validation: loss 5.594455173023865, covariance difference 8.523721884270346, sinkhorn epsilon 1.6883935406531138e-14
Iteration  208
Training: loss 4.554947376251221, covariance difference 1.003982424736023
Validation: loss 5.717117167364895, covariance difference 8.588390175761752, sinkhorn epsilon 1.509597619914649e-14
Iteration  209
Training: loss 4.677537441253662, covariance difference 1.0265024900436401
Validation: loss 5.671721115350661, covariance difference 8.57556652490919, sinkhorn epsilon 0.0
Iteration  210
Training: l

Iteration  248
Training: loss 4.6318769454956055, covariance difference 1.0186843872070312
Validation: loss 5.654061651197998, covariance difference 8.783468862733022, sinkhorn epsilon 0.0
Iteration  249
Training: loss 4.614522933959961, covariance difference 1.0138154029846191
Validation: loss 5.6815384812616205, covariance difference 8.82592615553488, sinkhorn epsilon 0.0
Iteration  250
Training: loss 4.641951560974121, covariance difference 1.0189980268478394
Validation: loss 5.69497785202163, covariance difference 8.660216480978043, sinkhorn epsilon 2.7755575615628914e-16
Iteration  251
Training: loss 4.655397415161133, covariance difference 1.021240472793579
Validation: loss 5.597419341068727, covariance difference 8.335463259036404, sinkhorn epsilon 0.0
Iteration  252
Training: loss 4.557807445526123, covariance difference 1.0023467540740967
Validation: loss 5.640736576528682, covariance difference 8.650619968226842, sinkhorn epsilon 0.0
Iteration  253
Training: loss 4.6011838912

Iteration  290
Training: loss 4.557942867279053, covariance difference 1.0034726858139038
Validation: loss 5.708275632379922, covariance difference 8.685317176686425, sinkhorn epsilon 0.0
Iteration  291
Training: loss 4.668619632720947, covariance difference 1.0245782136917114
Validation: loss 5.65262525742923, covariance difference 8.43408835069422, sinkhorn epsilon 2.020662463048932e-14
Iteration  292
Training: loss 4.612985134124756, covariance difference 1.0135142803192139
Validation: loss 5.726120526287836, covariance difference 8.752543376608795, sinkhorn epsilon 0.0
Iteration  293
Training: loss 4.686507225036621, covariance difference 1.0283892154693604
Validation: loss 5.543553467073899, covariance difference 8.531758017688926, sinkhorn epsilon 7.02758602885423e-14
Iteration  294
Training: loss 4.503936767578125, covariance difference 0.9947043657302856
Validation: loss 5.76126001072569, covariance difference 8.833920030718106, sinkhorn epsilon 0.0
Iteration  295
Training: los

Iteration  333
Training: loss 4.641175270080566, covariance difference 1.018979787826538
Validation: loss 5.6315552528710215, covariance difference 8.35673373411183, sinkhorn epsilon 0.0
Iteration  334
Training: loss 4.591884613037109, covariance difference 1.0118480920791626
Validation: loss 5.69756154863766, covariance difference 8.657832775534692, sinkhorn epsilon 0.0
Iteration  335
Training: loss 4.657880783081055, covariance difference 1.0222474336624146
Validation: loss 5.656895129250253, covariance difference 8.513175572059916, sinkhorn epsilon 0.0
Iteration  336
Training: loss 4.617242813110352, covariance difference 1.0152517557144165
Validation: loss 5.708651447765379, covariance difference 8.793573406942299, sinkhorn epsilon 3.0087887424417924e-14
Iteration  337
Training: loss 4.669002532958984, covariance difference 1.024549961090088
Validation: loss 5.628297718484702, covariance difference 8.875060853697088, sinkhorn epsilon 0.0
Iteration  338
Training: loss 4.588665008544

Iteration  376
Training: loss 4.576467037200928, covariance difference 1.0076779127120972
Validation: loss 5.638075502776429, covariance difference 8.359610556371349, sinkhorn epsilon 0.0
Iteration  377
Training: loss 4.598459243774414, covariance difference 1.0105631351470947
Validation: loss 5.659557350333561, covariance difference 8.67711748132375, sinkhorn epsilon 1.2604317752140005e-14
Iteration  378
Training: loss 4.619948387145996, covariance difference 1.0136280059814453
Validation: loss 5.683719491359874, covariance difference 8.585086915499096, sinkhorn epsilon 0.0
Iteration  379
Training: loss 4.644008159637451, covariance difference 1.0210132598876953
Validation: loss 5.6613586274156855, covariance difference 8.39235800234239, sinkhorn epsilon 2.235463423612763e-14
Iteration  380
Training: loss 4.621697425842285, covariance difference 1.0162386894226074
Validation: loss 5.692022619870388, covariance difference 8.764671549627604, sinkhorn epsilon 2.2270758453788372e-14
Itera

Iteration  419
Training: loss 4.5656962394714355, covariance difference 1.0055696964263916
Validation: loss 5.604607486700438, covariance difference 8.416787883206146, sinkhorn epsilon 6.797396266227845e-15
Iteration  420
Training: loss 4.564939022064209, covariance difference 1.0051424503326416
Validation: loss 5.662804182052431, covariance difference 8.77162787949235, sinkhorn epsilon 0.0
Iteration  421
Training: loss 4.623089790344238, covariance difference 1.0161088705062866
Validation: loss 5.755954083997851, covariance difference 8.672152218058038, sinkhorn epsilon 0.0
Iteration  422
Training: loss 4.716254711151123, covariance difference 1.0317517518997192
Validation: loss 5.6672862891937426, covariance difference 8.790623434566793, sinkhorn epsilon 0.0
Iteration  423
Training: loss 4.627608776092529, covariance difference 1.0174018144607544
Validation: loss 5.684734375826005, covariance difference 8.720558819005296, sinkhorn epsilon 3.6573900349046364e-14
Iteration  424
Trainin

Iteration  462
Training: loss 4.6353349685668945, covariance difference 1.0186821222305298
Validation: loss 5.665833072220584, covariance difference 8.584670710224618, sinkhorn epsilon 0.0
Iteration  463
Training: loss 4.626162528991699, covariance difference 1.0168852806091309
Validation: loss 5.713864720562167, covariance difference 8.760242753268546, sinkhorn epsilon 4.763515635410573e-14
Iteration  464
Training: loss 4.674139976501465, covariance difference 1.0249325037002563
Validation: loss 5.694390686850541, covariance difference 8.699009520293156, sinkhorn epsilon 0.0
Iteration  465
Training: loss 4.654669761657715, covariance difference 1.020561695098877
Validation: loss 5.6371355296812, covariance difference 8.786578274027917, sinkhorn epsilon 0.0
Iteration  466
Training: loss 4.597466945648193, covariance difference 1.0099207162857056
Validation: loss 5.578419110333762, covariance difference 8.460721328916598, sinkhorn epsilon 0.0
Iteration  467
Training: loss 4.538693904876

Iteration  505
Training: loss 4.648674964904785, covariance difference 1.01896071434021
Validation: loss 5.5780977720477924, covariance difference 8.464779849132116, sinkhorn epsilon 0.0
Iteration  506
Training: loss 4.538417816162109, covariance difference 0.9997079372406006
Validation: loss 5.64992700234092, covariance difference 8.462382307449936, sinkhorn epsilon 0.0
Iteration  507
Training: loss 4.610258102416992, covariance difference 1.0121285915374756
Validation: loss 5.713109637360956, covariance difference 8.70979034451732, sinkhorn epsilon 0.0
Iteration  508
Training: loss 4.673429489135742, covariance difference 1.0248637199401855
Validation: loss 5.613147180930433, covariance difference 8.363876334731849, sinkhorn epsilon 0.0
Iteration  509
Training: loss 4.573431015014648, covariance difference 1.0079879760742188
Validation: loss 5.600694270815746, covariance difference 8.605542765635356, sinkhorn epsilon 0.0
Iteration  510
Training: loss 4.561015605926514, covariance dif

Iteration  548
Training: loss 4.622625350952148, covariance difference 1.015315294265747
Validation: loss 5.718545000259541, covariance difference 8.756841621122877, sinkhorn epsilon 1.9403103940978842e-14
Iteration  549
Training: loss 4.678881645202637, covariance difference 1.0254535675048828
Validation: loss 5.618414321028862, covariance difference 8.489007717441478, sinkhorn epsilon 2.7755575615628914e-16
Iteration  550
Training: loss 4.578693389892578, covariance difference 1.0080944299697876
Validation: loss 5.682071824480224, covariance difference 8.699044007700392, sinkhorn epsilon 0.0
Iteration  551
Training: loss 4.642379283905029, covariance difference 1.0197913646697998
Validation: loss 5.648935943675383, covariance difference 8.50240321480629, sinkhorn epsilon 0.0
Iteration  552
Training: loss 4.609225749969482, covariance difference 1.0139063596725464
Validation: loss 5.7037929541307975, covariance difference 8.841863620880117, sinkhorn epsilon 0.0
Iteration  553
Training

Iteration  591
Training: loss 4.644547462463379, covariance difference 1.020066499710083
Validation: loss 5.607548423641105, covariance difference 8.535403497014173, sinkhorn epsilon 0.0
Iteration  592
Training: loss 4.567818641662598, covariance difference 1.0061454772949219
Validation: loss 5.693768827505476, covariance difference 8.692009140345114, sinkhorn epsilon 0.0
Iteration  593
Training: loss 4.654041290283203, covariance difference 1.020741581916809
Validation: loss 5.664134757219619, covariance difference 8.392552415955748, sinkhorn epsilon 0.0
Iteration  594
Training: loss 4.624410629272461, covariance difference 1.0197947025299072
Validation: loss 5.652229337606033, covariance difference 8.503686814890486, sinkhorn epsilon 0.0
Iteration  595
Training: loss 4.612526893615723, covariance difference 1.0102659463882446
Validation: loss 5.672295268028156, covariance difference 8.740303688812844, sinkhorn epsilon 0.0
Iteration  596
Training: loss 4.632616996765137, covariance di

Iteration  634
Training: loss 4.662838935852051, covariance difference 1.0224896669387817
Validation: loss 5.65793251115006, covariance difference 8.489466766056559, sinkhorn epsilon 6.786623833034424e-14
Iteration  635
Training: loss 4.618200778961182, covariance difference 1.0153669118881226
Validation: loss 5.645265026936673, covariance difference 8.653193007894714, sinkhorn epsilon 0.0
Iteration  636
Training: loss 4.605588436126709, covariance difference 1.0121148824691772
Validation: loss 5.588026652120675, covariance difference 8.87560630814199, sinkhorn epsilon 3.982691198160024e-15
Iteration  637
Training: loss 4.548305988311768, covariance difference 1.0046846866607666
Validation: loss 5.531289339878679, covariance difference 8.369632822454488, sinkhorn epsilon 0.0
Iteration  638
Training: loss 4.491608142852783, covariance difference 0.9918979406356812
Validation: loss 5.666717929306085, covariance difference 8.613922415291121, sinkhorn epsilon 5.1982670611394335e-15
Iterati

Iteration  677
Training: loss 4.71079158782959, covariance difference 1.0335195064544678
Validation: loss 5.5877938237111895, covariance difference 8.617967904283434, sinkhorn epsilon 0.0
Iteration  678
Training: loss 4.548069000244141, covariance difference 1.0023391246795654
Validation: loss 5.634497502016245, covariance difference 8.692365940944084, sinkhorn epsilon 0.0
Iteration  679
Training: loss 4.5947723388671875, covariance difference 1.0088133811950684
Validation: loss 5.677589618476334, covariance difference 8.67566282506595, sinkhorn epsilon 0.0
Iteration  680
Training: loss 4.637866020202637, covariance difference 1.0179439783096313
Validation: loss 5.652958114610573, covariance difference 8.64565542210968, sinkhorn epsilon 0.0
Iteration  681
Training: loss 4.613238334655762, covariance difference 1.013990879058838
Validation: loss 5.724999399593418, covariance difference 9.070997798398805, sinkhorn epsilon 0.0
Iteration  682
Training: loss 4.685274124145508, covariance di

Iteration  720
Training: loss 4.577469825744629, covariance difference 1.0071486234664917
Validation: loss 5.6953209866961965, covariance difference 9.022040680119439, sinkhorn epsilon 3.6594278101599136e-14
Iteration  721
Training: loss 4.655596733093262, covariance difference 1.024678349494934
Validation: loss 5.669439687965394, covariance difference 8.553174063973639, sinkhorn epsilon 0.0
Iteration  722
Training: loss 4.629772186279297, covariance difference 1.0170814990997314
Validation: loss 5.716074009494311, covariance difference 8.645833344575834, sinkhorn epsilon 1.5611908171216087e-14
Iteration  723
Training: loss 4.6763482093811035, covariance difference 1.0251007080078125
Validation: loss 5.6193635986565456, covariance difference 8.358245535156243, sinkhorn epsilon 0.0
Iteration  724
Training: loss 4.579641342163086, covariance difference 1.007895827293396
Validation: loss 5.6353131466551565, covariance difference 8.841804821364958, sinkhorn epsilon 0.0
Iteration  725
Train

Iteration  763
Training: loss 4.592255592346191, covariance difference 1.0125632286071777
Validation: loss 5.652887839129089, covariance difference 8.57956132676302, sinkhorn epsilon 0.0
Iteration  764
Training: loss 4.613192558288574, covariance difference 1.013480544090271
Validation: loss 5.597637562757393, covariance difference 8.637414122661776, sinkhorn epsilon 0.0
Iteration  765
Training: loss 4.557921409606934, covariance difference 1.003028392791748
Validation: loss 5.716067323693102, covariance difference 8.859470534700307, sinkhorn epsilon 0.0
Iteration  766
Training: loss 4.676337242126465, covariance difference 1.0258656740188599
Validation: loss 5.5384208286234236, covariance difference 8.729822910785884, sinkhorn epsilon 0.0
Iteration  767
Training: loss 4.4986982345581055, covariance difference 0.9943745136260986
Validation: loss 5.740454755758909, covariance difference 8.780625811782476, sinkhorn epsilon 0.0
Iteration  768
Training: loss 4.70072603225708, covariance di

Iteration  807
Training: loss 4.5766472816467285, covariance difference 1.006906509399414
Validation: loss 5.758333156687272, covariance difference 9.108014981170673, sinkhorn epsilon 0.0
Iteration  808
Training: loss 4.718606948852539, covariance difference 1.034880518913269
Validation: loss 5.6986514132227795, covariance difference 8.543886756685819, sinkhorn epsilon 3.603097807722504e-14
Iteration  809
Training: loss 4.6589226722717285, covariance difference 1.0234177112579346
Validation: loss 5.644554264046159, covariance difference 8.449469334108594, sinkhorn epsilon 0.0
Iteration  810
Training: loss 4.604824066162109, covariance difference 1.011985421180725
Validation: loss 5.644088443473453, covariance difference 8.473394741157755, sinkhorn epsilon 0.0
Iteration  811
Training: loss 4.604368209838867, covariance difference 1.0114836692810059
Validation: loss 5.611257129939268, covariance difference 8.504760705504252, sinkhorn epsilon 0.0
Iteration  812
Training: loss 4.5715308189

Iteration  850
Training: loss 4.621614456176758, covariance difference 1.015218734741211
Validation: loss 5.634169089374057, covariance difference 8.674514315770754, sinkhorn epsilon 0.0
Iteration  851
Training: loss 4.594452857971191, covariance difference 1.0114357471466064
Validation: loss 5.673894531174916, covariance difference 8.58177844511117, sinkhorn epsilon 2.6970454582184114e-14
Iteration  852
Training: loss 4.634173393249512, covariance difference 1.0163490772247314
Validation: loss 5.585145484658286, covariance difference 8.585097226059442, sinkhorn epsilon 6.842870371144435e-14
Iteration  853
Training: loss 4.545416355133057, covariance difference 1.0022469758987427
Validation: loss 5.661196648686355, covariance difference 8.489750994932953, sinkhorn epsilon 0.0
Iteration  854
Training: loss 4.62147331237793, covariance difference 1.0147473812103271
Validation: loss 5.621094120995642, covariance difference 8.44497551607029, sinkhorn epsilon 0.0
Iteration  855
Training: lo

Iteration  893
Training: loss 4.622145652770996, covariance difference 1.0165680646896362
Validation: loss 5.705790089303063, covariance difference 8.748813336926345, sinkhorn epsilon 2.2814500189555958e-14
Iteration  894
Training: loss 4.666060447692871, covariance difference 1.0235645771026611
Validation: loss 5.662135629587406, covariance difference 8.7827383177305, sinkhorn epsilon 0.0
Iteration  895
Training: loss 4.622405052185059, covariance difference 1.0166373252868652
Validation: loss 5.730009144237516, covariance difference 8.67756779694609, sinkhorn epsilon 0.0
Iteration  896
Training: loss 4.690277099609375, covariance difference 1.0281686782836914
Validation: loss 5.671115240901193, covariance difference 8.859500023230126, sinkhorn epsilon 0.0
Iteration  897
Training: loss 4.631386756896973, covariance difference 1.01646888256073
Validation: loss 5.648204024855258, covariance difference 8.317287849408553, sinkhorn epsilon 0.0
Iteration  898
Training: loss 4.60847520828247

Iteration  936
Training: loss 4.5516510009765625, covariance difference 1.0045084953308105
Validation: loss 5.7029256391449525, covariance difference 8.900027854786192, sinkhorn epsilon 0.0
Iteration  937
Training: loss 4.6631975173950195, covariance difference 1.0254300832748413
Validation: loss 5.658807935049827, covariance difference 8.600156315936976, sinkhorn epsilon 0.0
Iteration  938
Training: loss 4.619077682495117, covariance difference 1.0165013074874878
Validation: loss 5.639578987280611, covariance difference 8.511683644462916, sinkhorn epsilon 0.0
Iteration  939
Training: loss 4.599847793579102, covariance difference 1.011644959449768
Validation: loss 5.65482185856853, covariance difference 8.667030506397992, sinkhorn epsilon 0.0
Iteration  940
Training: loss 4.61510705947876, covariance difference 1.0149191617965698
Validation: loss 5.6668040485672195, covariance difference 8.76300692931336, sinkhorn epsilon 4.847062132791625e-14
Iteration  941
Training: loss 4.6270718574

Iteration  980
Training: loss 4.578217029571533, covariance difference 1.0084819793701172
Validation: loss 5.5400989050203595, covariance difference 8.222862533261877, sinkhorn epsilon 0.0
Iteration  981
Training: loss 4.500354290008545, covariance difference 0.9928858280181885
Validation: loss 5.692817496878149, covariance difference 8.708092123081945, sinkhorn epsilon 0.0
Iteration  982
Training: loss 4.6530842781066895, covariance difference 1.0210998058319092
Validation: loss 5.604693090980367, covariance difference 8.384117940462884, sinkhorn epsilon 0.0
Iteration  983
Training: loss 4.564962863922119, covariance difference 1.005588173866272
Validation: loss 5.696536259170599, covariance difference 8.855763826225946, sinkhorn epsilon 0.0
Iteration  984
Training: loss 4.656809329986572, covariance difference 1.0222846269607544
Validation: loss 5.677610215696717, covariance difference 8.564993569250673, sinkhorn epsilon 0.0
Iteration  985
Training: loss 4.6378960609436035, covarianc

Iteration  1023
Training: loss 4.636580944061279, covariance difference 1.019531011581421
Validation: loss 5.601540508376425, covariance difference 8.774568438443074, sinkhorn epsilon 0.0
Iteration  1024
Training: loss 4.56181001663208, covariance difference 1.006799340248108
Validation: loss 5.699966008622253, covariance difference 8.670658495159044, sinkhorn epsilon 0.0
Iteration  1025
Training: loss 4.66023588180542, covariance difference 1.0224114656448364
Validation: loss 5.577486693665994, covariance difference 8.503767484561449, sinkhorn epsilon 0.0
Iteration  1026
Training: loss 4.537755489349365, covariance difference 1.0012966394424438
Validation: loss 5.621157956614452, covariance difference 8.48003541314291, sinkhorn epsilon 5.898662910140472e-14
Iteration  1027
Training: loss 4.581439018249512, covariance difference 1.0076892375946045
Validation: loss 5.622528522316425, covariance difference 8.572985038364596, sinkhorn epsilon 0.0
Iteration  1028
Training: loss 4.582797050

Iteration  1066
Training: loss 4.624094486236572, covariance difference 1.0163718461990356
Validation: loss 5.70084101123124, covariance difference 8.845200174103859, sinkhorn epsilon 0.0
Iteration  1067
Training: loss 4.6611223220825195, covariance difference 1.024141788482666
Validation: loss 5.633729576484183, covariance difference 8.457465121455481, sinkhorn epsilon 0.0
Iteration  1068
Training: loss 4.594005107879639, covariance difference 1.0101025104522705
Validation: loss 5.595378693960927, covariance difference 8.641425210666416, sinkhorn epsilon 0.0
Iteration  1069
Training: loss 4.555655479431152, covariance difference 1.0040931701660156
Validation: loss 5.643443042345563, covariance difference 8.651667247117388, sinkhorn epsilon 0.0
Iteration  1070
Training: loss 4.603716850280762, covariance difference 1.0124560594558716
Validation: loss 5.665334459051882, covariance difference 8.635645789925398, sinkhorn epsilon 0.0
Iteration  1071
Training: loss 4.625603199005127, covari

Iteration  1110
Training: loss 4.604394912719727, covariance difference 1.012613296508789
Validation: loss 5.686685676389507, covariance difference 8.577983990659597, sinkhorn epsilon 0.0
Iteration  1111
Training: loss 4.646954536437988, covariance difference 1.0215296745300293
Validation: loss 5.6188506761534, covariance difference 8.576957396613532, sinkhorn epsilon 0.0
Iteration  1112
Training: loss 4.579127788543701, covariance difference 1.008419156074524
Validation: loss 5.66100841070106, covariance difference 8.668377095340999, sinkhorn epsilon 0.0
Iteration  1113
Training: loss 4.621282577514648, covariance difference 1.0159430503845215
Validation: loss 5.672412199186362, covariance difference 8.592215499544036, sinkhorn epsilon 6.286383977041911e-14
Iteration  1114
Training: loss 4.632684230804443, covariance difference 1.0173498392105103
Validation: loss 5.616968810576147, covariance difference 8.394047957208754, sinkhorn epsilon 0.0
Iteration  1115
Training: loss 4.577238082

Iteration  1153
Training: loss 4.6614885330200195, covariance difference 1.0225567817687988
Validation: loss 5.728278812146346, covariance difference 8.856427340845944, sinkhorn epsilon 0.0
Iteration  1154
Training: loss 4.688549995422363, covariance difference 1.0274441242218018
Validation: loss 5.633172130762212, covariance difference 8.580349877034667, sinkhorn epsilon 0.0
Iteration  1155
Training: loss 4.593445301055908, covariance difference 1.01112699508667
Validation: loss 5.677700062548228, covariance difference 8.592948121773476, sinkhorn epsilon 0.0
Iteration  1156
Training: loss 4.637969493865967, covariance difference 1.0187788009643555
Validation: loss 5.743401034919797, covariance difference 8.8364278796227, sinkhorn epsilon 1.5277724705362747e-14
Iteration  1157
Training: loss 4.703670501708984, covariance difference 1.0306365489959717
Validation: loss 5.659024686077693, covariance difference 8.579837076409884, sinkhorn epsilon 0.0
Iteration  1158
Training: loss 4.619292

Iteration  1196
Training: loss 4.533526420593262, covariance difference 1.0002994537353516
Validation: loss 5.718500282746829, covariance difference 8.833450283855043, sinkhorn epsilon 0.0
Iteration  1197
Training: loss 4.678768157958984, covariance difference 1.0257781744003296
Validation: loss 5.681251382685727, covariance difference 8.802909648004977, sinkhorn epsilon 0.0
Iteration  1198
Training: loss 4.641514778137207, covariance difference 1.0193532705307007
Validation: loss 5.6730154113155855, covariance difference 8.597692796878285, sinkhorn epsilon 0.0
Iteration  1199
Training: loss 4.633269309997559, covariance difference 1.0176273584365845
Validation: loss 5.585898425729566, covariance difference 8.664925956976617, sinkhorn epsilon 0.0
Iteration  1200
Training: loss 4.546170234680176, covariance difference 1.0027780532836914
Validation: loss 5.5742196485671345, covariance difference 8.528289345679843, sinkhorn epsilon 0.0
Iteration  1201
Training: loss 4.534493923187256, cov

Iteration  1240
Training: loss 4.719635963439941, covariance difference 1.032249927520752
Validation: loss 5.7026895476476875, covariance difference 8.65081739168725, sinkhorn epsilon 0.0
Iteration  1241
Training: loss 4.66295862197876, covariance difference 1.0237220525741577
Validation: loss 5.722107089185521, covariance difference 8.841368504672234, sinkhorn epsilon 0.0
Iteration  1242
Training: loss 4.682374954223633, covariance difference 1.027064561843872
Validation: loss 5.696233058529697, covariance difference 8.719805068921582, sinkhorn epsilon 0.0
Iteration  1243
Training: loss 4.656500816345215, covariance difference 1.021913766860962
Validation: loss 5.612383144372556, covariance difference 8.724016743066274, sinkhorn epsilon 0.0
Iteration  1244
Training: loss 4.572652816772461, covariance difference 1.0078604221343994
Validation: loss 5.670637978243449, covariance difference 8.524755529120151, sinkhorn epsilon 0.0
Iteration  1245
Training: loss 4.6309099197387695, covarian

Iteration  1284
Training: loss 4.674665451049805, covariance difference 1.0250548124313354
Validation: loss 5.691184023975983, covariance difference 8.530213656108462, sinkhorn epsilon 0.0
Iteration  1285
Training: loss 4.651453018188477, covariance difference 1.0206142663955688
Validation: loss 5.707798038420781, covariance difference 8.636442667399969, sinkhorn epsilon 0.0
Iteration  1286
Training: loss 4.6680588722229, covariance difference 1.024991512298584
Validation: loss 5.6391399856375966, covariance difference 8.582133289044359, sinkhorn epsilon 0.0
Iteration  1287
Training: loss 4.599412441253662, covariance difference 1.0122942924499512
Validation: loss 5.6495941618310415, covariance difference 8.773504949286808, sinkhorn epsilon 0.0
Iteration  1288
Training: loss 4.60986328125, covariance difference 1.0127339363098145
Validation: loss 5.668287743948622, covariance difference 8.750490986849401, sinkhorn epsilon 0.0
Iteration  1289
Training: loss 4.628558158874512, covariance

Iteration  1328
Training: loss 4.655744552612305, covariance difference 1.0223195552825928
Validation: loss 5.680485923960903, covariance difference 8.688009860046146, sinkhorn epsilon 0.0
Iteration  1329
Training: loss 4.6407551765441895, covariance difference 1.0179040431976318
Validation: loss 5.710331864877638, covariance difference 8.661615963747714, sinkhorn epsilon 0.0
Iteration  1330
Training: loss 4.670601844787598, covariance difference 1.0234870910644531
Validation: loss 5.692428904273669, covariance difference 8.569126237246303, sinkhorn epsilon 0.0
Iteration  1331
Training: loss 4.652683258056641, covariance difference 1.0225465297698975
Validation: loss 5.612845578129038, covariance difference 8.725792045905804, sinkhorn epsilon 0.0
Iteration  1332
Training: loss 4.5731611251831055, covariance difference 1.0068295001983643
Validation: loss 5.579456108643734, covariance difference 8.369681685899122, sinkhorn epsilon 2.222568850219952e-14
Iteration  1333
Training: loss 4.53

Iteration  1371
Training: loss 4.508522987365723, covariance difference 0.9964213371276855
Validation: loss 5.609913782345211, covariance difference 8.56969477975499, sinkhorn epsilon 0.0
Iteration  1372
Training: loss 4.570181846618652, covariance difference 1.0063717365264893
Validation: loss 5.681676314850429, covariance difference 8.85160564478641, sinkhorn epsilon 0.0
Iteration  1373
Training: loss 4.641946315765381, covariance difference 1.0182560682296753
Validation: loss 5.673004390420015, covariance difference 8.776597039723363, sinkhorn epsilon 0.0
Iteration  1374
Training: loss 4.633272171020508, covariance difference 1.016733169555664
Validation: loss 5.732431859091318, covariance difference 8.858499905126596, sinkhorn epsilon 3.838880173028942e-14
Iteration  1375
Training: loss 4.692708492279053, covariance difference 1.0306190252304077
Validation: loss 5.618670756673192, covariance difference 8.585186860233486, sinkhorn epsilon 0.0
Iteration  1376
Training: loss 4.5789403

Iteration  1414
Training: loss 4.715359210968018, covariance difference 1.032982587814331
Validation: loss 5.659215230736738, covariance difference 8.596741953117526, sinkhorn epsilon 4.179087702495526e-14
Iteration  1415
Training: loss 4.619483470916748, covariance difference 1.0147199630737305
Validation: loss 5.688613410000378, covariance difference 8.726855876097563, sinkhorn epsilon 0.0
Iteration  1416
Training: loss 4.648879528045654, covariance difference 1.0195742845535278
Validation: loss 5.747129024525366, covariance difference 8.750822891716915, sinkhorn epsilon 0.0
Iteration  1417
Training: loss 4.707401275634766, covariance difference 1.0315800905227661
Validation: loss 5.650600373689193, covariance difference 8.6857678456415, sinkhorn epsilon 0.0
Iteration  1418
Training: loss 4.610876560211182, covariance difference 1.0126385688781738
Validation: loss 5.630240304639792, covariance difference 8.533821357251632, sinkhorn epsilon 0.0
Iteration  1419
Training: loss 4.5905094

Iteration  1458
Training: loss 4.659172534942627, covariance difference 1.0225647687911987
Validation: loss 5.673727863340577, covariance difference 8.620984046848186, sinkhorn epsilon 0.0
Iteration  1459
Training: loss 4.633997440338135, covariance difference 1.017382264137268
Validation: loss 5.666202890608216, covariance difference 8.75091314287816, sinkhorn epsilon 0.0
Iteration  1460
Training: loss 4.62647008895874, covariance difference 1.0166202783584595
Validation: loss 5.692488875138394, covariance difference 8.599001889994073, sinkhorn epsilon 0.0
Iteration  1461
Training: loss 4.652757167816162, covariance difference 1.0210832357406616
Validation: loss 5.663890872020822, covariance difference 8.518720834333518, sinkhorn epsilon 0.0
Iteration  1462
Training: loss 4.6241607666015625, covariance difference 1.0153357982635498
Validation: loss 5.627904394590752, covariance difference 8.697474785538393, sinkhorn epsilon 0.0
Iteration  1463
Training: loss 4.5881853103637695, covari

Iteration  1501
Training: loss 4.603675842285156, covariance difference 1.0116043090820312
Validation: loss 5.613567114935724, covariance difference 8.545890659168247, sinkhorn epsilon 0.0
Iteration  1502
Training: loss 4.5738348960876465, covariance difference 1.0077539682388306
Validation: loss 5.694269015158964, covariance difference 8.539612581963299, sinkhorn epsilon 0.0
Iteration  1503
Training: loss 4.654540061950684, covariance difference 1.0197736024856567
Validation: loss 5.6295393774195785, covariance difference 8.511302237584788, sinkhorn epsilon 0.0
Iteration  1504
Training: loss 4.589807510375977, covariance difference 1.006723165512085
Validation: loss 5.632959445905325, covariance difference 8.618491917597993, sinkhorn epsilon 0.0
Iteration  1505
Training: loss 4.593227386474609, covariance difference 1.010103464126587
Validation: loss 5.602809111560768, covariance difference 8.620256819414848, sinkhorn epsilon 0.0
Iteration  1506
Training: loss 4.563076972961426, covar

Iteration  1545
Training: loss 4.642299175262451, covariance difference 1.0196179151535034
Validation: loss 5.6755768578351535, covariance difference 8.571214664506256, sinkhorn epsilon 0.0
Iteration  1546
Training: loss 4.635846138000488, covariance difference 1.0190683603286743
Validation: loss 5.7313190897659565, covariance difference 8.969478415895932, sinkhorn epsilon 0.0
Iteration  1547
Training: loss 4.691588878631592, covariance difference 1.0297069549560547
Validation: loss 5.723103725605788, covariance difference 8.52457239355053, sinkhorn epsilon 0.0
Iteration  1548
Training: loss 4.683371543884277, covariance difference 1.0278209447860718
Validation: loss 5.596434792879722, covariance difference 8.296061427644059, sinkhorn epsilon 0.0
Iteration  1549
Training: loss 4.556704521179199, covariance difference 1.0038094520568848
Validation: loss 5.607151220968188, covariance difference 8.571613993843801, sinkhorn epsilon 0.0
Iteration  1550
Training: loss 4.567420482635498, cova

Iteration  1589
Training: loss 4.667431354522705, covariance difference 1.0243792533874512
Validation: loss 5.65280038234436, covariance difference 8.713571292267572, sinkhorn epsilon 0.0
Iteration  1590
Training: loss 4.613069534301758, covariance difference 1.0143790245056152
Validation: loss 5.676100952699197, covariance difference 8.72440124107961, sinkhorn epsilon 0.0
Iteration  1591
Training: loss 4.636364936828613, covariance difference 1.0198261737823486
Validation: loss 5.672590056609411, covariance difference 8.488172688492753, sinkhorn epsilon 0.0
Iteration  1592
Training: loss 4.6328630447387695, covariance difference 1.0179377794265747
Validation: loss 5.6516823709552035, covariance difference 8.359227742568006, sinkhorn epsilon 0.0
Iteration  1593
Training: loss 4.6119489669799805, covariance difference 1.0134814977645874
Validation: loss 5.563995506026011, covariance difference 8.309278429269854, sinkhorn epsilon 0.0
Iteration  1594
Training: loss 4.524260520935059, cova

Iteration  1632
Training: loss 4.574638843536377, covariance difference 1.0067594051361084
Validation: loss 5.627238495448227, covariance difference 8.460922128948711, sinkhorn epsilon 6.143291315138666e-14
Iteration  1633
Training: loss 4.587507247924805, covariance difference 1.0105057954788208
Validation: loss 5.709703812440076, covariance difference 8.587565508199566, sinkhorn epsilon 0.0
Iteration  1634
Training: loss 4.669968128204346, covariance difference 1.0244396924972534
Validation: loss 5.661186410235226, covariance difference 8.480172377224875, sinkhorn epsilon 0.0
Iteration  1635
Training: loss 4.621456146240234, covariance difference 1.0136269330978394
Validation: loss 5.59752667475856, covariance difference 8.465316813057775, sinkhorn epsilon 0.0
Iteration  1636
Training: loss 4.557791709899902, covariance difference 1.0038330554962158
Validation: loss 5.667473905965265, covariance difference 8.398496791076825, sinkhorn epsilon 0.0
Iteration  1637
Training: loss 4.62774

Iteration  1676
Training: loss 4.571765899658203, covariance difference 1.0060310363769531
Validation: loss 5.722763016826834, covariance difference 8.950018219084209, sinkhorn epsilon 0.0
Iteration  1677
Training: loss 4.683032989501953, covariance difference 1.0249462127685547
Validation: loss 5.692102265937722, covariance difference 8.636845101928477, sinkhorn epsilon 0.0
Iteration  1678
Training: loss 4.652368068695068, covariance difference 1.0223468542099
Validation: loss 5.668162695241266, covariance difference 8.769165956176662, sinkhorn epsilon 0.0
Iteration  1679
Training: loss 4.628427505493164, covariance difference 1.016663670539856
Validation: loss 5.629677837656743, covariance difference 8.471082600807906, sinkhorn epsilon 0.0
Iteration  1680
Training: loss 4.5899457931518555, covariance difference 1.0097748041152954
Validation: loss 5.649569551556338, covariance difference 8.648457235243278, sinkhorn epsilon 0.0
Iteration  1681
Training: loss 4.609834671020508, covarian

Iteration  1720
Training: loss 4.60148811340332, covariance difference 1.0129574537277222
Validation: loss 5.61729499565605, covariance difference 8.606074572895928, sinkhorn epsilon 0.0
Iteration  1721
Training: loss 4.577563285827637, covariance difference 1.0070041418075562
Validation: loss 5.639539090853134, covariance difference 8.639660829295552, sinkhorn epsilon 0.0
Iteration  1722
Training: loss 4.5998077392578125, covariance difference 1.0120307207107544
Validation: loss 5.607632193747067, covariance difference 8.538947080317644, sinkhorn epsilon 0.0
Iteration  1723
Training: loss 4.567904472351074, covariance difference 1.007652759552002
Validation: loss 5.563813554833359, covariance difference 8.307953961561642, sinkhorn epsilon 0.0
Iteration  1724
Training: loss 4.524082183837891, covariance difference 0.9983347058296204
Validation: loss 5.696609317863687, covariance difference 8.60099013398872, sinkhorn epsilon 0.0
Iteration  1725
Training: loss 4.656878471374512, covarian

Iteration  1764
Training: loss 4.594948768615723, covariance difference 1.0117985010147095
Validation: loss 5.6460253894074635, covariance difference 8.567335486140749, sinkhorn epsilon 0.0
Iteration  1765
Training: loss 4.60629415512085, covariance difference 1.0123635530471802
Validation: loss 5.674335989161579, covariance difference 8.701626934029097, sinkhorn epsilon 0.0
Iteration  1766
Training: loss 4.634604454040527, covariance difference 1.0168455839157104
Validation: loss 5.570988689168118, covariance difference 8.442675121226126, sinkhorn epsilon 0.0
Iteration  1767
Training: loss 4.531252861022949, covariance difference 0.9997765421867371
Validation: loss 5.6892417830363895, covariance difference 8.611128278666586, sinkhorn epsilon 0.0
Iteration  1768
Training: loss 4.649507522583008, covariance difference 1.020443320274353
Validation: loss 5.666648035189178, covariance difference 8.82243306406138, sinkhorn epsilon 0.0
Iteration  1769
Training: loss 4.626920700073242, covari

Iteration  1808
Training: loss 4.572711944580078, covariance difference 1.0072264671325684
Validation: loss 5.619354683277326, covariance difference 8.616687563964229, sinkhorn epsilon 0.0
Iteration  1809
Training: loss 4.5796284675598145, covariance difference 1.00687837600708
Validation: loss 5.674963460272737, covariance difference 8.748803384052595, sinkhorn epsilon 0.0
Iteration  1810
Training: loss 4.635227680206299, covariance difference 1.0192466974258423
Validation: loss 5.6952613282414815, covariance difference 8.629284713258526, sinkhorn epsilon 0.0
Iteration  1811
Training: loss 4.655529022216797, covariance difference 1.0200239419937134
Validation: loss 5.662501479920897, covariance difference 8.515487453315016, sinkhorn epsilon 0.0
Iteration  1812
Training: loss 4.622769355773926, covariance difference 1.0169646739959717
Validation: loss 5.648294947431637, covariance difference 8.58485631299722, sinkhorn epsilon 0.0
Iteration  1813
Training: loss 4.6085591316223145, covar

Iteration  1852
Training: loss 4.565694332122803, covariance difference 1.0061732530593872
Validation: loss 5.66709801479337, covariance difference 8.524941214399666, sinkhorn epsilon 0.0
Iteration  1853
Training: loss 4.62736701965332, covariance difference 1.0156879425048828
Validation: loss 5.693739671443208, covariance difference 8.534766604696092, sinkhorn epsilon 0.0
Iteration  1854
Training: loss 4.654007911682129, covariance difference 1.0198231935501099
Validation: loss 5.6567438679114765, covariance difference 8.722118609887854, sinkhorn epsilon 0.0
Iteration  1855
Training: loss 4.617013454437256, covariance difference 1.0147794485092163
Validation: loss 5.6048852992206255, covariance difference 8.527091094698573, sinkhorn epsilon 0.0
Iteration  1856
Training: loss 4.565154075622559, covariance difference 1.0053527355194092
Validation: loss 5.691414721218095, covariance difference 8.661349473000861, sinkhorn epsilon 0.0
Iteration  1857
Training: loss 4.651679039001465, covar

Iteration  1896
Training: loss 4.574419021606445, covariance difference 1.0083609819412231
Validation: loss 5.640258561747487, covariance difference 8.727476224251381, sinkhorn epsilon 0.0
Iteration  1897
Training: loss 4.600527286529541, covariance difference 1.0125495195388794
Validation: loss 5.644231941040935, covariance difference 8.596095659502092, sinkhorn epsilon 0.0
Iteration  1898
Training: loss 4.6045002937316895, covariance difference 1.0111945867538452
Validation: loss 5.696844713619514, covariance difference 8.684029845385748, sinkhorn epsilon 0.0
Iteration  1899
Training: loss 4.657108783721924, covariance difference 1.0215485095977783
Validation: loss 5.649069684181719, covariance difference 8.576472154661799, sinkhorn epsilon 0.0
Iteration  1900
Training: loss 4.609339714050293, covariance difference 1.0132323503494263
Validation: loss 5.667635923639642, covariance difference 8.690853038626805, sinkhorn epsilon 0.0
Iteration  1901
Training: loss 4.627904891967773, cova

Iteration  1940
Training: loss 4.663521766662598, covariance difference 1.0242552757263184
Validation: loss 5.635921999423703, covariance difference 8.85861072911596, sinkhorn epsilon 0.0
Iteration  1941
Training: loss 4.596189975738525, covariance difference 1.0102170705795288
Validation: loss 5.595414501684004, covariance difference 8.477794270101672, sinkhorn epsilon 0.0
Iteration  1942
Training: loss 4.555673599243164, covariance difference 1.0038411617279053
Validation: loss 5.617093811699467, covariance difference 8.513779373990292, sinkhorn epsilon 0.0
Iteration  1943
Training: loss 4.577358245849609, covariance difference 1.0074694156646729
Validation: loss 5.64302458401939, covariance difference 8.849205646221753, sinkhorn epsilon 0.0
Iteration  1944
Training: loss 4.603288650512695, covariance difference 1.0122532844543457
Validation: loss 5.614787532027204, covariance difference 8.584938825072143, sinkhorn epsilon 2.9135880773941904e-14
Iteration  1945
Training: loss 4.57505

Iteration  1984
Training: loss 4.6209564208984375, covariance difference 1.0161775350570679
Validation: loss 5.670922458044952, covariance difference 8.579040915812705, sinkhorn epsilon 0.0
Iteration  1985
Training: loss 4.631191730499268, covariance difference 1.0176451206207275
Validation: loss 5.662974557128322, covariance difference 8.688645447569913, sinkhorn epsilon 6.258828645872519e-14
Iteration  1986
Training: loss 4.623233795166016, covariance difference 1.0169435739517212
Validation: loss 5.636645106224884, covariance difference 8.588402034926675, sinkhorn epsilon 0.0
Iteration  1987
Training: loss 4.596909523010254, covariance difference 1.011296033859253
Validation: loss 5.670585917545982, covariance difference 8.914900966552702, sinkhorn epsilon 0.0
Iteration  1988
Training: loss 4.630849838256836, covariance difference 1.0159637928009033
Validation: loss 5.641813214613643, covariance difference 8.658083833596082, sinkhorn epsilon 0.0
Iteration  1989
Training: loss 4.6020

Iteration  2027
Training: loss 4.6402201652526855, covariance difference 1.0186220407485962
Validation: loss 5.634867066979705, covariance difference 8.587531911074846, sinkhorn epsilon 0.0
Iteration  2028
Training: loss 4.595134735107422, covariance difference 1.0105797052383423
Validation: loss 5.636194869920947, covariance difference 8.757658503910832, sinkhorn epsilon 0.0
Iteration  2029
Training: loss 4.596463203430176, covariance difference 1.0107685327529907
Validation: loss 5.579077426716815, covariance difference 8.599176898464972, sinkhorn epsilon 0.0
Iteration  2030
Training: loss 4.539344787597656, covariance difference 1.000700831413269
Validation: loss 5.604026213404852, covariance difference 8.417822501666254, sinkhorn epsilon 0.0
Iteration  2031
Training: loss 4.564291000366211, covariance difference 1.006935477256775
Validation: loss 5.672975650178317, covariance difference 8.64063530307735, sinkhorn epsilon 0.0
Iteration  2032
Training: loss 4.633240699768066, covaria

Iteration  2070
Training: loss 4.591607093811035, covariance difference 1.0107051134109497
Validation: loss 5.588773939828483, covariance difference 8.651427197102613, sinkhorn epsilon 0.0
Iteration  2071
Training: loss 4.549038410186768, covariance difference 1.0048222541809082
Validation: loss 5.640151672338872, covariance difference 8.772387807288945, sinkhorn epsilon 5.5835592318908964e-14
Iteration  2072
Training: loss 4.60042142868042, covariance difference 1.0112918615341187
Validation: loss 5.614025112276953, covariance difference 8.62410309632833, sinkhorn epsilon 0.0
Iteration  2073
Training: loss 4.574295997619629, covariance difference 1.0087511539459229
Validation: loss 5.727446306561171, covariance difference 8.641755292080976, sinkhorn epsilon 0.0
Iteration  2074
Training: loss 4.687716484069824, covariance difference 1.0288290977478027
Validation: loss 5.640925998404486, covariance difference 8.574735644016748, sinkhorn epsilon 0.0
Iteration  2075
Training: loss 4.60119

Iteration  2114
Training: loss 4.703524589538574, covariance difference 1.0325095653533936
Validation: loss 5.669713250961384, covariance difference 8.585662035084624, sinkhorn epsilon 0.0
Iteration  2115
Training: loss 4.629981994628906, covariance difference 1.0165506601333618
Validation: loss 5.651692734377417, covariance difference 8.626755172835239, sinkhorn epsilon 4.914654212295544e-14
Iteration  2116
Training: loss 4.611964702606201, covariance difference 1.0151039361953735
Validation: loss 5.631548938982802, covariance difference 8.791209181330437, sinkhorn epsilon 0.0
Iteration  2117
Training: loss 4.591817378997803, covariance difference 1.0096925497055054
Validation: loss 5.696200100989314, covariance difference 8.458010909499723, sinkhorn epsilon 0.0
Iteration  2118
Training: loss 4.656467914581299, covariance difference 1.0224149227142334
Validation: loss 5.7216347413500594, covariance difference 8.860165500519988, sinkhorn epsilon 0.0
Iteration  2119
Training: loss 4.681

Iteration  2158
Training: loss 4.587543487548828, covariance difference 1.009914755821228
Validation: loss 5.729033418665715, covariance difference 8.70100883038233, sinkhorn epsilon 0.0
Iteration  2159
Training: loss 4.689301490783691, covariance difference 1.0259236097335815
Validation: loss 5.612957026048875, covariance difference 8.739695027059364, sinkhorn epsilon 0.0
Iteration  2160
Training: loss 4.573221206665039, covariance difference 1.0058759450912476
Validation: loss 5.632811771677036, covariance difference 8.651565412387734, sinkhorn epsilon 0.0
Iteration  2161
Training: loss 4.593076229095459, covariance difference 1.0101951360702515
Validation: loss 5.606061665344297, covariance difference 8.597200337055172, sinkhorn epsilon 0.0
Iteration  2162
Training: loss 4.5663299560546875, covariance difference 1.0070539712905884
Validation: loss 5.654862063168791, covariance difference 8.442350362588153, sinkhorn epsilon 0.0
Iteration  2163
Training: loss 4.615126132965088, covari

Iteration  2202
Training: loss 4.584681510925293, covariance difference 1.0080007314682007
Validation: loss 5.625427544158638, covariance difference 8.5586774973143, sinkhorn epsilon 0.0
Iteration  2203
Training: loss 4.585691452026367, covariance difference 1.0111452341079712
Validation: loss 5.6479631489113356, covariance difference 8.769545240579964, sinkhorn epsilon 0.0
Iteration  2204
Training: loss 4.608226776123047, covariance difference 1.0133062601089478
Validation: loss 5.590695626291115, covariance difference 8.491242836833996, sinkhorn epsilon 0.0
Iteration  2205
Training: loss 4.550960063934326, covariance difference 1.0033609867095947
Validation: loss 5.607639337799325, covariance difference 8.627540393371476, sinkhorn epsilon 0.0
Iteration  2206
Training: loss 4.567903995513916, covariance difference 1.0060089826583862
Validation: loss 5.6723431065796, covariance difference 8.824282359140808, sinkhorn epsilon 0.0
Iteration  2207
Training: loss 4.6326069831848145, covaria

Iteration  2246
Training: loss 4.607470512390137, covariance difference 1.0130618810653687
Validation: loss 5.66420961192923, covariance difference 8.611176928866874, sinkhorn epsilon 0.0
Iteration  2247
Training: loss 4.624477863311768, covariance difference 1.0183594226837158
Validation: loss 5.596709780400095, covariance difference 8.584411483188331, sinkhorn epsilon 0.0
Iteration  2248
Training: loss 4.556977272033691, covariance difference 1.0023531913757324
Validation: loss 5.528423418532751, covariance difference 8.455594215313468, sinkhorn epsilon 0.0
Iteration  2249
Training: loss 4.488691329956055, covariance difference 0.9910328388214111
Validation: loss 5.676442548169385, covariance difference 8.789643156133327, sinkhorn epsilon 0.0
Iteration  2250
Training: loss 4.636706352233887, covariance difference 1.017971396446228
Validation: loss 5.6812613427864225, covariance difference 8.572875586453456, sinkhorn epsilon 0.0
Iteration  2251
Training: loss 4.641529083251953, covari

Iteration  2290
Training: loss 4.598793983459473, covariance difference 1.012137770652771
Validation: loss 5.6887424237997495, covariance difference 8.60341077858073, sinkhorn epsilon 0.0
Iteration  2291
Training: loss 4.648996353149414, covariance difference 1.0212842226028442
Validation: loss 5.617506468180696, covariance difference 8.541341932672436, sinkhorn epsilon 0.0
Iteration  2292
Training: loss 4.577783107757568, covariance difference 1.0101581811904907
Validation: loss 5.640428222297895, covariance difference 8.659582509974264, sinkhorn epsilon 0.0
Iteration  2293
Training: loss 4.6006927490234375, covariance difference 1.0127990245819092
Validation: loss 5.758347971724494, covariance difference 8.748247219223138, sinkhorn epsilon 0.0
Iteration  2294
Training: loss 4.718611240386963, covariance difference 1.0342893600463867
Validation: loss 5.665646426464321, covariance difference 8.398031368754612, sinkhorn epsilon 0.0
Iteration  2295
Training: loss 4.625913619995117, covar

Iteration  2334
Training: loss 4.662641525268555, covariance difference 1.0231022834777832
Validation: loss 5.701305305814554, covariance difference 8.759137245773104, sinkhorn epsilon 0.0
Iteration  2335
Training: loss 4.661563873291016, covariance difference 1.024367094039917
Validation: loss 5.574679270090948, covariance difference 8.420264494508244, sinkhorn epsilon 0.0
Iteration  2336
Training: loss 4.53494930267334, covariance difference 1.0018068552017212
Validation: loss 5.624599661210737, covariance difference 8.48198704020123, sinkhorn epsilon 3.6559965400591467e-14
Iteration  2337
Training: loss 4.584867477416992, covariance difference 1.0094765424728394
Validation: loss 5.5782196596070905, covariance difference 8.494167829476572, sinkhorn epsilon 0.0
Iteration  2338
Training: loss 4.5384840965271, covariance difference 1.0027610063552856
Validation: loss 5.671544116790415, covariance difference 8.66042480019916, sinkhorn epsilon 0.0
Iteration  2339
Training: loss 4.63182163

Iteration  2378
Training: loss 4.635239601135254, covariance difference 1.0188673734664917
Validation: loss 5.699393717895793, covariance difference 8.740717218226541, sinkhorn epsilon 0.0
Iteration  2379
Training: loss 4.659665107727051, covariance difference 1.0226683616638184
Validation: loss 5.591391546048985, covariance difference 8.321619320971442, sinkhorn epsilon 0.0
Iteration  2380
Training: loss 4.5516510009765625, covariance difference 1.0040675401687622
Validation: loss 5.684167169426539, covariance difference 8.441722748100048, sinkhorn epsilon 0.0
Iteration  2381
Training: loss 4.644435405731201, covariance difference 1.0193636417388916
Validation: loss 5.687574746627506, covariance difference 8.890537106614861, sinkhorn epsilon 0.0
Iteration  2382
Training: loss 4.647839546203613, covariance difference 1.0212602615356445
Validation: loss 5.688210840614331, covariance difference 8.843014788167292, sinkhorn epsilon 0.0
Iteration  2383
Training: loss 4.648478984832764, cova

Iteration  2422
Training: loss 4.647294521331787, covariance difference 1.0192958116531372
Validation: loss 5.684888491415287, covariance difference 8.57615541038302, sinkhorn epsilon 0.0
Iteration  2423
Training: loss 4.6451568603515625, covariance difference 1.0212184190750122
Validation: loss 5.660691554018415, covariance difference 8.707924446002162, sinkhorn epsilon 0.0
Iteration  2424
Training: loss 4.620955467224121, covariance difference 1.0167316198349
Validation: loss 5.640913172004689, covariance difference 8.324815180079534, sinkhorn epsilon 0.0
Iteration  2425
Training: loss 4.6011810302734375, covariance difference 1.0130890607833862
Validation: loss 5.625791612425095, covariance difference 8.629652004427225, sinkhorn epsilon 0.0
Iteration  2426
Training: loss 4.5860595703125, covariance difference 1.0100317001342773
Validation: loss 5.599217198336726, covariance difference 8.700639888754527, sinkhorn epsilon 0.0
Iteration  2427
Training: loss 4.5594868659973145, covarian

Iteration  2466
Training: loss 4.696832656860352, covariance difference 1.027296543121338
Validation: loss 5.647126422009386, covariance difference 8.414685723337527, sinkhorn epsilon 0.0
Iteration  2467
Training: loss 4.607394218444824, covariance difference 1.0132899284362793
Validation: loss 5.685137460283099, covariance difference 8.885702387228484, sinkhorn epsilon 0.0
Iteration  2468
Training: loss 4.645401477813721, covariance difference 1.0212925672531128
Validation: loss 5.667649714894838, covariance difference 8.714643387043603, sinkhorn epsilon 0.0
Iteration  2469
Training: loss 4.6279144287109375, covariance difference 1.0160603523254395
Validation: loss 5.669322039824641, covariance difference 8.745205860435366, sinkhorn epsilon 0.0
Iteration  2470
Training: loss 4.629574775695801, covariance difference 1.0190370082855225
Validation: loss 5.679899005635677, covariance difference 8.747844418415694, sinkhorn epsilon 0.0
Iteration  2471
Training: loss 4.640162467956543, covar

In [None]:
x, y = zip(*list(res_test['cov_diff'].items()))

start = 100
plt.plot(x[start:], y[start:])