In [None]:
# Set parameters
import argparse
parser = argparse.ArgumentParser()

# Data
parser.add_argument(
    '-N_Q', '--N_samples_Q', type=int, default=200, help='total number of target samples',
)
parser.add_argument(
    '-N_P', '--N_samples_P', type=int, default=600, help='total number of prior samples',
)
parser.add_argument(
    '-N_dim', type=int, help='dimension of input data',
)
parser.add_argument(
    '-N_latent_dim', type=int, help='dimension of latent space',
)
parser.add_argument(
    '-N_project_dim', type=int, help='dimension of PCA projected space on input',
)
parser.add_argument(
    '-sample_latent', type=bool, default = False, help='True: sample in the latent space, False: sample in the physical space',
)
# Dataset property
parser.add_argument(
    '--dataset', type=str, default='Lorenz63', choices=['Lorenz63', 'Learning_gaussian', 'Mixture_of_gaussians', 'Mixture_of_gaussians2','Mixture_of_gaussians3','Mixture_of_gaussians4', 'Stretched_exponential', 'Learning_student_t', 'Mixture_of_student_t', 'Mixture_of_student_t_submnfld', 'Mixture_of_gaussians_submnfld','MNIST', 'CIFAR10', 'MNIST_switch', 'CIFAR10_switch', 'MNIST_ae', 'MNIST_ae_switch','CIFAR10_ae',  'Mixture_of_gaussians_submnfld_ae','BreastCancer', '1D_pts', '2D_pts','1D_dirac2gaussian', '1D_dirac2uniform',]
)
parser.add_argument(
    '-y0', type=float, nargs="+", default=[1.0,2.0, 2.0]
)
parser.add_argument(
    '-beta', type=float, help='gibbs distribution of -|x|^\beta',
)
parser.add_argument(
    '-sigma_P', type=float, default=0.5, help='std of initial gaussian distribution',
)
parser.add_argument(
    '-sigma_Q', type=float, default=0.5, help='std of target gaussian distribution',
)
parser.add_argument(
    '-nu', type=float, help='df of target student-t distribution',
)
parser.add_argument(
    '-interval_length', type=float, help='interval length of the uniform distribution',
)
parser.add_argument(
    '-label', type=int, nargs="+", help='class label of image data',
)
parser.add_argument(
    '-pts_P', type=float, nargs="+", default=[10.0,]
)
parser.add_argument(
    '-pts_Q', type=float, nargs="+", default=[0.0,]
)
parser.add_argument(
    '-pts_P_2', type=float, nargs="+", default=[0.0,]
)
parser.add_argument(
    '-pts_Q_2', type=float, nargs="+", default=[0.0,]
)
parser.add_argument(
    '--random_seed', type=int, default=0, help='random seed for data generator',
)


# (f, Gamma)-divergence
parser.add_argument(
    '--f', type=str, default='KL', choices=['KL', 'alpha', 'reverse_KL', 'JS'],
)
parser.add_argument(
    '-alpha', type=float, help='parameter value for alpha divergence',
)    
parser.add_argument(
    '--formulation', type=str, default='DV', choices=['LT', 'DV'], help='LT or DV in case of f=KL, otherwise, keep LT',
)
parser.add_argument(
    '--Gamma', type=str, default='Lipshitz', choices=['Lipshitz'],
)
parser.add_argument(
    '-L', type=float, default=1.0, help='Lipshitz constant: default=inf w/o constraint',
)
parser.add_argument(
    '--reverse', type=bool, default=False, help='True -> D(Q|P), False -> D(P|Q)',
)
parser.add_argument(
    '--constraint', type=str, default='hard', choices=['hard', 'soft'],
)
parser.add_argument(
    '-lamda', type=float, default=100.0, help='coefficient on soft constraint',
)


# Neural Network definition <phi>
parser.add_argument(
    '-NN', '--NN_model', type=str, default='fnn', choices=['fnn', 'cnn', 'cnn-fnn'],
)
parser.add_argument(
    '-N_fnn_layers', type=int, nargs='+', default=[32,32,32], help='list of the number of FNN hidden layer units / the number of CNN feed-forward hidden layer units',
)
parser.add_argument(
    '-N_cnn_layers', type=int, nargs='+', help='list of the number of CNN channels',
)
parser.add_argument(
    '--activation_ftn', type=str, nargs='+', default=['relu',], choices=['relu', 'mollified_relu_cos3','mollified_relu_poly3','mollified_relu_cos3_shift','softplus', 'leaky_relu','elu', 'bounded_relu', 'bounded_elu'], help='[0]: for the fnn/convolutional layer, [1]: for the cnn feed-forward layer, [2]: for the LAST cnn feed-forward layer',
)
parser.add_argument(
    '-eps', type=float, default = 0.5, help='Mollifier shape adjusting parameter when using mollified relu3 activations',
)
parser.add_argument(
    '--N_conditions', type=int, default=1, help='number of classes for the conditional setting',
)


