# Differentially Private Moment Discrepancy

## CONFIG.

Configs that worked:
- datadim=1, kmmd, k=100, width=3, depth=2, zdim=2, opt=rmsprop, elu, lr=5e-3, bs=64
- datadim=1, cmd, k=5, width=3, depth=2, zdim=2, opt=rmsprop, elu, lr=5e-3, bs=64
- datadim=2, cmd, k=5, width=3, depth=2, zdim=2, opt=rmsprop, elu, lr=5e-3, bs=64, "two_gaussians"
- datadim=2, cmd, k=10, width=3, depth=2, zdim=2, opt=rmsprop, elu, lr=5e-3, bs=64, "two_gaussians"
- data_file='yangmed.csv', kmmd, k=5, width=100, depth=3, zdim=64, opt=rmsprop, elu, lr=5e-3, bs=256

Configs that worked, but only for marginals:
- data_file='yangmed.csv', cmd, k=10, width=100, depth=3, zdim=64, opt=rmsprop, elu, lr=5e-3, bs=256

In [None]:
# Jupyter config.
jupyter_verbose = True
extra_verbose = False

import numpy as np
import tensorflow as tf

# Model config.
model_type = 'cmd_gan'
cmd_variation = 'cmd'  # {'onetime_noisy', 'onetime_noisy_joint', 'dp_sgd', 'mmd', 'kmmd', cmd', 'ncmd', 'ncmd_jmd'}
do_cmd_taylor_weights = False
k_moments = 5
width = 3
depth = 2
z_dim = 2
optimizer = 'rmsprop'
activation = tf.nn.elu
learning_rate = 5e-3
lr_update_step = 10000
lr_minimum = 1e-6

#data_file = 'yangmed.csv'
data_file = ''
data_dim = 2
#clip_unnormed = np.array([[2., 8.5], [-1., 6.]])  # Must be 2d array of floats.
clip_unnormed = None
if clip_unnormed is not None:
    assert clip_unnormed.shape == (data_dim, 2)

sigma = 1
laplace_eps = 1.0  # 1.0
default_gradient_l2norm_bound = 4.0  # 4.0
sgd_target_eps = [0.125, 0.25, 0.5, 1., 2., 4., 8.]  # [0.125, 0.25, 0.5, 1., 2., 4., 8.]
sgd_eps = 1.0  # 1.0
sgd_delta = 4.0  # 4.0
sgd_sigma = 4.0  # 4.0

data_num_init = 5000
percent_train = 0.9
batch_size = 64
gen_num = 64
log_step = 100
max_step = 200000

tag = 'test'
load_existing = False

args = [model_type, cmd_variation, default_gradient_l2norm_bound,
        laplace_eps, sgd_target_eps, sgd_eps, sgd_delta, sgd_sigma,
        data_num_init, data_dim, percent_train, batch_size, gen_num,
        width, depth, z_dim, log_step, max_step, learning_rate,
        lr_update_step, lr_minimum, optimizer, data_file, k_moments,
        sigma, tag, load_existing, activation]
args = [str(a) for a in args]

### Imports.

In [None]:
""" This script runs differential privacy SGD (moment accountant) on a GAN
    with moment discrepancy loss.
"""
import argparse
from time import time
import os
import pdb
import shutil
import sys
import numpy as np
from numpy.linalg import norm

from IPython.display import Image, display
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib.gridspec import GridSpec
from scipy.spatial.distance import pdist
from scipy.stats import truncnorm, pearsonr
import tensorflow as tf
layers = tf.layers

#sys.path.append('/home/maurice/mmd')
from mmd_utils import (compute_mmd, compute_kmmd, compute_cmd,
                       compute_joint_moment_discrepancy,
                       compute_noncentral_moment_discrepancy,
                       compute_moments, compute_central_moments,
                       compute_kth_central_moment, MMD_vs_Normal_by_filter,
                       dp_sensitivity_to_expectation)

from dp_optimizer import DPGradientDescentOptimizer
from sanitizer import AmortizedGaussianSanitizer, ClipOption
from accountant import GaussianMomentsAccountant, EpsDelta
from utils import NetworkParameters, LayerParameters, BuildNetwork

### Helper functions.

In [None]:
def mm(arr):
    """Prints min and max of array."""
    print('  {}, {}'.format(np.min(arr), np.max(arr)))


def get_random_z(gen_num, z_dim, for_training=True):
    """Generates 2d array of noise input data."""
    #return np.random.uniform(size=[gen_num, z_dim],
    #                         low=-1.0, high=1.0)
    if for_training:
        return np.random.normal(size=[gen_num, z_dim])
    else:
        return truncnorm.rvs(-3, 3, size=[gen_num, z_dim])


def reduce_var(x, axis=None, keepdims=False):
    """Variance of a tensor, alongside the specified axis.

    Args:
        x: A tensor or variable.
        axis: An integer, the axis to compute the variance.
        keepdims: A boolean, whether to keep the dimensions or not.
            If `keepdims` is `False`, the rank of the tensor is reduced
            by 1. If `keepdims` is `True`,
            the reduced dimension is retained with length 1.

    Returns:
        A tensor with the variance of elements of `x`.
    """
    m = tf.reduce_mean(x, axis=axis, keepdims=True)
    devs_squared = tf.square(x - m)
    return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)


def reduce_std(x, axis=None, keepdims=False):
    """Standard deviation of a tensor, alongside the specified axis.

    Args:
        x: A tensor or variable.
        axis: An integer, the axis to compute the standard deviation.
        keepdims: A boolean, whether to keep the dimensions or not.
            If `keepdims` is `False`, the rank of the tensor is reduced
            by 1. If `keepdims` is `True`,
            the reduced dimension is retained with length 1.

    Returns:
        A tensor with the standard deviation of elements of `x`.
    """
    return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))

### Generator and autoencoder TF functions.

In [None]:
def dense(x, width, activation, batch_residual=False, use_bias=False, name=None,
          bounds=None):
    """Wrapper on fully connected TensorFlow layer.
    
    Args:
      x: Layer input.
      width: Width of output layer.
      activation: TensorFlow activation function.
      batch_residual: Flag to use batch residual output.
      use_bias: Flag to use bias in fully connected layer.
      bounds: List of min and max scalar values to clip to. [min, max]
    """
    if batch_residual:
        option = 1
        if option == 1:
            x_ = layers.dense(x, width, activation=activation, use_bias=use_bias)
            if bounds is not None:
                x_ = tf.clip_by_value(x_, bounds[0], bounds[1])
            r = layers.batch_normalization(x_) + x
        elif option == 2:
            x_newdim = layers.dense(x, width, activation=activation,
                                    use_bias=use_bias, name=name)
            x_newdim = layers.batch_normalization(x_newdim)
            x_newdim = tf.nn.relu(x_newdim)
            x = layers.dense(x_newdim, width, activation=activation,
                              use_bias=use_bias, name=name)
            x = layers.batch_normalization(x)
            x = tf.nn.relu(x)
            r = x + x_newdim
    else:
        x_ = layers.dense(x, width, activation=activation, use_bias=use_bias)
        if bounds is not None:
            x_ = tf.clip_by_value(x_, bounds[0], bounds[1])
        r = layers.batch_normalization(x_)
    return r


def fully_connected(x, num_units, scope):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
        w = tf.get_variable(
            "weights",
            [x.shape[1], num_units],
            initializer=tf.truncated_normal_initializer(
                stddev=1.0 / np.sqrt(num_units)))
        b = tf.get_variable(
            "biases",
            [num_units],
            initializer=tf.constant_initializer(0.0))
        return tf.matmul(x, w) + b


def generator(z_in, width=3, depth=3, activation=tf.nn.elu, out_dim=2,
              reuse=False, bounds=None):
    """Decodes. Generates output, given noise input."""
    bounds = None 
    if bounds == None:
        print('No bounds on Generator')
    with tf.variable_scope('generator', reuse=reuse) as vs_g:
        x = dense(z_in, width, activation=activation, bounds=bounds,
                  batch_residual=False, use_bias=False)

        for idx in range(depth - 1):
            # TODO: Should this use batch resid, and is it defined properly?
            x = dense(x, width, activation=activation, bounds=bounds,
                      batch_residual=False, use_bias=False)

        out = dense(x, out_dim, activation=None, bounds=bounds,
                    batch_residual=False, use_bias=False)
    vars_g = tf.contrib.framework.get_variables(vs_g)
    return out, vars_g


def generator_v2(inputs, network_parameters):
    #with tf.variable_scope('generator', reuse=reuse) as vs_g:
    #with tf.variable_scope('generator_v2', reuse=tf.AUTO_REUSE):
        #out, _, training_params = BuildNetwork(inputs, network_parameters)

    num_inputs = network_parameters.input_size
    outputs = inputs

    # First hidden layer.
    h0_params = network_parameters.layer_parameters[0]
    outputs_fc = fully_connected(outputs, h0_params.num_units, h0_params.name)
    outputs_bn = (outputs_fc - tf.reduce_mean(outputs_fc, axis=0)) / (
        reduce_std(outputs_fc, axis=0))
    outputs = tf.nn.elu(outputs_bn)

    # Remainder of hidden layers with batch residual.
    for h_params in network_parameters.layer_parameters[1:]:
        if 'hidden' in h_params.name:
            num_units = h_params.num_units
            outputs_fc = fully_connected(outputs, num_units, h_params.name)
            outputs_bn = (outputs_fc - tf.reduce_mean(outputs_fc, axis=0)) / (
                reduce_std(outputs_fc, axis=0))
            outputs = tf.nn.elu(outputs_bn)
            #outputs = outputs_bn + outputs

    # Do final layer.
    hl_params = network_parameters.layer_parameters[-1]
    num_units = hl_params.num_units
    outputs_fc = fully_connected(outputs, num_units, hl_params.name)
    outputs = outputs_fc

    return outputs 


