In [1]:
from __future__ import division, print_function, unicode_literals

import os
import sys
import numpy as np
import math

In [2]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [3]:
import tensorflow as tf
from ops import *
from utils import *

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [6]:
dim = 2

X = tf.placeholder(shape=[None, dim], dtype=tf.float32, name="X")
Y = tf.placeholder(shape=[None, dim], dtype=tf.float32, name="Y")

energy_pos = energy_func(X)
energy_neg = energy_func(Y)

In [15]:
batchsize = 150

def get_median(v):
###     v = tf.reshape(v, [-1])
###     m = v.get_shape()[0]//2
    ### v: [N,N,dim]
    v = tf.transpose(v,[0,2,1]) ### [N,dim,N]
    m = batchsize//2
    return tf.reduce_min(tf.nn.top_k(v, m, sorted=False).values, axis=-1) ### [N,dim]

### Eucleadian Kernel:
def RBF_kernel(x_in, scale=1., bandwidth=-1, n=batchsize):
    x = tf.reshape(x_in,[n,-1]) ### Necessary for median calculation!!! [N,dim]
    
    x1 = tf.expand_dims(x,axis=1) ### [none,1,dim]
    x2 = tf.expand_dims(x,axis=0) ### [1,none,dim]
    dist = x1 - x2 ### [none,none,dim]
    
###     dist2 = tf.reduce_sum( dist**2, axis=-1 ) ### [none,none]
    dist2 = dist**2 ### [none,none,dim]
    
    if bandwidth < 0: ### using median trick
        median = get_median(dist2) ### [N,dim]
        bandwidth = tf.sqrt( 0.5 * median / tf.log( tf.cast(tf.shape(x)[0],dtype=tf.float32) +1) )
        bandwidth = bandwidth / scale
        bandwidth = tf.expand_dims(bandwidth,axis=1) ### [n,1,dim]
        
    kernel = - dist2 / bandwidth**2/ 2. ### [none,none,dim]
    ################################
#     #### Compute Kernel derivatives:
#     dx_kernel = - tf.matmul(kernel, x) ### [N,N] @ [N,dim] = [N,dim]
#     sum_kernel = tf.reduce_sum(kernel,axis=1,keepdims=True) ### [N,1]
#     grad = dx_kernel + x*sum_kernel ### [N,dim] + [N,dim]*[N,1] = [N,dim]
#     grad = grad / (bandwidth**2)
    
#     return (kernel,grad)
    return kernel

######################################################

### Torus Kernel:
def RBF_kernel_torus(x_in, scale=1., bandwidth=-1, n=batchsize):
    x = tf.reshape(x_in,[n,-1]) ### Necessary for median calculation!!! [N,dim]
    
    x1 = tf.expand_dims(x,axis=1) ### [none,1,dim]
    x2 = tf.expand_dims(x,axis=0) ### [1,none,dim]
    
    dist_cos = tf.cos(x1) - tf.cos(x2)
    dist_sin = tf.sin(x1) - tf.sin(x2)
    
    dist2 = dist_cos**2 + dist_sin**2 ### [none,none,dim]
###    dist2 = tf.reduce_sum( dist2, axis=-1 ) ### [none,none]
    
    if bandwidth < 0: ### using median trick
        median = get_median(dist2) ### [N,dim]
        bandwidth = tf.sqrt( 0.5 * median / tf.log( tf.cast(tf.shape(x)[0],dtype=tf.float32) +1) )
        bandwidth = bandwidth / scale
        bandwidth = tf.expand_dims(bandwidth,axis=1) ### [n,1,dim]
        
    kernel = - dist2 / bandwidth**2/ 2. ### [none,none,dim]

#     kernel = tf.exp( tf.reduce_mean( - dist2 / bandwidth**2/ 2., axis=-1) ) ### [none,none]
#     grad = - 2./bandwidth * tf.expand_dims(kernel,axis=-1) ### [none,none,1]
#     grad = grad * ( dist_cos * tf.expand_dims(tf.sin(x),axis=0) - dist_sin * tf.expand_dims(tf.cos(x),axis=0) ) ### [none,none,dim]   
#     grad = tf.reduce_sum(grad,axis=1) ### [none,dim]
    
