In [1]:
import keras
from keras import backend as K
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

from keras.layers import Input, Dense, Activation
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D
from keras.layers import Flatten, Reshape, Lambda
from keras.utils import plot_model
from keras import Model

import os
import os.path as osp
import sys
sys.path.append('/home/jcollins')

import ot
import numpy as np

Using TensorFlow backend.


#### Define NN

In [13]:
numpoints = 10

def plan_python(Ms):
    result=[ot.emd(np.ones(len(M[0])),np.ones(len(M[1])),M) for M in Ms]
    return np.array(result,dtype=np.float32)

def myground_dist_func(a_positions, b_positions,numpoints=3):
    tiled_b = tf.tile(b_positions,[1,numpoints,1])
    repeated_a = tf.reshape(a_positions, [-1,numpoints,1,2])    # Convert to a len(yp) x 1 matrix.
    repeated_a = tf.tile(repeated_a, [1,1,numpoints,1])  # Create multiple columns.
    repeated_a = tf.reshape(repeated_a, [-1,numpoints*numpoints,2])  

    ground_dist = tf.norm(repeated_a-tiled_b,axis=-1)
    ground_dist = tf.reshape(ground_dist,[-1,numpoints,numpoints])
    
    return ground_dist

def create_emd_loss(numpoints):
    
    def emd_loss(y_true, y_pred):
        ground_dist_tensor = myground_dist_func(y_true, y_pred,numpoints=numpoints)
        plan_tensor = tf.py_func(func=plan_python, inp=[ground_dist_tensor],Tout=tf.float32)
        my_calc_loss_tensor = tf.linalg.trace(tf.matmul(ground_dist_tensor,tf.transpose(plan_tensor, perm=[0,2,1])))
        return my_calc_loss_tensor
    
    return emd_loss

inputs = Input(batch_shape=(None,numpoints,2),name='encoder_0_input')
layer = Reshape((numpoints*2,),name='encoder_1_reshape')(inputs)
layer = Dense(20,activation='relu',name='encoder_2_dense')(layer)
latent = Dense(10,activation='relu',name='latent')(layer)
layer = Dense(20,activation='relu',name='decoder_1_dense')(latent)
layer = Dense(numpoints*2,activation='relu',name='decoder_2_dense')(layer)
output = Reshape((numpoints, 2),name='decoder_3_reshape')(layer)

model = Model(inputs, output)


model.compile(loss=create_emd_loss(numpoints) , optimizer=keras.optimizers.Adam())

#### Before training

In [14]:
y_true = np.random.rand(1,10,2)
y_pred = model.predict(y_true)
model.evaluate(y_true,y_true)
print("Input:\n", y_true)
print("\nOuput:\n", y_pred)

Input:
 [[[0.32428188 0.2179219 ]
  [0.78628233 0.31047642]
  [0.38002429 0.81332314]
  [0.45056699 0.68212896]
  [0.52725623 0.05088297]
  [0.69765567 0.46895866]
  [0.51589004 0.77805182]
  [0.06194304 0.04234782]
  [0.49944744 0.61550296]
  [0.11565807 0.42962206]]]

Ouput:
 [[[0.06958116 0.        ]
  [0.         0.        ]
  [0.         0.        ]
  [0.05484366 0.06156513]
  [0.         0.13032712]
  [0.03672989 0.        ]
  [0.         0.        ]
  [0.         0.10176778]
  [0.08615629 0.        ]
  [0.         0.00369972]]]


#### Train

In [15]:
y_true = np.random.rand(10000,10,2)
model.fit(y_true,y_true,epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f49a47056d8>

#### After training

In [16]:
y_true = np.random.rand(1,10,2)
y_pred = model.predict(y_true)
print("Input:\n", y_true)
print("\nOuput:\n", y_pred)

Input:
 [[[0.08020526 0.73851116]
  [0.51485739 0.46720266]
  [0.35596592 0.72512363]
  [0.56238889 0.88581388]
  [0.63692942 0.9095551 ]
  [0.47752342 0.90601615]
  [0.72394758 0.26987862]
  [0.23626585 0.88773468]
  [0.81808181 0.44447507]
  [0.69268159 0.73864892]]]

Ouput:
 [[[0.80697715 0.35397422]
  [0.         0.        ]
  [0.         0.7336159 ]
  [0.8368896  0.9302227 ]
  [0.47217798 0.7486334 ]
  [0.49820584 0.3995404 ]
  [0.         0.9501151 ]
  [0.6346078  0.9214463 ]
  [0.7424135  0.71212643]
  [0.37867773 0.9448007 ]]]