def autoencoder(x, width=3, depth=3, activation=tf.nn.elu, z_dim=3,
                reuse=False, normed_weights=False, normed_encs=False,
                bounds=None):
    """Autoencodes input via a bottleneck layer h."""
    out_dim = x.shape[1]
    with tf.variable_scope('encoder', reuse=reuse) as vs_enc:
        x = dense(x, width, activation=activation, bounds=bounds)

        for idx in range(depth - 1):
            # TODO: Should this use batch resid, and is it defined properly?
            x = dense(x, width, activation=activation, batch_residual=True,
                      bounds=bounds)

        h = dense(x, z_dim, activation=None, bounds=bounds)

    with tf.variable_scope('decoder', reuse=reuse) as vs_dec:

        x = dense(h, width, activation=activation, name='hidden',
                  bounds=bounds)

        for idx in range(depth - 1):
            x = dense(x, width, activation=activation, batch_residual=True,
                      bounds=bounds)

        ae = dense(x, out_dim, activation=None, bounds=bounds)

    vars_enc = tf.contrib.framework.get_variables(vs_enc)
    vars_dec = tf.contrib.framework.get_variables(vs_dec)

    return h, ae, vars_enc, vars_dec

### Load and normalize raw data.

In [None]:
def five_num(v):
    """Returns Tukey's five number summary (minimum, lower-hinge, median, upper-hinge, maximum) for the input vector, a list or array of numbers based on 1.5 times the interquartile distance"""
    try:
        np.sum(v)
    except TypeError:
        print('Error: you must provide a list or array of only numbers')
    q1 = np.percentile(v, 25)
    q3 = np.percentile(v, 75)
    iqd = q3 - q1
    md = np.median(v)
    whisker = 1.5 * iqd
    return np.min(v), q1, md, q3, np.max(v)


def load_normed_data(data_num, percent_train, log_dir, clip_unnormed=None,
                     clip=None, data_file=None):
    """Generates data, and returns it normalized, along with helper objects."""
    
    # Load data.
    if data_file not in ['', None]:
        if data_file.endswith('npy'):
            data_raw = np.load(data_file)
        elif data_file.endswith('txt'):
            data_raw = np.loadtxt(open(data_file, 'rb'), delimiter=' ')
        elif data_file.endswith('csv'):
            data_raw = np.loadtxt(open(data_file, 'rb'), delimiter=',')
        
        #########################################################################
        # YANG MED DATA (begin)
        
        if 'yang' in data_file:
            d = np.loadtxt(open(data_file, 'rb'), delimiter=',')
            print('Using first 10k of data.')
            d = d[:20000]
            d = np.delete(d, 226, axis=0)  # Error.
            d = np.random.permutation(d)
            num_rows = d.shape[0]
            num_cols = d.shape[1]

            # Separate train and test data.
            num_train = int(percent_train * num_rows)
            d_train = d[:num_train]
            d_test = d[num_train:]

            print('\nRaw data:')
            for i in range(num_cols):
                print('{: >8},{: >8},{: >8},{: >8},{: >8}'.format(
                    *five_num(d[:,i])))

            binary_cols = []
            for col in range(num_cols):
                col_data = d[:, col]
                if np.array_equal(col_data, col_data.astype(bool)):
                    binary_cols.append(col)
            print('binary_cols={}'.format(binary_cols))

            # Don't standardize binary vars.
            #mean_mask = np.array([1] * num_cols)
            mean_mask = np.ones(num_cols)
            mean_mask[binary_cols] = 0.  # Do not subtract mean from binary vars.
            mean_vec = d_train.mean(0) * mean_mask
            std_vec = d_train.std(0)
            std_vec[binary_cols] = 1.  # Do not divide by std for binary vars.
            data = (d_train - mean_vec) / std_vec
            data_test = (d_test - mean_vec) / std_vec

            #print('\nStandardized data:')
            #for i in range(data.shape[1]):
            #    print('{: >8},{: >8},{: >8},{: >8},{: >8}'.format(
            #        *five_num(data[:,i])))
                
            # Save copy of raw data.
            np.save(os.path.join(log_dir, 'data_raw.npy'), d)

            # Set up a few helpful constants.
            data_num = d_train.shape[0]
            data_test_num = d_test.shape[0]
            out_dim = data.shape[1]

            # Adjust clip values to use the tighest interval.
            # For clip lows, use greater of clip_low and min(data).
            # For clip highs, use lesser of clip_high and max(data).
            if clip_unnormed is not None:
                clip_unnormed[:, 0] = np.max(np.vstack(
                    (clip_unnormed[:, 0], np.min(d, axis=0))), axis=0)
                clip_unnormed[:, 1] = np.min(np.vstack(
                    (clip_unnormed[:, 1], np.max(d, axis=0))), axis=0)
                clip = ((clip_unnormed - mean_vec.reshape(-1, 1)) / std_vec.reshape(-1, 1))

            return (data, data_test, data_num, data_test_num, out_dim,
                    mean_vec, std_vec, clip)

        # YANG MED DATA (end)
        #################################################################################

        
    else:
        if data_dim == 1:
            
            data_raw = np.zeros((data_num, 1))
            for i in range(data_num):
                # Pick a Gaussian, then generate from that Gaussian.
                i_cluster = np.random.binomial(1, 0.1)  # NOTE: Setting p=0/p=1 chooses one cluster.
                if i_cluster == 0:
                    data_raw[i] = np.random.normal(5, 0.5)
                else:
                    data_raw[i] = np.random.normal(6.5, 0.5)
            
            # Plot clipped version of data.
            if clip_unnormed is not None:
                data_clipped = np.clip(data_raw, 
                                       clip_unnormed[0][0], clip_unnormed[0][1])            
                fig, ax = plt.subplots()
                ax.hist(data_raw, density=True, bins=30, color='gray', alpha=0.3,
                        label='data')
                ax.hist(data_clipped, density=True, bins=30, color='blue', alpha=0.3,
                        label='clipped')
                plt.legend()
                filepath = os.path.join(plot_dir, '{}.png'.format('data_clipped'))
                plt.savefig(filepath)
                plt.close(fig)
                if jupyter_verbose:
                    display(Image(filename=filepath))
                               
        elif data_dim == 2:
            
            design = 'two_gaussians'
            print('Using data set {}'.format(design))
            if design == 'two_gaussians':
                data_raw = np.zeros((data_num, 2))
                for i in range(data_num):
                    if np.random.binomial(1, 0.5):  # 2nd param controls mixture.
                        data_raw[i] = \
                            np.random.multivariate_normal([4., 2.], [[0.5, 0.], [0., 0.5]], 1)
                    else:
                        data_raw[i] = \
                            np.random.multivariate_normal([6., 4.], [[0.5, 0.1], [0.1, 0.5]], 1)

            elif design == 'noisy_sin':
                x = np.linspace(0, 10, data_num)
                y = np.sin(x) + np.random.normal(0, 0.5, len(x))
                data_raw = np.hstack((np.expand_dims(x, axis=1),
                                      np.expand_dims(y, axis=1)))
                data_raw = data_raw[np.random.permutation(data_num)]
            elif design == 'uniform':
                data_raw = np.zeros((data_num, 2))
                for i in range(data_num):
                    if np.random.binomial(1, 0.5):
                        x_sample = np.random.uniform(0, 1)
                        y_sample = np.random.uniform(0, 1)
                    else:
                        x_sample = np.random.uniform(-1, 0)
                        y_sample = np.random.uniform(-1, 0)
                    data_raw[i] = [x_sample, y_sample]
            
            # Plot clipped version of data.
            if clip_unnormed is not None:
                data_clipped = np.zeros(data_raw.shape)
                data_clipped[:, 0] = np.clip(data_raw[:, 0],
                                             clip_unnormed[0][0],
                                             clip_unnormed[0][1])
                data_clipped[:, 1] = np.clip(data_raw[:, 1],
                                             clip_unnormed[1][0],
                                             clip_unnormed[1][1])
                fig, ax = plt.subplots()
                ax.scatter(*zip(*data_raw), color='gray',
                                 alpha=0.2, label='raw')
                ax.scatter(*zip(*data_clipped), color='green',
                                 alpha=0.2, label='clipped')
                plt.legend()
                filepath = os.path.join(plot_dir, '{}.png'.format('data_clipped'))
                plt.savefig(filepath)
                plt.close(fig)
                if jupyter_verbose:
                    display(Image(filename=filepath))
                    
    # Save copy of raw data.
    np.save(os.path.join(log_dir, 'data_raw.npy'), data_raw)    

    
    # First split data into Train/Test, then normalize both.
    # Split.
    num_train = int(percent_train * data_raw.shape[0])
    data_raw_train = data_raw[:num_train]
    data_raw_test = data_raw[num_train:]
    # Normalize (based only on training data).
    data_raw_train_mean = np.mean(data_raw_train, axis=0)
    data_raw_train_std = np.std(data_raw_train, axis=0)
    data = (data_raw_train - data_raw_train_mean) / data_raw_train_std  # Normed training data
    data_test = (data_raw_test - data_raw_train_mean) / data_raw_train_std  # Normed test data
    
    
    # Set up a few helpful constants.
    data_num = data.shape[0]
    data_test_num = data_test.shape[0]
    out_dim = data.shape[1]
    
    # Adjust clip values to use the tighest interval.
    # For clip lows, use greater of clip_low and min(data).
    # For clip highs, use lesser of clip_high and max(data).
    if clip_unnormed is not None:
        clip_unnormed[:, 0] = np.max(np.vstack(
            (clip_unnormed[:, 0], np.min(data_raw, axis=0))), axis=0)
        clip_unnormed[:, 1] = np.min(np.vstack(
            (clip_unnormed[:, 1], np.max(data_raw, axis=0))), axis=0)
        clip = ((clip_unnormed - data_raw_train_mean.reshape(-1, 1)) / 
                data_raw_train_std.reshape(-1, 1))

    return (data, data_test, data_num, data_test_num, out_dim,
            data_raw_train_mean, data_raw_train_std, clip)