# training parameters
parser.add_argument(
    '-ep', '--epochs', type=int, default=1000, help='# updates for P',
)
parser.add_argument(
    '-ep_nn', '--epochs_nn', type=int, default=3, help='# updates for NN to find phi*',
)
parser.add_argument(
    '--optimizer', type=str, default='adam', choices=['sgd', 'adam',], help='optimizer for NN',
)
parser.add_argument(
    '--ode_solver', type=str, choices=['forward_euler', 'AB2', 'AB3', 'AB4', 'AB5', 'ABM1', 'Heun', 'ABM2', 'ABM3', 'ABM4', 'ABM5', 'RK4', 'ode45' ], default='forward_euler', help='ode solver for particle ode',
)
parser.add_argument(
    '-mobility', type=str, help='problem dependent mobility function\nRecommendation: MNIST - bounded',
)
parser.add_argument(
    '-lr_P_decay', type=str, choices=['rational', 'step',], help='delta t decay',
)
parser.add_argument(
    '--lr_P', type=float, default=1.0, help='lr for P',
)
parser.add_argument(
    '--lr_NN', type=float, default=0.001, help='lr for NN',
)
parser.add_argument(
    '--exp_no', type=str, default='first_run', help='short experiment name under the same data',
)
parser.add_argument(
    '--mb_size_P', type=int, default=200, help='mini batch size for the moving distribution P',
)
parser.add_argument(
    '--mb_size_Q', type=int, default=200, help='mini batch size for the target distribution Q',
)


# save/display 
parser.add_argument(
    '--save_iter', type=int, default=10, help='save results per each save_iter',
)
parser.add_argument(
    '--plot_result', type=bool, default=True, help='True -> show plots',
)
parser.add_argument(
    '--plot_intermediate_result', type=bool, default=False, help='True -> save intermediate plots',
)

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # avoid tensorflow warning
import tensorflow as tf
import numpy as np
import re

In [None]:
# input parameters --------------------------------------
p, unknown = parser.parse_known_args()
param = vars(p)

if p.alpha:    
    par = [p.alpha]
    p.exptype = '%s=%05.2f-%s' % (p.f, p.alpha, p.Gamma)
else: 
    par = []
    p.exptype = '%s-%s' % (p.f, p.Gamma)
if p.L == None:
    p.expname = '%s_%s' % (p.exptype, 'inf')
else:
    p.expname = '%s_%.4f' % (p.exptype, p.L)

In [None]:
# Data generation ----------------------------------------
from util.generate_data import generate_data
p, X_, Y_, X_label, Y_label = generate_data(p)
       
if p.dataset in ['BreastCancer',]:
    Q = tf.constant(X_/10.0, dtype=tf.float32) # constant
    P = tf.Variable(Y_/10.0, dtype=tf.float32) # variable
else:
    Q = tf.constant(X_, dtype=tf.float32) # constant
    P = tf.Variable(Y_, dtype=tf.float32) # variable
    
if p.N_conditions >1:
    Q_label = tf.constant(X_label, dtype=tf.float32)
    P_label = tf.constant(Y_label, dtype=tf.float32)
    
    label_idx_Q = [Q_label[:,n]==1 for n in range(p.N_conditions)]
    label_idx_P = [P_label[:,n]==1 for n in range(p.N_conditions)]
else:
    Q_label, P_label = None, None
    label_idx_Q, label_idx_P = None, None
    

data_par = {'P_label': P_label, 'Q_label': Q_label, 'mb_size_P': p.mb_size_P, 'mb_size_Q': p.mb_size_Q, 'N_samples_P': p.N_samples_P, 'N_samples_Q': p.N_samples_Q, 'label_idx_Q' : label_idx_Q, 'label_idx_P': label_idx_P}

In [None]:
# Discriminator learning  -----------------------------------------
# Discriminator construction using Neural Network
from util.construct_NN import check_nn_topology, initialize_NN, model

