# Consensus Optimization

This notebook contains the code for the toy experiment in the paper [The Numerics of GANs](https://arxiv.org/abs/1705.10461).

In [10]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as np
import scipy as sp
from scipy import stats
from matplotlib import pyplot as plt
import sys, os
from tqdm import tqdm as tqdm_notebook
tf.reset_default_graph()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def kde(mu, tau, bbox=[-5, 5, -5, 5], save_file="", xlabel="", ylabel="", cmap='Blues'):
    values = np.vstack([mu, tau])
    kernel = sp.stats.gaussian_kde(values)

    fig, ax = plt.subplots()
    ax.axis(bbox)
    ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom='off',      # ticks along the bottom edge are off
        top='off',         # ticks along the top edge are off
        labelbottom='off') # labels along the bottom edge are off
    plt.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left='off',      # ticks along the bottom edge are off
        right='off',         # ticks along the top edge are off
        labelleft='off') # labels along the bottom edge are off
    
    xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    f = np.reshape(kernel(positions).T, xx.shape)
    cfset = ax.contourf(xx, yy, f, cmap=cmap)

    if save_file != "":
        plt.savefig(save_file, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()
        
        
def complex_scatter(points, bbox=None, save_file="", xlabel="real part", ylabel="imaginary part", cmap='Blues'):
    fig, ax = plt.subplots()

    if bbox is not None:
        ax.axis(bbox)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    xx = [p.real for p in points]
    yy = [p.imag for p in points]
    
    plt.plot(xx, yy, 'X')
    plt.grid()

    if save_file != "":
        plt.savefig(save_file, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

In [12]:
# Parameters
learning_rate = 1e-4
reg_param = 10.
batch_size = 512
z_dim = 16
sigma = 0.01
method = 'conopt'
divergence = 'standard'
outdir = os.path.join('gifs', method)
niter = 50000
n_save = 500
bbox = [-1.6, 1.6, -1.6, 1.6]
do_eigen = True

In [13]:
# Target distribution
mus = np.vstack([np.cos(2*np.pi*k/8), np.sin(2*np.pi*k/8)] for k in range(batch_size))
x_real = mus + sigma*tf.random_normal([batch_size, 2])

In [14]:
# Model
def generator_func(z):
    net = slim.fully_connected(z, 16)
    net = slim.fully_connected(net, 16)
    net = slim.fully_connected(net, 16)
    net = slim.fully_connected(net, 16)
    x = slim.fully_connected(net, 2, activation_fn=None)
    return x
        

def discriminator_func(x):
    # Network
    net = slim.fully_connected(x, 16)
    net = slim.fully_connected(net, 16)
    net = slim.fully_connected(net, 16)
    net = slim.fully_connected(net, 16)
    logits = slim.fully_connected(net, 1, activation_fn=None)
    out = tf.squeeze(logits, -1)

    return out

generator = tf.make_template('generator', generator_func)
discriminator = tf.make_template('discriminator', discriminator_func)


In [15]:
z = tf.random_normal([batch_size, z_dim])
x_fake = generator(z)
d_out_real = discriminator(x_real)
d_out_fake = discriminator(x_fake)

# Loss
if divergence == 'standard':
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_out_real, labels=tf.ones_like(d_out_real)
    ))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
    ))
    d_loss = d_loss_real + d_loss_fake

    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_out_fake, labels=tf.ones_like(d_out_fake)
    ))
elif divergence == 'JS':
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_out_real, labels=tf.ones_like(d_out_real)
    ))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
    ))
    d_loss = d_loss_real + d_loss_fake

    g_loss = -d_loss
elif divergence == 'indicator':
    d_loss = tf.reduce_mean(d_out_real - d_out_fake)
    g_loss = -d_loss 
else:
    raise NotImplementedError

In [16]:
g_vars =  tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
d_vars =  tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
optimizer = tf.train.RMSPropOptimizer(learning_rate, use_locking=True)
# optimizer = tf.train.GradientDescentOptimizer(learning_rate, use_locking=True)

# Compute gradients
d_grads = tf.gradients(d_loss, d_vars)
g_grads = tf.gradients(g_loss, g_vars)
# Merge variable and gradient lists
variables = d_vars + g_vars
grads = d_grads + g_grads

    
if method == 'simga':
    apply_vec = list(zip(grads, variables))
elif method == 'conopt':
    # Reguliarizer
    reg = 0.5 * sum(
        tf.reduce_sum(tf.square(g)) for g in grads
    )
    # Jacobian times gradiant
    Jgrads = tf.gradients(reg, variables)
    
    apply_vec = [
         (g + reg_param * Jg, v)
         for (g, Jg, v) in zip(grads, Jgrads, variables) if Jg is not None
    ]
    
else:
    raise NotImplementedError

with tf.control_dependencies([g for (g, v) in apply_vec]):
    train_op = optimizer.apply_gradients(apply_vec)

In [None]:
if do_eigen:
    jacobian_rows = []
    g_grads = tf.gradients(g_loss, g_vars)
    g_grads = [-g for g in g_grads]
    d_grads = tf.gradients(d_loss, d_vars)
    d_grads = [-g for g in d_grads]
    
    for g_idx, g in enumerate(g_grads + d_grads):
        print 'Doing gradient {}/{}'.format(g_idx, len(g_grads + d_grads))
        g = tf.reshape(g, [-1])
        len_g = int(g.get_shape()[0])
        for i in range(len_g):
            if i%10 == 0:
                print '\rDimension {}/{}'.format(i, len_g)
            g_row = tf.gradients(g[i], g_vars)
            d_row = tf.gradients(g[i], d_vars)
        jacobian_rows.append(g_row + d_row)