def unnormalize(data_normed, data_raw_mean, data_raw_std):
    """Unnormalizes data based on mean and std."""
    return data_normed * data_raw_std + data_raw_mean


def print_baseline_moment_stats(d_normed, d_mean, d_std, k_moments):
    """Compute baseline statistics on moments for data set."""
    d = d_normed * d_std + d_mean
    baseline_moments = compute_moments(d, k_moments)
    for j in range(k_moments):
        print('Moment {}: {}'.format(j+1, baseline_moments[j]))

### Sensitivity and noisy moments.

In [None]:
def make_fixed_batches(data, batch_size):
    # Partition data into fixed batches of batch_size.
    data_num = data.shape[0]
    data_dim = data.shape[1]
    fixed_batches = np.array(
        [data[i:i + batch_size] for i in range(0, len(data), batch_size)])
    fixed_batches = np.array(
        [b for b in fixed_batches if len(b) == batch_size])
    return fixed_batches

def compute_sensitivities(data, batch_size, k_moments, clip=None):
    """Computes all forms of sensitivity.
    
    Sensitivities are based on batch_size, number of moments, and 
    clipping.
    
        clip: [[dim1_low, dim1_high],
               [dim2_low, dim2_high],
                ...,
               [dimD_low, dimD_high]]
        moment_sensitivities:
              [[m1_dim1, ..., m1_dimD],
               [m2_dim1, ..., m2_dimD],
                ...,
               [mk_dim1, ..., mk_dimD],]
    """
    # Verify dimensions of clip. Should have a min and max for each data dim.
    if clip is not None:
        assert clip.shape == (data.shape[1], 2), 'clip must be shape (data_dim, 2)'
    
    # MOMENT sensitivities for entire data set.
    moment_sensitivities = np.zeros((k_moments, data.shape[1]))
    if clip is not None:
        for k in range(1, k_moments + 1):
            mk_sens = np.power(np.max(np.abs(clip), axis=1), k) / batch_size  # D-dimensional 
            moment_sensitivities[k - 1, :] = mk_sens
    else:
        for k in range(1, k_moments + 1):
            mk_sens = np.power(np.max(np.abs(data), axis=0), k) / batch_size  # D-dimensional
            moment_sensitivities[k - 1, :] = mk_sens 

    # JOINT-MOMENT sensitivities for entire data set.
    jmoment_sensitivities = np.zeros((k_moments, 1))
    if clip is not None:
        for k in range(1, k_moments + 1):
            jmk_sens = np.power(np.prod(np.max(np.abs(clip), axis=1)), k) / batch_size  # 1-dimensional.
            jmoment_sensitivities[k - 1, :] = jmk_sens
    else:
        for k in range(1, k_moments + 1):
            # For medical data, only compute joint sensitivity over non-binary cols.
            if 'yang' in data_file:
                jmk_sens = np.power(np.max(
                    np.abs(np.prod(data[:, :17], axis=1)), axis=0), k) / batch_size
                jmoment_sensitivities[k - 1, :] = jmk_sens
            else:
                jmk_sens = np.power(np.max(
                    np.abs(np.prod(data, axis=1)), axis=0), k) / batch_size
                jmoment_sensitivities[k - 1, :] = jmk_sens

            
    # Compute QUANTILE sensitivities for the entire data set, i.e. the
    # the difference between the quantile value and the next value in a sorted
    # array.
    n = data.shape[0]
    quantiles = [0.01, 0.1, 0.5, 0.9, 0.99]
    global_quantile_values = np.zeros((data_dim, len(quantiles)))
    quantile_sensitivities = np.zeros((data_dim, len(quantiles)))
    for dim in range(data_dim):
        component_data = sorted(data[:, dim])
        for q_i, q_v in enumerate(quantiles):
            # Index associated with quantile.
            index_q = int(np.ceil(n * q_v))
            global_quantile_values[dim, q_i] = component_data[index_q] 
            quantile_sensitivities[dim, q_i] = (
                component_data[index_q + 1] - component_data[index_q])

    return (moment_sensitivities, jmoment_sensitivities,
            global_quantile_values, quantile_sensitivities)


def make_fbo_noisy_moments(fixed_batches, moment_sensitivities,
                           k_moments, laplace_eps, allocation=None):
    """Sets up true and noisy moments per batch.

    Args:
      fixed_batches: Array of data batches, [batch_num, batch_size, data_dim].
      moment_sensitivities: Array of sensitivities of moments,
        [k_moments, data_dim].
      k_moments: Integer number of moments to compute.
      laplace_eps: Float, overall privacy budget.
      allocation: List of values, to be normalized, that allocate the privacy
        budget among moments.

    Returns:
      fixed_batches_onetime_noisy_moments: Array of fixed batch noisy moments,
        according to sensitivities of each moment, [batch_num, k_moments].
    """
    data_dim = fixed_batches.shape[-1]
    fixed_batches_moments = np.zeros(
        (len(fixed_batches), k_moments, data_dim), dtype=np.float32)
    fixed_batches_onetime_noisy_moments = np.zeros(
        (len(fixed_batches), k_moments, data_dim), dtype=np.float32)

    # Choose allocation of budget.
    if allocation is None:
        allocation = [1] * k_moments
    print('Privacy budget and allocation: {}, {}\n'.format(
        laplace_eps, allocation))
    eps_ = laplace_eps * (np.array(allocation) / float(np.sum(allocation)))
    assert len(eps_) == k_moments, 'allocation length must match moment num'

    # Get moments and noisy version for each batch.
    for batch_num, batch in enumerate(fixed_batches):
        # Each moment within that batch.
        raw_moments = np.zeros((k_moments, data_dim))
        noisy_moments = np.zeros((k_moments, data_dim))
        for k in range(1, k_moments + 1):
            mk_sens = moment_sensitivities[k - 1]
            # Sample laplace noise for each dimension of data -- scale
            # param takes vector of laplace scales and outputs
            # corresponding values.
            mk_laplace = np.random.laplace(loc=0, scale=mk_sens/eps_[k-1])
            mk = np.mean(np.power(batch, k), axis=0)
            mk_noisy = mk + mk_laplace
            raw_moments[k - 1] = mk
            noisy_moments[k - 1] = mk_noisy
        fixed_batches_moments[batch_num] = raw_moments
        fixed_batches_onetime_noisy_moments[batch_num] = noisy_moments
    print(' Sample: RAW moments')
    print(fixed_batches_moments[0])
    print(' Sample: NOISY moments')
    print(fixed_batches_onetime_noisy_moments[0])
    return fixed_batches_onetime_noisy_moments


def make_fbo_noisy_jmoments(fixed_batches, jmoment_sensitivities,
                            k_moments, laplace_eps, allocation=None):
    """Sets up true and noisy joint moments per batch.

    Args:
      fixed_batches: Array of data batches, [batch_num, batch_size, data_dim].
      jmoment_sensitivities: Array of sensitivities of joint moments,
        [k_moments].
      k_moments: Integer number of moments to compute.
      laplace_eps: Float, overall privacy budget.
      allocation: List of values, to be normalized, that allocate the privacy
        budget among moments.

    Returns:
      fixed_batches_onetime_noisy_jmoments: Array of fixed batch noisy joint
        moments, according to sensitivities of each joint moment,
        [batch_num, k_moments].
    """
    data_dim = fixed_batches.shape[-1]
    fixed_batches_jmoments = np.zeros(
        (len(fixed_batches), k_moments), dtype=np.float32)
    fixed_batches_onetime_noisy_jmoments = np.zeros(
        (len(fixed_batches), k_moments), dtype=np.float32)

    # Choose allocation of budget.
    if allocation is None:
        allocation = [1] * k_moments
    print('Privacy budget and allocation: {}, {}\n'.format(
        laplace_eps, allocation))
    eps_ = laplace_eps * (np.array(allocation) / float(np.sum(allocation)))
    assert len(eps_) == k_moments, 'allocation length must match moment num'

    # Get moments and noisy version for each batch.
    for batch_num, batch in enumerate(fixed_batches):
        # Each moment within that batch.
        raw_jmoments = np.zeros(k_moments)
        noisy_jmoments = np.zeros(k_moments)
        for k in range(1, k_moments + 1):
            mk_sens = jmoment_sensitivities[k - 1]  
            # Sample laplace noise for each dimension of data -- scale
            # param takes vector of laplace scales and outputs
            # corresponding values.
            mk_laplace = np.random.laplace(loc=0, scale=mk_sens/eps_[k-1])
            mk = np.mean(np.power(np.prod(batch, axis=1), k), axis=0)
            mk_noisy = mk + mk_laplace
            raw_jmoments[k - 1] = mk
            noisy_jmoments[k - 1] = mk_noisy
        fixed_batches_jmoments[batch_num] = raw_jmoments
        fixed_batches_onetime_noisy_jmoments[batch_num] = noisy_jmoments
    print(' Sample: RAW jmoments')
    print(fixed_batches_jmoments[0])
    print(' Sample: NOISY jmoments')
    print(fixed_batches_onetime_noisy_jmoments[0])
    return fixed_batches_onetime_noisy_jmoments