N_fnn_layers, N_cnn_layers, p.activation_ftn = check_nn_topology(p.NN_model, p.N_fnn_layers, p.N_cnn_layers, p.N_dim, p.activation_ftn)

NN_par = {'NN_model':p.NN_model, 'activation_ftn':p.activation_ftn, 'N_dim': p.N_dim, 'N_cnn_layers':N_cnn_layers, 'N_fnn_layers':N_fnn_layers, 'N_conditions': p.N_conditions, 'constraint': p.constraint, 'L': p.L, 'eps': p.eps}

W, b = initialize_NN(NN_par)
phi = model(NN_par)  # discriminator

# scalar optimal value optimization for f-divergence
nu = tf.Variable(0.0, dtype=tf.float32)

parameters = {'W':W, 'b':b, 'nu':nu} # Learnable parameters for the discriminator phi

# Train setting
from util.train_NN import train_disc
lr_NN = tf.Variable(p.lr_NN, trainable=False) # lr for training a discriminator function

# (Discriminator) Loss ----------------------------------------------
loss_par = {'f': p.f, 'formulation': p.formulation, 'par': par, 'reverse': p.reverse, 'lamda': p.lamda}

In [None]:
# Transporting particles --------------------------------------------
# ODE solver setting
from util.transport_particles import calc_vectorfield, solve_ode
dPs = []
if p.ode_solver in ['forward_euler', 'AB2', 'AB3', 'AB4', 'AB5']:
    aux_params = []
else:
    aux_params = {'parameters': parameters, 'phi': phi, 'Q': Q, 'lr_NN': lr_NN,'epochs_nn': p.epochs_nn, 'loss_par': loss_par, 'NN_par': NN_par, 'data_par': data_par, 'optimizer': p.optimizer}

# Applying mobility to particles
if p.mobility == 'bounded':
    from util.construct_NN import bounded_relu  # mobility that bounding particles (For image data)
        
# Train setting
lr_P_init = p.lr_P # Assume that deltat = deltat(t)
if p.ode_solver == "DOPRI5": # deltat = deltat(x,t)
    lr_P_init = [p.lr_P]*p.N_samples_P
    # Low dimensional example=> rank 2, Image example=> rank 4
    for i in range(1, tf.rank(P)):
        lr_P_init = np.expand_dims(lr_P_init, axis=i)
lr_P = tf.Variable(lr_P_init, trainable=False)
lr_Ps = []

In [None]:
# Save & plot settings -----------------------------------------------
# Metrics to calculate
from util.evaluate_metric import calc_fid, calc_ke, calc_grad_phi
trajectories = []
vectorfields = []
divergences = []
KE_Ps = []
FIDs = []

# saving/plotting parameters
if p.save_iter >= p.epochs:
    p.save_iter = 1

if p.plot_result == True:
    from plot_result import plot_result

p.expname = p.expname+'_%04d_%04d_%02d_%s' % (p.N_samples_Q, p.N_samples_P, p.random_seed, p.exp_no)
filename = p.dataset+'/%s.pickle' % (p.expname)

if p.plot_intermediate_result == True:
    if 'gaussian' in p.dataset and 'Extension' not in p.dataset:
         r_param = p.sigma_Q
    elif 'student_t' in p.dataset:
        r_param = p.nu
    elif p.dataset == 'Extension_of_gaussian':
        r_param = p.a
    else:
        r_param = None
    
# additional plots for simple low dimensional dynamics
if p.N_dim == 1:
    xx = np.linspace(-10, 10, 300)
    xx = tf.constant(np.reshape(xx, (-1,1)), dtype=tf.float32)
    phis = []
elif p.N_dim == 2:#'2D' in p.dataset:
    xx = np.linspace(-10, 10, 40)
    yy = np.linspace(-10, 10, 40)
    XX, YY = np.meshgrid(xx, yy)
    xx = np.concatenate((np.reshape(XX, (-1,1)), np.reshape(YY, (-1,1))), axis=1)
    xx = tf.constant(xx, dtype=tf.float32)
    phis = []

In [None]:
# Train ---------------------------------------------------------------
import time 
t0 = time.time()

