In [1]:
"""
Illustration of MMD loss
"""
import numpy as np
import matplotlib.pyplot as plt

do_plot = True

# data
xr = 0.15*np.random.randn(2, 6)  # N-by-1
xg = 0.15*np.random.randn(2, 6)


In [2]:
# plot
if do_plot:
    fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
    plt.scatter(
        xr[0], xr[1], marker='o', c='tab:gray', 
        s=40, linewidths=20, alpha=0.5)
    plt.scatter(
        xg[0], xg[1], marker='o', c='tab:red', 
        s=40, linewidths=20, alpha=0.5)
    ax.legend(
        ['R', 'G'], frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
    plt.show()


In [4]:
import tensorflow as tf
from GeneralTools.math_func import matrix_mean_wo_diagonal
from GeneralTools.graph_func import MySession

max_step = 450
lr = 1e-2
query_step = 100
do_plot = True
use_attractive_loss = True
folder = '../animation_mmd/'
do_save = False


def pd2(m1, m2):
    """ squared pair-wise distance
    
    :param m1: 2-by-N1
    :param m2: 2-by-N2
    :return: 
    """
    aa = tf.reduce_sum(tf.multiply(m1, m1), axis=0, keepdims=True)  # 1-by-N1
    bb = tf.reduce_sum(tf.multiply(m2, m2), axis=0, keepdims=True)  # 1-by-N2
    ab = tf.matmul(m1, m2, transpose_a=True)  # N1-by-N2
    
    return tf.clip_by_value(
        tf.transpose(aa, perm=(1, 0)) + bb - 2.0*ab, 
        clip_value_min=0.0, clip_value_max=10000.0)


def kernel(m, sigma=1.0):
    return tf.exp(-m/sigma)  # N1-by-N2


def e_kernel(m):
    ms = tf.cast(m.get_shape().as_list(), tf.float32)
    return matrix_mean_wo_diagonal(m, ms[0], ms[1])


with tf.Graph().as_default():
    xr_tf = tf.Variable(xr, name='r', dtype=tf.float32)
    xg_tf = tf.Variable(xg, name='g', dtype=tf.float32)
    
    d2r = pd2(xr_tf, xr_tf)
    d2g = pd2(xg_tf, xg_tf)
    d2rg = pd2(xr_tf, xg_tf)
    kr = kernel(d2r, sigma=1.0)
    kg = kernel(d2g, sigma=1.0)
    krg = kernel(d2rg, sigma=1.0)
    
    if use_attractive_loss:
        mmd_att = e_kernel(kr) + e_kernel(kg)
        mmd_rep = - 2*e_kernel(krg)
        
        gr_att = tf.gradients(mmd_att, xr_tf)[0]
        gg_att = tf.gradients(mmd_att, xg_tf)[0]
        gr_rep = tf.gradients(mmd_rep, xr_tf)[0]
        gg_rep = tf.gradients(mmd_rep, xg_tf)[0]
    else:
        mmd_att = e_kernel(kg)
        mmd_rep = - e_kernel(kr)
        
        gr_att = tf.zeros([2, 6])
        gg_att = tf.gradients(mmd_att, xg_tf)[0]
        gr_rep = tf.gradients(mmd_rep, xr_tf)[0]
        gg_rep = tf.zeros([2, 6])
    
    mmd = mmd_att + mmd_rep
    gr = tf.gradients(mmd, xr_tf)[0]
    gg = tf.gradients(mmd, xg_tf)[0]
    lr = tf.constant(lr, dtype=tf.float32)
    opr = tf.compat.v1.assign(xr_tf, xr_tf + gr*lr)
    opg = tf.compat.v1.assign(xg_tf, xg_tf + gg*lr)
    
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    
    fig_index = 0
    for step in range(max_step):
        _, _, mmd_np, xr_np, xg_np = sess.run(
            [opr, opg, mmd, xr_tf, xg_tf])
        
        if step % query_step == 0:
            print('step {}, mmd {}'.format(step, mmd_np))
            if do_plot:
                fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
                plt.scatter(
                    xr_np[0], xr_np[1], marker='o', c='tab:gray', 
                    s=40, linewidths=20, alpha=0.5)
                plt.scatter(
                    xg_np[0], xg_np[1], marker='o', c='tab:red', 
                    s=40, linewidths=20, alpha=0.5)
                ax.legend(
                    ['R', 'G'], frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
                
                gr_att_np, gg_att_np, gr_rep_np, gg_rep_np = sess.run(
                    [gr_att/2.0, gg_att/2.0, gr_rep/2.0, gg_rep/2.0])
                for i in range(xr_np.shape[1]):
                    if use_attractive_loss:
                        plt.arrow(
                            xr_np[0, i], xr_np[1, i], gr_att_np[0, i], gr_att_np[1, i], 
                            color='tab:blue', width=0.01)
                    plt.arrow(
                        xr_np[0, i], xr_np[1, i], gr_rep_np[0, i], gr_rep_np[1, i], 
                        color='tab:orange', width=0.01)
                for i in range(xg_np.shape[1]):
                    plt.arrow(
                        xg_np[0, i], xg_np[1, i], gg_att_np[0, i], gg_att_np[1, i], 
                        color='tab:blue', width=0.01)
                    if use_attractive_loss:
                        plt.arrow(
                            xg_np[0, i], xg_np[1, i], gg_rep_np[0, i], gg_rep_np[1, i], 
                            color='tab:orange', width=0.01)
                
                _, _, _, _ = plt.axis([-1.0, 1.0, -1.0, 1.0])
                
                if do_save:
                    fig_index = fig_index+1
                    if use_attractive_loss:
                        figurename = 'mmd_att_{:03d}.png'.format(fig_index)
                    else:
                        figurename = 'mmd_rep_{:03d}.png'.format(fig_index)
                    plt.savefig(
                        folder + figurename, format='png', bbox_inches='tight')
                else:
                    plt.show()
        
    sess.close()
    plt.close('all')


step 0, mmd -0.024196505546569824


step 30, mmd -0.015836477279663086
step 60, mmd -0.0023784637451171875


step 90, mmd 0.02160811424255371
step 120, mmd 0.06640219688415527


step 150, mmd 0.1494077444076538
step 180, mmd 0.2934000492095947


step 210, mmd 0.5120537281036377
step 240, mmd 0.7853901386260986


step 270, mmd 1.0607848167419434
step 300, mmd 1.2927263975143433


step 330, mmd 1.4675564765930176


step 360, mmd 1.5929691791534424
step 390, mmd 1.6819651126861572


step 420, mmd 1.7456884384155273


In [5]:
max_step = 600
lr = 1e-2
query_step = 100
folder = '../animation_mmd/'
do_save = False

with tf.Graph().as_default():
    xr_tf = tf.Variable(xr_np, name='r', dtype=tf.float32)
    xg_tf = tf.Variable(xg_np, name='g', dtype=tf.float32)
    
    d2r = pd2(xr_tf, xr_tf)
    d2g = pd2(xg_tf, xg_tf)
    d2rg = pd2(xr_tf, xg_tf)
    kr = kernel(d2r, sigma=1.0)
    kg = kernel(d2g, sigma=1.0)
    krg = kernel(d2rg, sigma=1.0)
    
    mmd_att = e_kernel(kr) + e_kernel(kg)
    mmd_rep = - 2*e_kernel(krg)
    
    gg_att = tf.gradients(mmd_att, xg_tf)[0]
    gg_rep = tf.gradients(mmd_rep, xg_tf)[0]
    
    mmd = mmd_att + mmd_rep
    gg = tf.gradients(mmd, xg_tf)[0]
    lr = tf.constant(lr, dtype=tf.float32)
    opg = tf.compat.v1.assign(xg_tf, xg_tf - gg*lr)
    
    sess = tf.Session()
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    
    fig_index = 0
    for step in range(max_step):
        _, mmd_np2, xr_np2, xg_np2 = sess.run(
            [opg, mmd, xr_tf, xg_tf])
        
        if step % query_step == 0:
            print('step {}, mmd {}'.format(step, mmd_np2))
            if do_plot:
                fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
                plt.scatter(
                    xr_np2[0], xr_np2[1], marker='o', c='tab:gray', 
                    s=40, linewidths=20, alpha=0.5)
                plt.scatter(
                    xg_np2[0], xg_np2[1], marker='o', c='tab:red', 
                    s=40, linewidths=20, alpha=0.5)
                ax.legend(
                    ['R', 'G'], loc='upper left', frameon=True, fontsize=15, labelspacing=1, borderpad=0.5)
                
                gg_att_np, gg_rep_np = sess.run(
                    [gg_att/2.0, gg_rep/2.0])
                for i in range(xg_np2.shape[1]):
                    plt.arrow(
                        xg_np2[0, i], xg_np2[1, i], -gg_att_np[0, i], -gg_att_np[1, i], 
                        color='tab:orange', width=0.01)
                    plt.arrow(
                        xg_np2[0, i], xg_np2[1, i], -gg_rep_np[0, i], -gg_rep_np[1, i], 
                        color='tab:blue', width=0.01)
                
                _, _, _, _ = plt.axis([-1.0, 1.0, -1.0, 1.0])
                
                if do_save:
                    fig_index = fig_index+1
                    if use_attractive_loss:
                        figurename = 'g_mmd_att_{:03d}.png'.format(fig_index)
                    else:
                        figurename = 'g_mmd_rep_{:03d}.png'.format(fig_index)
                    plt.savefig(
                        folder + figurename, format='png', bbox_inches='tight')
                else:
                    plt.show()
                    
    sess.close()
    plt.close('all')


step 0, mmd 1.7907965183258057
step 100, mmd 1.7025774717330933


step 200, mmd 1.533265471458435
step 300, mmd 1.1779274940490723


step 400, mmd 0.636290431022644
step 500, mmd 0.24290156364440918