def make_noisy_quantiles(global_quantile_values, 
                         quantile_sensitivities,
                         laplace_eps):
    """Adds Laplace noise to true, global quantile values.

    Args:
      global_quantile_values: Array of true quantile values,
        [data_dim, num_quantiles].
      quantile_sensitivities: Array of sensitivities of quantiles,
        [data_dim, num_quantiles].
      laplace_eps: Float, overall privacy budget.

    Returns:
      onetime_noisy_quantiles: Array of noisy quantiles, according to
        sensitivities of each quantile, [batch_num, data_dim, num_quantiles].
    """
    # Simple case: Just perturb global quantiles with Laplace noise.
    num_quantiles = quantile_sensitivities.shape[-1]
    onetime_noisy_quantiles = np.zeros(global_quantile_values.shape,
                                       dtype=np.float32)

    # Sample laplace noise for each dimension of data -- scale
    # param takes vector of laplace scales and outputs
    # corresponding values.
    scale_params = quantile_sensitivities / laplace_eps
    laplace_noise = np.random.laplace(loc=0, scale=scale_params)
    onetime_noisy_quantiles = global_quantile_values + laplace_noise

    print(' RAW quantiles')
    print(global_quantile_values)
    print(' NOISY quantiles')
    print(onetime_noisy_quantiles)
    return onetime_noisy_quantiles


def get_noisy_mean_cov(fixed_batches, laplace_eps):
    """Computes noisy mean and covariance for fixed batches, based on
    sensitivities of each across all batches.
    
    Args:
      fixed_batches: NumPy array of fixed batches of inputs, of dimension
        [num_batches, batch_size, input_dim].
      laplace_eps: Float, differential privacy epsilon.

    Returns:
      noisy_means: NumPy array of noisy means.
      noisy_covs: NumPy array of flattened noisy covs.
    """
    num_batches = fixed_batches.shape[0]
    batch_size = fixed_batches.shape[1]
    input_dim = fixed_batches.shape[2]

    means = np.zeros((num_batches, input_dim), dtype=np.float32)
    covs = np.zeros((num_batches, input_dim, input_dim), dtype=np.float32)
    noisy_means = np.zeros((num_batches, input_dim), dtype=np.float32)
    noisy_covs = np.zeros((num_batches, input_dim, input_dim), dtype=np.float32)

    # Store each batch's mean and vectorized covariance.
    for i, b in enumerate(fixed_batches):
        means[i] = np.mean(b, axis=0)
        covs[i] = np.cov(b, rowvar=False)

    # Compute sensitivities of means and covs.
    # For the whole set, compute sensitivity of each moment.
    mean_sensitivity = np.max(np.abs(means), axis=0) / batch_size
    cov_sensitivity = (
        np.max(
            np.abs(
                np.reshape(covs, [num_batches, -1])),  # Vectorized covariances.
            axis=0) /
        batch_size)

    # Compute noisy moments for each batch.
    for i, b in enumerate(fixed_batches):
        natural_mean = means[i]
        laplace_noise_mean = \
            np.random.laplace(loc=0, scale=mean_sensitivity/laplace_eps)
        noisy_means[i] = natural_mean + laplace_noise_mean
        print(natural_mean)
        print(natural_mean + laplace_noise_mean)

        natural_cov = covs[i]
        valid_cov = False
        tries = 0
        tries_limit = 5
        while not valid_cov and tries < tries_limit:
            # Sample and laplace noise to cov. Make cov symmetric. Verify PSD.
            laplace_noise_cov = \
                np.random.laplace(loc=0, scale=cov_sensitivity/laplace_eps)
            noisy_cov_vec = np.reshape(natural_cov, [1, -1]) + laplace_noise_cov
            noisy_cov = np.reshape(noisy_cov_vec, [input_dim, input_dim])
            indices_lower = np.tril_indices(input_dim, -1)
            noisy_cov[indices_lower] = noisy_cov.T[indices_lower]
            print(natural_cov)
            print(noisy_cov)
            if np.all(np.linalg.eigvals(noisy_cov) > 0):
                valid_cov = True
            else:
                tries += 1
                print('----NOT PSD {}----'.format(tries))
        noisy_covs[i] = noisy_cov

    # Check that outputs are somewhat close for large eps.
    print(' Sample: natural and noisy means')
    print(means[:3], noisy_means[:3])
    print(' Sample: natural and noisy covs')
    print(covs[:3], noisy_covs[:3])

    return noisy_means, noisy_covs

### Directories for logging.

In [None]:
def prepare_dirs(load_existing):
    """Creates directories for logs, checkpoints, and plots."""
    log_dir = 'logs/logs_{}'.format(tag)
    checkpoint_dir = os.path.join(log_dir, 'checkpoints')
    plot_dir = os.path.join(log_dir, 'plots')
    g_out_dir = os.path.join(log_dir, 'g_out')
    if os.path.exists(log_dir) and not load_existing:
        shutil.rmtree(log_dir)
    for path in [log_dir, checkpoint_dir, plot_dir, g_out_dir]:
        if not os.path.exists(path):
            os.makedirs(path)
    return log_dir, checkpoint_dir, plot_dir, g_out_dir


def prepare_logging(log_dir, checkpoint_dir, sess):
    """Sets up TensorFlow logging."""
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)
    step = tf.Variable(0, name='step', trainable=False)
    sv = tf.train.Supervisor(logdir=checkpoint_dir,
                             is_chief=True,
                             saver=saver,
                             summary_op=None,
                             summary_writer=summary_writer,
                             save_model_secs=300,
                             global_step=step,
                             ready_for_local_init_op=None)
    return saver, summary_writer


def load_checkpoint(saver, sess, checkpoint_dir):
    """Restores weights from pre-trained model."""
    import re
    print(' [*] Reading checkpoints...')
    print('     {}'.format(checkpoint_dir))
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
        #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
        #counter = int(''.join([i for i in ckpt_name if i.isdigit()]))
        counter = int(ckpt_name.split('-')[-1])
        print(' [*] Success to read {}'.format(ckpt_name))
        return True, counter
    else:
        print(' [*] Failed to find a checkpoint')
        return False, 0

### Presence and attribute risk.

In [None]:
def avg_nearest_neighbor_distance(candidates, references, flag='noflag'):
    """Measures distance from candidate set to a reference set.
    NOTE: This is not symmetric!

    For each element in candidate set, find distance to nearest neighbor in
    reference set. Return the average of these distances.

    Args:
      candidates: Numpy array of candidate points. (num_points x point_dim)
      references: Numpy array of reference points. (num_points x point_dim)

    Returns:
      avg_dist: Float, average over distances.
      distances: List of distances to nearest neighbor for each candidate.
    """
    distances = []
    for i in range(candidates.shape[0]):
        c_i = tf.gather(candidates, [i])
        distances_from_i = tf.norm(c_i - references, axis=1)
        d_from_i_reshaped = tf.reshape(distances_from_i, [1, -1])  # NEW

        assert d_from_i_reshaped.shape.as_list() == [1, references.shape[0]]
        distances_negative = -1.0 * d_from_i_reshaped
        #distances_negative = -1.0 * distances_from_i  # OLD
        smallest_dist_negative, _ = tf.nn.top_k(distances_negative, name=flag)
        assert smallest_dist_negative.shape.as_list() == [1, 1]
        smallest_dist = -1.0 * smallest_dist_negative[0]

        distances.append(smallest_dist)

    avg_dist = tf.reduce_mean(distances)
    return avg_dist, distances


def evaluate_presence_risk(train, test, sim, ball_radius=1e-2):
    """Assess privacy of simulations.

    Compute True Pos., True Neg., False Pos., and False Neg. rates of
    finding a neighbor in the simulations, for each of a subset of training
    data and a subset of test data.

    Args:
      train: Numpy array of all training data.
      test: Numpy array of all test data (smaller than train).
      sim: Numpy array of simulations.
      ball_radius: Float, distance from point to sim that qualifies as match.

    Return:
      sensitivity: Float of TP / (TP + FN).
      precision: Float of TP / (TP + FP).
    """
    # TODO: Sensitivity as a loss, rather than just be reported?
    # TODO: Count nearest neighbors, rather than presence in epsilon-ball?
    assert len(test) < len(train), 'test should be smaller than train'
    num_samples = len(test)
    compromised_records = train[:num_samples]
    tp, tn, fp, fn = 0, 0, 0, 0

    # Count true positives and false negatives.
    for i in compromised_records:
        distances_from_i = norm(i - sim, axis=1)
        has_neighbor = np.any(distances_from_i < ball_radius)
        if has_neighbor:
            tp += 1
        else:
            fn += 1
    # Count false positives and true negatives.
    for i in test:
        distances_from_i = norm(i - sim, axis=1)
        has_neighbor = np.any(distances_from_i < ball_radius)
        if has_neighbor:
            fp += 1
        else:
            tn += 1

    sensitivity = float(tp) / (tp + fn)
    precision = float(tp + 1e-10) / (tp + fp + 1e-10)
    false_positive_rate = float(fp) / (fp + tn)
    return (sensitivity, precision, false_positive_rate, tp, fn, fp, tn)

### Med data plotting.