Doing gradient 1
Dimension 0/256
Dimension 1/256
Dimension 2/256
Dimension 3/256
Dimension 4/256
Dimension 5/256
Dimension 6/256
Dimension 7/256
Dimension 8/256
Dimension 9/256
Dimension 10/256
Dimension 11/256
Dimension 12/256
Dimension 13/256
Dimension 14/256
Dimension 15/256
Dimension 16/256
Dimension 17/256
Dimension 18/256
Dimension 19/256
Dimension 20/256
Dimension 21/256
Dimension 22/256
Dimension 23/256
Dimension 24/256
Dimension 25/256
Dimension 26/256
Dimension 27/256
Dimension 28/256
Dimension 29/256
Dimension 30/256
Dimension 31/256
Dimension 32/256
Dimension 33/256
Dimension 34/256
Dimension 35/256
Dimension 36/256
Dimension 37/256
Dimension 38/256
Dimension 39/256
Dimension 40/256
Dimension 41/256
Dimension 42/256
Dimension 43/256
Dimension 44/256
Dimension 45/256
Dimension 46/256
Dimension 47/256
Dimension 48/256
Dimension 49/256
Dimension 50/256
Dimension 51/256
Dimension 52/256
Dimension 53/256
Dimension 54/256
Dimension 55/256
Dimension 56/256
Dimension 57/256
Dimensi

In [23]:
i = 3
tf.gradients(g[i], g_vars)

ValueError: slice index 3 of dimension 0 out of bounds. for 'strided_slice_1' (op: 'StridedSlice') with input shapes: [1], [1], [1], [1] and with computed input tensors: input[1] = <3>, input[2] = <4>, input[3] = <1>.

In [9]:
def get_J(J_rows):
    J_rows_linear = [np.concatenate([g.flatten() for g in row]) for row in J_rows]
    J = np.array(J_rows_linear)
    return J

def process_J(J, save_file, bbox=None):
    eig, eigv = np.linalg.eig(J)
    eig_real = np.array([p.real for p in eig])
    complex_scatter(eig, save_file=save_file, bbox=bbox)

    
def process_J_conopt(J, reg, save_file, bbox=None):
    J2 = J - reg * np.dot(J.T, J)
    eig, eigv = np.linalg.eig(J2)
    eig_real = np.array([p.real for p in eig])
    complex_scatter(eig, save_file=save_file, bbox=bbox)


In [10]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())


In [11]:
# Real distribution
x_out = np.concatenate([sess.run(x_real) for i in range(5)], axis=0)
kde(x_out[:, 0], x_out[:, 1], bbox=bbox, cmap='Reds', save_file='gt.png')

In [12]:
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
eigrawdir = os.path.join(outdir, 'eigs_raw')
if not os.path.exists(eigrawdir):
    os.makedirs(eigrawdir)
    
eigdir = os.path.join(outdir, 'eigs')
if not os.path.exists(eigdir):
    os.makedirs(eigdir)

        
eigdir_conopt = os.path.join(outdir, 'eigs_conopt')
if not os.path.exists(eigdir_conopt):
    os.makedirs(eigdir_conopt)
    
ztest = [np.random.randn(batch_size, z_dim) for i in range(5)]
progress  = tqdm_notebook(range(niter))

if do_eigen:
    J_rows = sess.run(jacobian_rows)
    J = get_J(J_rows)

for i in progress:
    sess.run(train_op)
    d_loss_out, g_loss_out = sess.run([d_loss, g_loss])
    
    if do_eigen and i % 500 == 0:
        J[:, :] = 0.
        for k in range(10):
            J_rows = sess.run(jacobian_rows)
            J += get_J(J_rows)/10.
        with open(os.path.join(eigrawdir, 'J_%d.npz' % i), 'wb') as f:
            np.save(f, J)

    progress.set_description('d_loss = %.4f, g_loss =%.4f' % (d_loss_out, g_loss_out))
    if i % n_save == 0:
        x_out = np.concatenate([sess.run(x_fake, feed_dict={z: zt}) for zt in ztest], axis=0)
        kde(x_out[:, 0], x_out[:, 1], bbox=bbox, save_file=os.path.join(outdir,'%d.png' % i))




In [13]:
import re
import glob
import matplotlib
matplotlib.rcParams.update({'font.size': 16})

pattern = r'J_(?P<it>0).npz'

bbox = [-3.5, 0.75, -1.2, 1.2]


eigrawdir = os.path.join(outdir, 'eigs_raw')
if not os.path.exists(eigrawdir):
    os.makedirs(eigrawdir)
    
eigdir = os.path.join(outdir, 'eigs')
if not os.path.exists(eigdir):
    os.makedirs(eigdir)

        
eigdir_conopt = os.path.join(outdir, 'eigs_conopt')
if not os.path.exists(eigdir_conopt):
    os.makedirs(eigdir_conopt)
    
out_files = glob.glob(os.path.join(eigrawdir, '*.npz'))
matches = [re.fullmatch(pattern, os.path.basename(s)) for s in out_files]
matches = [m for m in matches if m is not None]

for m in tqdm_notebook(matches):
    it = int(m.group('it'))
    J = np.load(os.path.join(eigrawdir, m.group()))
    process_J(J, save_file=os.path.join(eigdir, '%d.png' % it), bbox=bbox)
    process_J_conopt(J, reg=reg_param, save_file=os.path.join(eigdir_conopt, '%d.png' % it), bbox=bbox)