for it in range(1, p.epochs+1): # Loop for updating particles P
    parameters, current_loss, dW_norm = train_disc(parameters, phi, P, Q, lr_NN, p.epochs_nn, loss_par, NN_par, data_par, p.optimizer, print_vals=True)
    
    dPs.append( calc_vectorfield(phi, P, parameters, NN_par, loss_par, data_par) )
    
    if p.ode_solver == "DOPRI5": # deltat adust
        P, dPs, dP, lr_P = solve_ode(P, lr_P, dPs, p.ode_solver, aux_params) # update P
    else:
        P, dPs, dP = solve_ode(P, lr_P, dPs, p.ode_solver, aux_params) # update P

    if p.mobility == 'bounded':
        P.assign(bounded_relu(P))
     
    lr_Ps.append(lr_P.numpy())
    # adjust learning rates
    #if it>=100:
    #    lr_P = decay_learning_rate(lr_P, p.lr_P_decay, {'epochs': p.epochs-100, 'epoch': it-100, 'KE_P': KE_P})
    
    # save results
    divergences.append(current_loss)
    KE_P = calc_ke(dP, p.N_samples_P)
    KE_Ps.append(KE_P)
    grad_phi = calc_grad_phi(dP)
    #print("grad", grad_phi)
    
    if p.epochs<=100 or it%p.save_iter == 0:
        if p.dataset in ['BreastCancer',]:
            trajectories.append(P.numpy()*10)
        else:
            trajectories.append(P.numpy())
        if np.prod(p.N_dim) < 500:
            vectorfields.append(dP.numpy())
        elif np.prod(p.N_dim) >= 784:  # image data
            FIDs.append( calc_fid(pred=P.numpy(), real=Q.numpy()) )
    
    # display intermediate results
    if it % (p.epochs/10) == 0:
    #if it in [5, 50, 500, 1000, 2000, 3000, 4000, 5000]:
        display_msg = 'iter %6d: loss = %.10f, norm of dW = %.2f, kinetic energy of P = %.10f, average learning rate for P = %.6f' % (it, current_loss, dW_norm, KE_P, tf.math.reduce_mean(lr_P).numpy())
        if len(FIDs) > 0 :
            display_msg = display_msg + ', FID = %.3f' % FIDs[-1]   
        print(display_msg)
        print("grad", grad_phi)
        
        if p.plot_intermediate_result == True:
            data = {'trajectories': trajectories, 'divergences': divergences, 'KE_Ps': KE_Ps, 'FIDs':FIDs, 'X_':X_, 'Y_':Y_, 'X_label':X_label, 'Y_label':Y_label, 'dt': lr_Ps, 'dataset': p.dataset, 'r_param': r_param, 'vectorfields': vectorfields, 'save_iter':p.save_iter}
            if p.N_dim ==2:
                data.update({'phi': phi, 'W':W, 'b':b, 'NN_par':NN_par})
            plot_result(filename, intermediate=True, epochs = it, iter_nos = None, data = data, show=False)
        
        '''
        if np.prod(p.N_dim) <= 2:
            zz = phi(xx,None, W,b,NN_par).numpy()
            zz = np.reshape(zz, -1)
            phis.append(zz)
        '''

total_time = time.time() - t0
print(f'total time {total_time:.3f}s')

In [None]:
# Save result ------------------------------------------------------
import pickle
if not os.path.exists(p.dataset):
    os.makedirs(p.dataset)

if '1D' in p.dataset:
    X_ = np.concatenate((X_, np.zeros(shape=X_.shape)), axis=1)
    Y_ = np.concatenate((Y_, np.zeros(shape=Y_.shape)), axis=1)
    
    trajectories = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in trajectories]
    vectorfields = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in vectorfields]
        
param.update({'X_': X_, 'Y_': Y_, 'lr_Ps':lr_Ps,})
result = {'trajectories': trajectories, 'vectorfields': vectorfields, 'divergences': divergences, 'KE_Ps': KE_Ps, 'FIDs': FIDs,}

if p.dataset in ['BreastCancer',]:
    np.savetxt("gene_expression_example/GPL570/"+p.dataset+'/output_norm_dataset_dim_%d.csv' % p.N_dim, trajectories[-1], delimiter=",")
        
# Save trained data
with open(filename,"wb") as fw:
    pickle.dump([param, result] , fw)
print("Results saved at:", filename)

# Plot final result
if p.plot_result == True:
    plot_result(filename, intermediate=False, show=False)

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(trajectories[0][:,0],trajectories[0][:,1],trajectories[0][:,2])
plt.show()



dataset = np.load('lorenzdataset.npy')


fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(dataset[0,0:-1:5],dataset[1,0:-1:5],dataset[2,0:-1:5],s=0.1)
plt.show()