In [None]:
def plot_marginals(raw_data_train, data, batch_size, step, g_, g_out, log_dir,
        filename_tag=None, plot_sparse=False):
    """Plots all marginals, and computes MMDs between marginals of real and sim.

    Args:
      raw_data_train: Numpy array, un-standardized data. 
      data: Numpy array, standardized data. 
      batch_size: Int, batch_size for MMD computation. 
      step: Int, used for logging. 
      g_: Numpy array, standardized simulation. 
      g_out: Numpy array, un-standardized simulation. 
      log_dir: String, path where plots are saved.
      filename_tag: String, used for naming plot.
      plot_spares: Bool, used to toggle labels on plots.
    """
    random_batch_data = np.array(
        [data[i] for i in np.random.choice(len(data), batch_size)])
    random_batch_gen = np.array(
        [g_[i] for i in np.random.choice(len(g_), batch_size)])
    num_cols = raw_data_train.shape[1]
    sq_dim = int(np.ceil(np.sqrt(num_cols)))
    fig, axs = plt.subplots(sq_dim, sq_dim, figsize=(30, 30))
    if not plot_sparse:
        fig.subplots_adjust(hspace=0.5, wspace=0.5)
        fig.suptitle('Marginals, it{}'.format(step))
    else:
        fig.subplots_adjust(hspace=0.05, wspace=0.05)
    axs = axs.ravel()
    bins = 40
    for i in range(num_cols):
        # For each marginal, compute mmd between normalized data and
        # normalized simulations.
        mmd_i_data_gen, _ = compute_mmd(
            random_batch_data[:, i], random_batch_gen[:, i], use_tf=False)
        # For each marginal, plot unnormalized data and unnormalized simulations.
        plot_d = raw_data_train[:, i]
        plot_g = g_out[:, i]
        axs[i].hist(plot_d, density=True, alpha=0.3, label='d', bins=bins)
        axs[i].hist(plot_g, density=True, alpha=0.3, label='g', bins=bins)
        if not plot_sparse:
            axs[i].set_xlabel('mmd = {:.3f}'.format(mmd_i_data_gen))
            axs[i].legend()
        else:
            axs[i].tick_params(axis='both', which='both', bottom='off', top='off',
                labelbottom='off', right='off', left='off', labelleft='off')
    for i in range(num_cols, sq_dim ** 2):
        axs[i].axis('off')
    if filename_tag:
        filename = 'plot_marginals_{}.png'.format(filename_tag)
    else:
        filename = 'plot_marginals_{}.png'.format(step)
    plt.savefig(os.path.join(log_dir, filename))
    plt.close('all')
    
    if jupyter_verbose:
        display(Image(filename=os.path.join(log_dir, filename)))


def plot_correlations(raw_data_train, step, g_out, log_dir):
    num_cols = raw_data_train.shape[1]
    corr_coefs_data = np.zeros((num_cols, num_cols))
    corr_coefs_gens = np.zeros((num_cols, num_cols))
    for i in range(num_cols):
        for j in range(num_cols):
            if j > i:
                corr_coefs_data[i][j], _ = pearsonr(
                        raw_data_train[:, i], raw_data_train[:, j])
                corr_coefs_gens[i][j], _ = pearsonr(
                        g_out[:, i], g_out[:, j])
    coefs_d = corr_coefs_data.flatten()
    coefs_g = corr_coefs_gens.flatten()
    coefs_d = coefs_d[coefs_d != 0]
    coefs_g = coefs_g[coefs_g != 0]
    fig, ax = plt.subplots()
    ax.scatter(coefs_d, coefs_g)
    ax.plot(ax.get_xlim(), ax.get_ylim(), ls='-')
    #ax.set_xlabel('Correlation data')
    #ax.set_ylabel('Correlation gens')
    filepath = os.path.join(
        log_dir,'plot_correlations_{}.png'.format(step))
    plt.savefig(filepath)
    plt.close('all')
    
    if jupyter_verbose:
        display(Image(filename=filepath))

### Network helper function for DP-SGD.

In [None]:
def load_network_parameters(z_dim, default_gradient_l2norm_bound,
                            depth, width, out_dim):
    network_parameters = NetworkParameters()
    network_parameters.input_size = z_dim
    network_parameters.default_gradient_l2norm_bound = \
        default_gradient_l2norm_bound
    for i in range(depth):
        hidden = LayerParameters()
        hidden.name = "hidden%d" % i
        hidden.num_units = width
        hidden.relu = True
        hidden.with_bias = False
        hidden.trainable = True
        network_parameters.layer_parameters.append(hidden)

    gen = LayerParameters()
    gen.name = 'gen'
    gen.num_units = out_dim 
    gen.relu = False
    gen.with_bias = False
    network_parameters.layer_parameters.append(gen)
    return network_parameters

### Build TF graph.