#     return (kernel,grad)
    return kernel


In [25]:
def compute_kernel(x):
    X_euc = tf.transpose( tf.gather( tf.transpose(x),[0]) ) ### Euclidean
    X_torus = tf.transpose( tf.gather( tf.transpose(x),[1]) ) ### Torus
    
    kernel_euc = RBF_kernel(X_euc) ### [none,none,dim1]
    kernel_torus = RBF_kernel_torus(X_torus) ### [none,none,dim2]
    
    kernel = tf.concat([kernel_euc, kernel_torus], axis=-1) ### [N,N,dim]
    kernel = tf.exp( tf.reduce_mean(kernel, axis=-1) ) ### [N,N]

    ###grad_kernel = tf.gradients(kernel,[X])
    
    return kernel

In [26]:
#### Define BT Loops
num_steps = tf.placeholder(shape=[], dtype=tf.int32, name="bt.steps") ### 100
stepsize = tf.placeholder(shape=[], dtype=tf.float32, name="ss") ### 1e-1
beta = tf.placeholder(shape=[], dtype=tf.float32, name="beta") ### KbT: 2.494339 kJ/mol ### inverse temperature

steps = tf.constant(0)
E_g2 = tf.zeros(tf.shape(Y))
c = lambda i, x, eg2: tf.less(i, num_steps)

In [27]:
def bt_step(counter, x, E_g2, stepsize=stepsize, beta=beta ):
    
    energy = energy_func(x)
    force = - beta * tf.gradients(energy,[x])[0]
    
    kernel = compute_kernel(x)
    grad_x2 = tf.gradients(kernel,[x])[0]
    
    grad_x1 = tf.stop_gradient( tf.matmul(kernel,force) / batchsize) ### [Ni,dim]
    grad_x2 = tf.stop_gradient( grad_x2 / batchsize )
    
    grad_x = grad_x1 + grad_x2
    
    ### AdaGrad
    decay_rate = 0.9
    fudge_factor = 1e-6
    
    def f1(): return grad_x**2
    def f2(): return decay_rate * E_g2 + (1. - decay_rate) * (grad_x ** 2)
    E_g2 = tf.cond( 
        tf.equal(counter,tf.constant(0)) , 
        f1,
        f2, 
    )         

    adj_grad = tf.divide( grad_x, tf.sqrt(E_g2+fudge_factor) )

    ################################
    x = x + stepsize * adj_grad   
    counter += 1
    
    return [counter, x, E_g2]

In [29]:
steps, Y_BT, _ = tf.while_loop(c, bt_step, [steps, Y, E_g2])

In [None]:

### with tf.Session() as sess:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
        
        
        # ---------------------
        #  Boltzmann Transformer
        # ---------------------
        beta_set = 1.
        stepsize_set = 5e-2 ### 1e-1
        num_steps_set = 250
        
        if_annealing = True
        anneal_steps = 50
        beta_anneal = 0.3 * beta_set
        
        if if_annealing:
            feed_dict={Y:neg_imgs,
               beta:beta_anneal, num_steps:anneal_steps, stepsize:stepsize_set,
               is_training:False, }
            
            neg_imgs = sess.run( Y_BT, feed_dict=feed_dict) ### [N,dims=2]            
            neg_imgs = np.sign(np.sin(neg_imgs)) * np.arccos(np.cos(neg_imgs))
        ##########
        
        feed_dict={Y:neg_imgs,
           beta:beta_set, num_steps:num_steps_set, stepsize:stepsize_set,
           is_training:False, }
            
        neg_imgs = sess.run( Y_BT, feed_dict=feed_dict) ### [N,dims=2] 
        neg_imgs = np.sign(np.sin(neg_imgs)) * np.arccos(np.cos(neg_imgs))
        #############################################################
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        feed_dict={X:pos_imgs, Y:neg_imgs, wt:pos_weight,
                   is_training:True, lr:1e-4, reg:1e-2, }

        _, _loss, _loss_lh, _loss_drift, _energy_pos, _energy_neg = sess.run(
                [ train_op, loss, loss_lh, loss_drift, energy_pos, energy_neg ], feed_dict=feed_dict)        
        