In [1]:
import os, sys
sys.path.append(os.getcwd())
import time

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets
import tensorflow as tf

In [2]:


import tflib as lib
import tflib.ops.linear
import tflib.ops.conv2d
import tflib.ops.batchnorm
import tflib.ops.deconv2d
import tflib.save_images
import tflib.mnist
import tflib.plot

In [12]:
MODE = 'wgan-gp' # dcgan, wgan, or wgan-gp
DIM = 64 # Model dimensionality
BATCH_SIZE = 50 # Batch size
CRITIC_ITERS = 5 # For WGAN and WGAN-GP, number of critic iters per gen iter
LAMBDA = 10 # Gradient penalty lambda hyperparameter
ITERS = 200000 # How many generator iterations to train for 
OUTPUT_DIM = 4096 # Number of pixels in MNIST (64*64)

In [3]:
lib.print_model_settings(locals().copy())

Uppercase local vars:


In [4]:
def LeakyReLU(x, alpha=0.2):
    return tf.maximum(alpha*x, x)

In [5]:
def ReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name+'.Linear', 
        n_in, 
        n_out, 
        inputs,
        initialization='he'
    )
    return tf.nn.relu(output)


In [6]:
def LeakyReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name+'.Linear', 
        n_in, 
        n_out, 
        inputs,
        initialization='he'
    )
    return LeakyReLU(output)

In [7]:
def Generator(n_samples, noise=None):
    if noise is None:
        noise = tf.random.normal([n_samples, 128])

    output = lib.ops.linear.Linear('Generator.Input', 128, 4*4*4*DIM, noise)
    if MODE == 'wgan':
        output = lib.ops.batchnorm.Batchnorm('Generator.BN1', [0], output)
    output = tf.nn.relu(output)
    output = tf.reshape(output, [-1, 4*DIM, 4, 4])

    output = lib.ops.deconv2d.Deconv2D('Generator.2', 4*DIM, 2*DIM, 5, output)
    if MODE == 'wgan':
        output = lib.ops.batchnorm.Batchnorm('Generator.BN2', [0,2,3], output)
    output = tf.nn.relu(output)

    output = output[:,:,:7,:7]

    output = lib.ops.deconv2d.Deconv2D('Generator.3', 2*DIM, DIM, 5, output)
    if MODE == 'wgan':
        output = lib.ops.batchnorm.Batchnorm('Generator.BN3', [0,2,3], output)
    output = tf.nn.relu(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.5', DIM, 1, 5, output)
    output = tf.nn.sigmoid(output)

    return tf.reshape(output, [-1, OUTPUT_DIM])


In [8]:
def Discriminator(inputs):
    output = tf.reshape(inputs, [-1, 1, 64, 64])

    output = lib.ops.conv2d.Conv2D('Discriminator.1',1,DIM,5,output,stride=2)
    output = LeakyReLU(output)

    output = lib.ops.conv2d.Conv2D('Discriminator.2', DIM, 2*DIM, 5, output, stride=2)
    if MODE == 'wgan':
        output = lib.ops.batchnorm.Batchnorm('Discriminator.BN2', [0,2,3], output)
    output = LeakyReLU(output)

    output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*DIM, 4*DIM, 5, output, stride=2)
    if MODE == 'wgan':
        output = lib.ops.batchnorm.Batchnorm('Discriminator.BN3', [0,2,3], output)
    output = LeakyReLU(output)

    output = tf.reshape(output, [-1, 4*4*4*DIM])
    output = lib.ops.linear.Linear('Discriminator.Output', 4*4*4*DIM, 1, output)

    return tf.reshape(output, [-1])

In [23]:
real_data = tf.Variable(initial_value=tf.zeros(shape=(BATCH_SIZE, OUTPUT_DIM), dtype=tf.float32), trainable=False)
#real_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE, OUTPUT_DIM])
fake_data = Generator(BATCH_SIZE)

disc_real = Discriminator(real_data)
disc_fake = Discriminator(fake_data)

gen_params = lib.params_with_name('Generator')
disc_params = lib.params_with_name('Discriminator')

TypeError: Variable is unhashable. Instead, use variable.ref() as the key. (Variable: <tf.Variable 'Generator.Input/Generator.Input.W:0' shape=(128, 4096) dtype=float32, numpy=
array([[ 0.02567775, -0.02669399, -0.01812933, ..., -0.03175392,
        -0.02235023, -0.01022566],
       [-0.01484493,  0.00863906, -0.01036157, ...,  0.00751138,
        -0.01818066, -0.03466291],
       [ 0.00817575,  0.0166621 ,  0.03102938, ...,  0.00167163,
        -0.02021211, -0.02805721],
       ...,
       [ 0.03072485, -0.01196431,  0.00519867, ..., -0.00522642,
         0.02819545, -0.00683101],
       [-0.01217334, -0.02364941, -0.00086858, ...,  0.02179544,
        -0.00736717, -0.00091929],
       [ 0.02785378, -0.02328118,  0.02192808, ..., -0.02472678,
         0.02378554,  0.00449341]], dtype=float32)>)