In [None]:
def build_model_cmd_gan(batch_size, gen_num, data_num, data_test_num, out_dim,
                        z_dim, cmd_span_const,
                        moment_sensitivities,
                        quantile_sensitivities,
                        bounds):
    """Builds model for Central Moment Discrepancy as adversary."""

    # Placeholders to precompute avg distance from data_test to data.
    x_precompute = tf.placeholder(
        tf.float32, [data_test_num, out_dim], name='x_precompute')
    x_test_precompute = tf.placeholder(
        tf.float32, [data_test_num, out_dim], name='x_test_precompute')
    avg_dist_x_to_x_test_precomputed, distances_xt_xp = \
        avg_nearest_neighbor_distance(x_precompute, x_test_precompute)

    # Placeholders for regular training.
    x = tf.placeholder(tf.float32, [batch_size, out_dim], name='x')
    z = tf.placeholder(tf.float32, [gen_num, z_dim], name='z')
    z_readonly = tf.placeholder(tf.float32, [data_num, z_dim], name='z_readonly')
    x_test = tf.placeholder(tf.float32, [batch_size, out_dim], name='x_test')

    avg_dist_x_to_x_test = tf.placeholder(
        tf.float32, shape=(), name='avg_dist_x_to_x_test')
    prog_cmd_coefs = tf.placeholder(
        tf.float32, shape=(k_moments), name='prog_cmd_coefs')
    mmd_to_cmd_indicator = tf.placeholder(
        tf.float32, shape=(), name='mmd_to_cmd_indicator')

    # Update learning rate.
    lr = tf.Variable(learning_rate, name='lr', trainable=False)
    lr_update = tf.assign(lr, tf.maximum(lr * 0.8, lr_minimum),
                          name='lr_update')

    ########################
    # Generator output.
    use_generator_v2 = 1
    if use_generator_v2:
        # Dense network using Google Research code from Github.
        # https://github.com/tensorflow/models/blob/master/research/
        #   differential_privacy/dp_sgd/dp_optimizer/utils.py
        network_parameters = load_network_parameters(
            z_dim, default_gradient_l2norm_bound, depth, width, out_dim)
        with tf.variable_scope('generator_v2') as scope:
            g = generator_v2(z, network_parameters)
            g_readonly = generator_v2(z_readonly, network_parameters)
    else:
        g, g_vars = generator(
            z, width=width, depth=depth, activation=activation, out_dim=out_dim,
            bounds=bounds)
        g_readonly, _ = generator(
            z_readonly, width=width, depth=depth, activation=activation,
            out_dim=out_dim, reuse=True, bounds=bounds)

        
    ########################
    # Autoencoder output.
    h_out, ae_out, enc_vars, dec_vars = autoencoder(tf.concat([x, g], 0),
    width=width, depth=depth, activation=activation, z_dim=z_dim, reuse=False)
    enc_x, enc_g = tf.split(h_out, [batch_size, gen_num])
    ae_x, ae_g = tf.split(ae_out, [batch_size, gen_num])
    
        
    #######################  
    # Moment discrepancies.
    on_encodings = 0
    if on_encodings:
        arr1 = enc_x
        arr2 = enc_g
    else:
        arr1 = x
        arr2 = g
        
    mmd = compute_mmd(
        arr1, arr2, use_tf=True, slim_output=True, sigma_list=[0.1, 0.5, 1.0, 2.0])  # Added sigma_list.
    kmmd = compute_kmmd(
        arr1, arr2, k_moments=k_moments, use_tf=True,
        slim_output=True, sigma_list=[0.1, 0.5, 1.0, 2.0])
    cmd_k, cmd_k_terms = compute_cmd(
        arr1, arr2, k_moments=k_moments, use_tf=True,
        cmd_span_const=cmd_span_const, return_terms=True, taylor_weight=do_cmd_taylor_weights)
    ncmd_k = compute_noncentral_moment_discrepancy(
        arr1, arr2, k_moments=k_moments, use_tf=True,
        cmd_span_const=cmd_span_const)
    _, ncmd_k_terms = compute_noncentral_moment_discrepancy(
        arr1, arr2, k_moments=k_moments, use_tf=True,
        return_terms=True, cmd_span_const=1)
    jmd_k = compute_joint_moment_discrepancy(
        arr1, arr2, k_moments=k_moments, use_tf=True,
        cmd_span_const=cmd_span_const)

    
    #######################
    # Losses.

    # Generator loss.
    if cmd_variation == 'onetime_noisy':
        # NoncentralMD on one-time-noised empirical data moments.
        batch_id = tf.placeholder(tf.int32, shape=(), name='batch_id')
        fbo_noisy_moments = tf.placeholder(
            tf.float32, [None, k_moments, out_dim], name='fbo_noisy_moments')
        g_loss = compute_noncentral_moment_discrepancy(
            x, g, k_moments=k_moments, use_tf=True,
            cmd_span_const=cmd_span_const, batch_id=batch_id,
            fbo_noisy_moments=fbo_noisy_moments)
        
        eps = None
        delta = None
        fbo_noisy_jmoments = None
        noisy_quantiles = None
        priv_accountant = None
        
    elif cmd_variation == 'onetime_noisy_joint':
        # Joint MD on one-time-noised empirical data moments.
        batch_id = tf.placeholder(tf.int32, shape=(), name='batch_id')
        fbo_noisy_moments = tf.placeholder(
            tf.float32, [None, k_moments, out_dim], name='fbo_noisy_moments')
        fbo_noisy_jmoments = tf.placeholder(
            tf.float32, [None, k_moments, 1], name='fbo_noisy_jmoments')
        num_quantiles = quantile_sensitivities.shape[-1]
        noisy_quantiles = tf.placeholder(
            tf.float32, [out_dim, num_quantiles], name='noisy_quantiles')
        # Noncentral moment discrepancy.
        ncmd_k = compute_noncentral_moment_discrepancy(
            x, g, k_moments=k_moments, use_tf=True,
            cmd_span_const=cmd_span_const, batch_id=batch_id,
            fbo_noisy_moments=fbo_noisy_moments)
        # Joint moment discrepancy.
        jmd_k = compute_joint_moment_discrepancy(
            x, g, k_moments=k_moments, use_tf=True,
            cmd_span_const=cmd_span_const, batch_id=batch_id,
            fbo_noisy_jmoments=fbo_noisy_jmoments)
        # Combine the noncentral and joint discrepancies.
        g_loss = 1 * ncmd_k + .1 * jmd_k

        eps = None
        delta = None
        priv_accountant = None
        
    elif cmd_variation == 'dp_sgd':
        eps = None
        delta = None
        priv_accountant = None
        batch_id = None
        fbo_noisy_moments = None
        fbo_noisy_jmoments = None
        noisy_quantiles = None
        g_loss = mmd  # [mmd, cmd_k, ncmd_k]
    
    elif cmd_variation in ['mmd', 'kmmd', 'cmd', 'ncmd', 'ncmd_jmd']:
        eps = None
        delta = None
        priv_accountant = None
        batch_id = None
        fbo_noisy_moments = None
        fbo_noisy_jmoments = None
        noisy_quantiles = None
        if cmd_variation == 'mmd':
            g_loss = mmd
        elif cmd_variation == 'kmmd':
            g_loss = kmmd
        elif cmd_variation == 'cmd':
            g_loss = cmd_k
        elif cmd_variation == 'ncmd':
            g_loss = ncmd_k
        elif cmd_variation == 'ncmd_jmd':
            g_loss = ncmd_k + jmd_k
    
    else:
        sys.exit('Ensure valid cmd_variation.')

    # Autoencoder loss.
    ae_loss = tf.reduce_mean(tf.square(ae_x - x))
    d_loss = ae_loss - 0.1 * g_loss

        
    ######################
    # Optimization ops.

    # DP-SGD optim nodes.
    if cmd_variation == 'dp_sgd':
        # Begin: Differentially private SGD.
        # In classification example on Github, authors take average of cross-
        # entropy loss across batch, claiming the "actual cost is the average
        # across the examples". Here, the CMD is already an expectation over
        # the samples, so it's unclear whether it should be scaled.
        # TODO: Should the g_loss be scaled down by batch_size?

        # Define effective training size, given fixed batches.
        num_training = (data_num // batch_size) * batch_size

        # Instantiate the accountant.
        priv_accountant = GaussianMomentsAccountant(num_training)

        # Define sigma.
        sigma = sgd_sigma
        
        # Instantiate the sanitizer.
        # TODO: Should the bound be scaled down by batch size?
        gaussian_sanitizer = AmortizedGaussianSanitizer(
            priv_accountant,
            [default_gradient_l2norm_bound / batch_size, True])  # / batch_size?

        # Setting clip options for each var. For now, this does nothing, as all
        # vars take default option.
        #for var in training_params:
        #    if "gradient_l2norm_bound" in training_params[var]:
        #        l2bound = training_params[var]["gradient_l2norm_bound"] / batch_size
        #        gaussian_sanitizer.set_option(var, sanitizer.ClipOption(l2bound,
        #                                                                True))

        # Constants for optimization step.
        #lr = tf.placeholder(tf.float32)
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)

        # Define optimization node.
        g_optim = DPGradientDescentOptimizer(
            lr,
            [eps, delta],
            gaussian_sanitizer,
            sigma=sigma,
            batches_per_lot=1).minimize(g_loss)
        
        # TODO: DP-SGD version of autoencoder optim.
        # TODO: Figure whether dp-sgd version needs to have its
        #       var_list restricted to generator vars.

    # Moment discrepancy optim nodes.
    else:
        if optimizer == 'adagrad':
            g_opt = tf.train.AdagradOptimizer(lr)
            d_opt = tf.train.AdagradOptimizer(lr)
        elif optimizer == 'adam':
            g_opt = tf.train.AdamOptimizer(lr)
            d_opt = tf.train.AdamOptimizer(lr)
        elif optimizer == 'rmsprop':
            g_opt = tf.train.RMSPropOptimizer(lr)
            d_opt = tf.train.RMSPropOptimizer(lr)
        elif optimizer == 'adadelta':
            g_opt = tf.train.AdadeltaOptimizer(lr)
            d_opt = tf.train.AdadeltaOptimizer(lr)
        else:
            g_opt = tf.train.GradientDescentOptimizer(lr)
            d_opt = tf.train.GradientDescentOptimizer(lr)
            
        d_optim = d_opt.minimize(d_loss, var_list=enc_vars+dec_vars)
            
        g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name]
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # layers.batch_normalization
        with tf.control_dependencies(update_ops):
            clip = 1
            if clip:
                #g_vars = tf.trainable_variables()  # NEW to v2
                g_grads_, g_vars_ = zip(*g_opt.compute_gradients(g_loss, var_list=g_vars))
                g_grads_clipped_ = tuple(
                    [tf.clip_by_value(grad, -0.01, 0.01) for grad in g_grads_])
                g_optim = g_opt.apply_gradients(zip(g_grads_clipped_, g_vars_))
            else:
                #g_vars = tf.trainable_variables()  # NEW to v2
                g_optim = g_opt.minimize(g_loss, var_list=g_vars)


    
    ###########################
    # Diagnostics and summary.
    
    # Get diagnostics. Compare distances between data, heldouts, and gen.
    avg_dist_g_to_x, distances_g_x = avg_nearest_neighbor_distance(g, x)
    avg_dist_x_to_g, distances_x_g = avg_nearest_neighbor_distance(x, g)
    avg_dist_x_test_to_g, distances_xt_g = avg_nearest_neighbor_distance(x_test, g)
    loss1 = avg_dist_x_to_x_test - avg_dist_x_to_g
    loss2 = avg_dist_x_test_to_g - avg_dist_x_to_g
    cmd = cmd_k  # For reporting, define cmd as cmd_k.

    # Define summary op for reporting.
    summary_op = tf.summary.merge([
    tf.summary.scalar("loss/g_loss", g_loss),
    tf.summary.scalar("loss/g_loss", d_loss),
    tf.summary.scalar("loss/g_loss", ae_loss),
    tf.summary.scalar("loss/loss1", loss1),
    tf.summary.scalar("loss/loss2", loss2),
    tf.summary.scalar("loss/mmd", mmd),
    tf.summary.scalar("loss/cmd", cmd),
    tf.summary.scalar("loss/ncmd_k", ncmd_k),
    tf.summary.scalar("loss/jmd_k", jmd_k),
    tf.summary.scalar("misc/lr", lr),
    ])

    return (x, z, z_readonly, x_test, x_precompute, x_test_precompute,
            avg_dist_x_to_x_test, avg_dist_x_to_x_test_precomputed,
            distances_xt_xp, prog_cmd_coefs, mmd_to_cmd_indicator, cmd_k_terms,
            ncmd_k_terms, g, g_readonly, mmd, kmmd, cmd, loss1, loss2, lr_update, lr,
            g_optim, d_optim, d_loss, ae_loss, summary_op, batch_id, fbo_noisy_moments,
            fbo_noisy_jmoments, noisy_quantiles, eps, delta, priv_accountant)


def add_nongraph_summary_items(summary_writer, step, dict_to_add):
    """Adds to list of summary items during logging."""
    for k, v in dict_to_add.items():
        summ = tf.Summary()
        summ.value.add(tag=k, simple_value=v)
        summary_writer.add_summary(summ, step)
    summary_writer.flush()

# Run.

### Initialize directories.

In [None]:
# Prepare logging, checkpoint, and plotting directories.
log_dir, checkpoint_dir, plot_dir, g_out_dir = prepare_dirs(load_existing)
save_tag = str(args)
with open(os.path.join(log_dir, 'save_tag.txt'), 'w') as save_tag_file:
    save_tag_file.write(save_tag)
print('Save tag: {}'.format(save_tag))

### Load data.

In [None]:
# Load data and prep dirs.
(data, 
 data_test, 
 data_num, 
 data_test_num, 
 out_dim, 
 data_raw_mean,
 data_raw_std,
 clip) = load_normed_data(data_num_init, percent_train, log_dir,
                          clip_unnormed=clip_unnormed,
                          data_file=data_file)
data_dim = data.shape[1]
normed_moments_data = compute_moments(data, k_moments=k_moments+1)
normed_moments_data_test = compute_moments(data_test, k_moments=k_moments+1)
nmd_zero_indices = np.argwhere(
    norm(np.array(normed_moments_data), axis=1) < 0.1)

# Compute baseline statistics on moments for data set.
#print_baseline_moment_stats(data, data_raw_mean, data_raw_std, k_moments)

### Sensitivities.

In [None]:
# Compute sensitivities for moments.
fixed_batches = make_fixed_batches(data, batch_size)

(moment_sensitivities,
 jmoment_sensitivities,
 global_quantile_values,
 quantile_sensitivities) = compute_sensitivities(data, batch_size, k_moments, clip=clip)
print('Global quantile values:')
print(global_quantile_values)

print('\n\nData size: {}, Batch size: {}, Num batches: {}, '
      'Effective data size: {}\n\n'.format(
           len(data), batch_size, len(fixed_batches),
           batch_size * len(fixed_batches)))

# Allocation can apply more of privacy budget to certain moments.
# e.g. allocation = [1, 1, 5] applies 5/7 of the eps budget to the third moment.
# More budget leads to lower noise applied to that moment.
allocation = [1] * k_moments

# Add onetime noise to MOMENT in each batch, according to 
# moment_sensitivities.
fixed_batches_onetime_noisy_moments = \
    make_fbo_noisy_moments(
        fixed_batches, moment_sensitivities, k_moments,
        laplace_eps, allocation=allocation)

# Add onetime noise to JOINT MOMENT in each batch, according to 
# fixed_batches_jmoment_sensitivities.
allocation = [1] * k_moments
fixed_batches_onetime_noisy_jmoments = \
    make_fbo_noisy_jmoments(
        fixed_batches, jmoment_sensitivities, k_moments,
        laplace_eps, allocation=allocation)
fixed_batches_onetime_noisy_jmoments = np.expand_dims(
    fixed_batches_onetime_noisy_jmoments, axis=2)
assert (len(fixed_batches_onetime_noisy_moments.shape) ==
        len(fixed_batches_onetime_noisy_jmoments.shape) == 3), (
            'fbo inputs must be 3d tensors')

# Add onetime noise to quantiles in each batch, according to 
# quantile_sensitivities.
onetime_noisy_quantiles = \
    make_noisy_quantiles(
        global_quantile_values,
        quantile_sensitivities,
        laplace_eps)

# Get compact interval bounds for CMD computations.
data_raw = data * data_raw_std + data_raw_mean
cmd_a_raw = np.min(data_raw)
cmd_b_raw = np.max(data_raw)
cmd_span_const = 1.0 / np.max(pdist(data))  # TODO: Is it a problem that this is on [0, 1]?
#print('OVERWROTE SPAN CONSTANT to 1.')
#cmd_span_const = 1.
cmd_a = np.min(data)
cmd_b = np.max(data)
print('cmd_span_const: {:.2f}'.format(cmd_span_const))

### Store data and build model.

In [None]:
# Save data set used for training.
data_train_unnormed = data * data_raw_std + data_raw_mean
data_test_unnormed = data_test * data_raw_std + data_raw_mean
np.save(os.path.join(log_dir, 'data_train.npy'), data_train_unnormed)
np.save(os.path.join(log_dir, 'data_test.npy'), data_test_unnormed)

# Save file for outputs in txt form. Also saved later as npy.
g_out_file = os.path.join(g_out_dir, 'g_out.txt')
if os.path.isfile(g_out_file):
    os.remove(g_out_file)

# build_all()
# Build model.
(x, z, z_readonly, x_test, x_precompute, x_test_precompute,
 avg_dist_x_to_x_test, avg_dist_x_to_x_test_precomputed, distances_xt_xp,
 prog_cmd_coefs, mmd_to_cmd_indicator, cmd_k_terms, ncmd_k_terms, g, g_readonly,
 mmd, kmmd, cmd, loss1, loss2, lr_update, lr, g_optim, d_optim, d_loss, ae_loss,
 summary_op, batch_id, fbo_noisy_moments, fbo_noisy_jmoments, noisy_quantiles,
 eps, delta, priv_accountant) = \
     build_model_cmd_gan(batch_size, gen_num, data_num, data_test_num,
                         out_dim, z_dim, cmd_span_const,
                         moment_sensitivities,
                         quantile_sensitivities, bounds=[cmd_a, cmd_b])

### Start session.

In [None]:
################
# Start session.
################
init_op = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
sess_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

with tf.Session(config=sess_config) as sess:
    if cmd_variation == 'dp_sgd':
        # We need to maintain the intialization sequence.
        for v in tf.trainable_variables():
            sess.run(tf.variables_initializer([v]))


    saver, summary_writer = prepare_logging(log_dir, checkpoint_dir, sess)

    sess.run(init_op)

    # load_existing_model().
    if load_existing:
        could_load, checkpoint_counter = load_checkpoint(
            saver, sess, checkpoint_dir)
        if could_load:
            load_step = checkpoint_counter
            print(' [*] Load SUCCESS')
        else:
            print(' [!] Load failed...')
    else:
        load_step = 0

    # Once, compute average distance from heldout data to training data.
    avg_dist_x_to_x_test_precomputed_, _ = sess.run(
        [avg_dist_x_to_x_test_precomputed, distances_xt_xp],
        {x_precompute: data[:len(data_test)],
         x_test_precompute: data_test})

    # Containers to hold empirical and relative errors of moments.
    empirical_moments_gens = np.zeros(
        ((max_step - 0) // log_step, k_moments+1, data_dim))
    relative_error_of_moments = np.zeros(
        ((max_step - 0) // log_step, k_moments+1))
    reom = relative_error_of_moments


    #########
    # train()
    #########
    start_time = time()
    for step in range(load_step, max_step):

        # Set up inputs for all models.
        # OPTION 1: Random batch selection.
        # OPTION 2: Fixed batch selection.
        batch_selection_option = 2
        if batch_selection_option == 1:
            random_batch_data = np.array(
                [data[d] for d in np.random.choice(len(data), batch_size)])
            _batch_id = None
        elif batch_selection_option == 2:
            _epoch, _batch_id = np.divmod(step, len(fixed_batches))
            if _batch_id == 0 and _epoch % 10 == 0:
                print(' {}'.format(_epoch))
            random_batch_data = fixed_batches[_batch_id]

        # Fetch test data and z.
        random_batch_data_test = np.array(
            [data_test[d] for d in np.random.choice(
                len(data_test), batch_size)])
        random_batch_z = get_random_z(gen_num, z_dim)

        # Update shared dict for chosen model.
        #if cmd_variation in ['', None]:
        #    feed_dict = {
        #        x: random_batch_data,
        #        z: random_batch_z,
        #        x_test: random_batch_data_test,
        #        avg_dist_x_to_x_test: avg_dist_x_to_x_test_precomputed_,
        #        batch_id: _batch_id}
        if cmd_variation == 'dp_sgd':
            feed_dict = {
                x: random_batch_data,
                z: random_batch_z,
                x_test: random_batch_data_test,
                avg_dist_x_to_x_test: avg_dist_x_to_x_test_precomputed_,
                eps: sgd_eps,
                delta: sgd_delta}
        elif cmd_variation == 'onetime_noisy': 
            feed_dict = {
                x: random_batch_data,
                z: random_batch_z,
                x_test: random_batch_data_test,
                avg_dist_x_to_x_test: avg_dist_x_to_x_test_precomputed_,
                batch_id: _batch_id,
                fbo_noisy_moments: fixed_batches_onetime_noisy_moments}
        elif cmd_variation == 'onetime_noisy_joint': 
            feed_dict = {
                x: random_batch_data,
                z: random_batch_z,
                x_test: random_batch_data_test,
                avg_dist_x_to_x_test: avg_dist_x_to_x_test_precomputed_,
                batch_id: _batch_id,
                fbo_noisy_moments: fixed_batches_onetime_noisy_moments,
                fbo_noisy_jmoments: fixed_batches_onetime_noisy_jmoments}
        else:
            feed_dict = {
                x: random_batch_data,
                z: random_batch_z,
                x_test: random_batch_data_test,
                avg_dist_x_to_x_test: avg_dist_x_to_x_test_precomputed_}

        
        # Run optimization step.
        #sess.run([d_optim, g_optim], feed_dict)
        sess.run([g_optim], feed_dict)

        # Occasionally update learning rate.
        if step % lr_update_step == lr_update_step - 1:
            _, lr_ = sess.run([lr_update, lr])
            print('Updated learning rate to {}'.format(lr_))


        ###########
        # logging()
        ###########
        # Occasionally log/plot results.
        if step % log_step == 0 and step > 0:
            print('\nIter: {}'.format(step))

            if cmd_variation == 'dp_sgd':
                # Report privacy loss.
                spent_eps_deltas = priv_accountant.get_privacy_spent(
                    sess, target_eps=sgd_target_eps)
                for spent_eps, spent_delta in spent_eps_deltas:
                    print('  spent privacy: eps {:.4f} delta {:.5g}'.format(
                          spent_eps, spent_delta))

            # Read off from graph.
            (cmd_, cmd_k_terms_, ncmd_k_terms_,
             mmd_, kmmd_, d_loss_, ae_loss_,
             loss1_, loss2_, 
             summary_result) = sess.run([cmd, cmd_k_terms, ncmd_k_terms,
                                         mmd, kmmd, d_loss, ae_loss,
                                         loss1, loss2,
                                         summary_op],
                                        feed_dict)
            g_readonly_ = sess.run(g_readonly,
                                   {z_readonly: get_random_z(data_num,
                                                             z_dim,
                                                             for_training=False)})
            g_batch_ = g_readonly_[np.random.randint(0, data_num, batch_size)]
            g_full_ = g_readonly_
            # TODO: Determine how much budget to put on joint moment.
            if cmd_variation == 'onetime_noisy':
                total_laplace_eps = laplace_eps
            elif cmd_variation == 'onetime_noisy_joint':
                total_laplace_eps = 2 * laplace_eps
            else:
                total_laplace_eps = laplace_eps
            print(('  LAPLACE_EPS: {:.1f},\n'
                   'MMD: {:.4f}, kmmd: {:.4f}, cmd: {:.4f},\n'
                   'ncmd_k_terms: {}\n'
                   'd_loss: {:.1f}, ae_loss: {:.1f}, '
                   'loss1: {:.4f}, loss2: {:.4f}').format(
                       total_laplace_eps,
                       mmd_, kmmd_, cmd_,
                       ncmd_k_terms_,
                       d_loss_, ae_loss_,
                       loss1_, loss2_))


            # Test joint moment discrepancy.
            md_batch = compute_noncentral_moment_discrepancy(
                random_batch_data, g_batch_, k_moments=k_moments,
                cmd_span_const=cmd_span_const, batch_id=_batch_id,
                fbo_noisy_moments=fixed_batches_onetime_noisy_moments)
            jmd_batch = compute_joint_moment_discrepancy(
                random_batch_data, g_batch_, k_moments=k_moments,
                cmd_span_const=cmd_span_const, batch_id=_batch_id,
                fbo_noisy_jmoments=fixed_batches_onetime_noisy_jmoments)
            print('  MD_batch: {:.4f}'.format(md_batch))
            print('  JMD_batch: {:.4f}'.format(jmd_batch))
            
            # Test kmmd.
            kmmd_test = compute_kmmd(random_batch_data, g_batch_, k_moments=k_moments,
                                     use_tf=False, slim_output=True)

            # Diagnose NaNs.
            #if np.isnan(mmd_):
            if np.isnan(mmd_):
                pdb.set_trace()

            # Unormalize data and simulations for all logs and plots.
            g_batch_unnormed = unnormalize(
                g_batch_, data_raw_mean, data_raw_std)
            g_full_unnormed = unnormalize(
                g_full_, data_raw_mean, data_raw_std)
            data_unnormed = unnormalize(
                data, data_raw_mean, data_raw_std)
            data_test_unnormed = unnormalize(
                data_test, data_raw_mean, data_raw_std)

            if extra_verbose:
                # Compute disclosure risk.
                (sensitivity, precision, false_positive_rate, tp, fn, fp,
                 tn) = evaluate_presence_risk(
                     data_unnormed, data_test_unnormed, g_full_unnormed)
                     #ball_radius=avg_dist_x_to_x_test_precomputed_)
                sens_minus_fpr = sensitivity - false_positive_rate
                print('  Sens={:.4f}, Prec={:.4f}, Fpr: {:.4f}, '
                      'tp: {}, fn: {}, fp: {}, tn: {}'.format(
                          sensitivity, precision, false_positive_rate, tp, fn,
                          fp, tn))

                # Add presence disclosure stats to summaries.
                summary_writer.add_summary(summary_result, step)
                add_nongraph_summary_items(
                    summary_writer, step,
                    {'misc/sensitivity': sensitivity,
                     'misc/false_positive_rate': false_positive_rate,
                     'misc/sens_minus_fpr': sens_minus_fpr,
                     'misc/precision': precision})

            # Save checkpoint.
            saver.save(
                sess,
                os.path.join(log_dir, 'checkpoints', model_type),
                global_step=step)

            # Save generated data to file.
            np.save(os.path.join(g_out_dir, 'g_out_{}.npy'.format(step)),
                    g_full_unnormed)
            with open(g_out_file, 'a') as f:
                f.write(str(g_full_unnormed) + '\n')

            # Print time performance.
            if step % (10 * log_step) == 0 and step > 0:
                elapsed_time = time() - start_time
                time_per_iter = elapsed_time / step
                total_est = elapsed_time / step * max_step
                m, s = divmod(total_est, 60)
                h, m = divmod(m, 60)
                total_est_str = '{:.0f}:{:02.0f}:{:02.0f}'.format(h, m, s)
                print(('\n  Time (s): {:.2f}, time/iter: {:.4f},'
                       ' Total est.: {}').format(
                            elapsed_time, time_per_iter, total_est_str))

                print('  Save tag: {}\n'.format(save_tag))


            ############################
            # PLOT data and simulations.
            ############################
            if out_dim == 1:
                fig, ax = plt.subplots()
                ax.hist(data, density=True, bins=30, color='gray', alpha=0.3,
                        label='data')
                ax.hist(g_full_, density=True, bins=30, color='blue', alpha=0.2,
                        label='g_full_readonly')

                plt.legend()
                filepath = os.path.join(plot_dir, '{}.png'.format(step))
                plt.savefig(filepath)
                plt.close(fig)
                
                if jupyter_verbose:
                    display(Image(filename=filepath))
            
            elif out_dim == 2:
                d_x = data_unnormed[:, 0]
                d_y = data_unnormed[:, 1]
                g_x = g_full_unnormed[:, 0]
                g_y = g_full_unnormed[:, 1]

                fig = plt.figure()
                gs = GridSpec(4, 4)
                ax_joint = fig.add_subplot(gs[1:4, 0:3])
                ax_marg_x = fig.add_subplot(gs[0, 0:3], sharex=ax_joint)
                ax_marg_y = fig.add_subplot(gs[1:4, 3], sharey=ax_joint)

                ax_joint.scatter(*zip(*data_unnormed), color='gray',
                                 alpha=0.2, label='data')
                ax_joint.scatter(*zip(*g_full_unnormed), color='green',
                                 alpha=0.2, label='sim')
                bins_x = np.arange(np.min([np.min(d_x), np.min(g_x)]),
                                   np.max([np.max(d_x), np.max(g_x)]), 0.2)
                bins_y = np.arange(np.min([np.min(d_y), np.min(g_y)]),
                                   np.max([np.max(d_y), np.max(g_y)]), 0.2)
                #bins_y = np.arange(np.min(d_y), np.max(d_y), 0.2)
                ax_marg_x.hist([d_x, g_x], bins=bins_x, alpha=0.2,
                               color=['gray', 'green'], label=['data', 'gen'],
                               density=True)
                ax_marg_y.hist([d_y, g_y], bins=bins_y, alpha=0.2, 
                               color=['gray', 'green'], label=['data', 'gen'],
                               density=True, orientation='horizontal')
                ax_joint.legend()
                ax_marg_x.legend()
                ax_marg_y.legend()
                plt.setp(ax_marg_x.get_xticklabels(), visible=False)
                plt.setp(ax_marg_y.get_yticklabels(), visible=False)
                plt.suptitle(tag)
                filepath = os.path.join(plot_dir, '{}.png'.format(step))
                plt.savefig(filepath)
                plt.close()

                if jupyter_verbose:
                    display(Image(filename=filepath))

                    
            # Plot moment diagnostics.
            if data_dim <= 2:
                normed_moments_gens = compute_moments(
                    g_full_, k_moments=k_moments+1)
                empirical_moments_gens[step // log_step] = normed_moments_gens

                # Define colormap used for plotting.
                cmap = plt.cm.get_cmap('cool', k_moments+1)

                if data_dim == 1:
                    # Plot empirical moments throughout training.
                    fig, (ax_data, ax_gens) = plt.subplots(2, 1)
                    for i in range(k_moments+1):
                        ax_data.axhline(y=normed_moments_data[i],
                                        label='m{}'.format(i+1), c=cmap(i))
                        ax_gens.plot(empirical_moments_gens[:step//log_step, i],
                                     label='m{}'.format(i+1), c=cmap(i),
                                     alpha=0.8)
                    ax_data.set_ylim(min(normed_moments_data)-0.5,
                                     max(normed_moments_data)+0.5)
                    ax_gens.set_xlabel('Empirical moments, gens')
                    ax_data.set_xlabel('Empirical moments, data')
                    ax_gens.legend()
                    ax_data.legend()
                    plt.suptitle('{}, empirical moments, k={}'.format(
                        tag, k_moments))
                    plt.tight_layout()
                    plt.savefig(os.path.join(
                        plot_dir, 'empirical_moments.png'))
                    plt.close(fig)

                # Plot relative error of moments.
                relative_error_of_moments_test = (
                    norm(np.array(normed_moments_data_test) -
                         np.array(normed_moments_data), axis=1) /
                    norm(np.array(normed_moments_data), axis=1))
                relative_error_of_moments_gens = (
                    norm(np.array(normed_moments_gens) -
                         np.array(normed_moments_data), axis=1) /
                    norm(np.array(normed_moments_data), axis=1))

                relative_error_of_moments_test[nmd_zero_indices] = 0.0
                relative_error_of_moments_gens[nmd_zero_indices] = 0.0
                reom[step // log_step] = relative_error_of_moments_gens

                if extra_verbose:
                    if data_dim <= 2:
                        print('  Relative_error_of_moments_TEST: {}'.format(list(
                            np.round(relative_error_of_moments_test, 2))))
                        print('  Relative_error_of_moments_GENS: {}'.format(list(
                            np.round(relative_error_of_moments_gens, 2))))

                # For plotting, zero-out moments that are likely zero, so
                # their relative values don't dominate the plot.
                reom_trim_level = np.max(np.abs(reom[:, :k_moments]))
                reom_trimmed = np.copy(reom)
                reom_trimmed[
                    np.where(reom_trimmed > reom_trim_level)] = \
                            2 * reom_trim_level
                reom_trimmed[
                    np.where(reom_trimmed < -reom_trim_level)] = \
                            -2 * reom_trim_level
                fig, ax = plt.subplots()
                for i in range(k_moments+1):
                    ax.plot(reom[:step//log_step, i],
                            label='m{}'.format(i+1), c=cmap(i))
                #ax.set_ylim((-2 * reom_trim_level, 2 * reom_trim_level))
                ax.set_ylim((-2, 2))
                ax.legend()
                plt.suptitle('{}, relative errors of moments, k={}'.format(
                    tag, k_moments))
                plt.savefig(os.path.join(
                    plot_dir, 'reom.png'))
                plt.close(fig)

                # Print normed moments to console.
                if extra_verbose:
                    if data_dim == 1:
                        print('    data_normed moments: {}'.format(
                            normed_moments_data))
                        print('    test_normed moments: {}'.format(
                            normed_moments_data_test))
                        print('    gens_normed moments: {}'.format(
                            normed_moments_gens))
            
            # Plot yangmed data.
            if 'yang' in data_file:
                binary_cols = [17, 18, 19]
                for i, row in enumerate(g_full_unnormed):
                    for col in binary_cols:
                        if row[col] < 0.5:
                            row[col] = 0.0
                        else:
                            row[col] = 1.0
                    g_full_unnormed[i] = row
                plot_marginals(data_unnormed, data, batch_size, step, g_full_, g_full_unnormed, log_dir)
                plot_correlations(data_unnormed, step, g_full_unnormed, log_dir)