In [1]:
# example of fitting an auxiliary classifier gan (ac-gan) on fashion mnsit
from numpy import zeros
from numpy import ones
from numpy import expand_dims
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.datasets.fashion_mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.initializers import RandomNormal
from matplotlib import pyplot
import os
import torch
import numpy as np

In [2]:
def load_real_samples():
	# load dataset
	train_x = []
	train_y = []
	for client in range(10):
		for item in torch.load('../data/MNIST.pth')[client]:
			img = item[0].swapaxes(0,1)
			img = img.swapaxes(1,2)
			train_x.append(np.array(img))
			train_y.append(np.array(item[1]))

	# (trainX, trainy), (_, _) = load_data()
	# expand to 3d, e.g. add channels
	X = expand_dims(train_x, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	# print(train_y[0])
	return [np.array(X), np.array(train_y)]

load_real_samples()

(1000, 28, 28, 1) (1000,)


[array([[[[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]],
 
 
         [[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]],
 
 
         [[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]],
 
 
         ...,
 
 
         [[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]],
 
 
         [[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]],
 
 
         [[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
          [[-1.]],
 
          [[-1.]]]],
 
 
 
        [[[[-1.]],
 
          [[-1.]],
 
          [[-1.]],
 
          ...,
 
          [[-1.]],
 
  

In [3]:
# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1), n_classes=10):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=in_shape)
	# downsample to 14x14
	fe = Conv2D(32, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Dropout(0.5)(fe)
	# normal
	fe = Conv2D(64, (3,3), padding='same', kernel_initializer=init)(fe)
	fe = BatchNormalization()(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Dropout(0.5)(fe)
	# downsample to 7x7
	fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(fe)
	fe = BatchNormalization()(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Dropout(0.5)(fe)
	# normal
	fe = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(fe)
	fe = BatchNormalization()(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	fe = Dropout(0.5)(fe)
	# flatten feature maps
	fe = Flatten()(fe)
	# real/fake output
	out1 = Dense(1, activation='sigmoid')(fe)
	# class label output
	out2 = Dense(n_classes, activation='softmax')(fe)
	# define model
	model = Model(in_image, [out1, out2])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'sparse_categorical_crossentropy'], optimizer=opt)
	return model

# define the standalone generator model
def define_generator(latent_dim, n_classes=10):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# label input
	in_label = Input(shape=(1,))
	# embedding for categorical input
	li = Embedding(n_classes, 50)(in_label)
	# linear multiplication
	n_nodes = 7 * 7
	li = Dense(n_nodes, kernel_initializer=init)(li)
	# reshape to additional channel
	li = Reshape((7, 7, 1))(li)
	# image generator input
	in_lat = Input(shape=(latent_dim,))
	# foundation for 7x7 image
	n_nodes = 384 * 7 * 7
	gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
	gen = Activation('relu')(gen)
	gen = Reshape((7, 7, 384))(gen)
	# merge image gen and label input
	merge = Concatenate()([gen, li])
	# upsample to 14x14
	gen = Conv2DTranspose(192, (5,5), strides=(2,2), padding='same', kernel_initializer=init)(merge)
	gen = BatchNormalization()(gen)
	gen = Activation('relu')(gen)
	# upsample to 28x28
	gen = Conv2DTranspose(1, (5,5), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	out_layer = Activation('tanh')(gen)
	# define model
	model = Model([in_lat, in_label], out_layer)
	return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
	# make weights in the discriminator not trainable
	for layer in d_model.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# connect the outputs of the generator to the inputs of the discriminator
	gan_output = d_model(g_model.output)
	# define gan model as taking noise and label and outputting real/fake and label outputs
	model = Model(g_model.input, gan_output)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'sparse_categorical_crossentropy'], optimizer=opt)
	return model

# load images
# def load_real_samples():
# 	# load dataset
# 	(trainX, trainy), (_, _) = load_data()
# 	# expand to 3d, e.g. add channels
# 	X = expand_dims(trainX, axis=-1)
# 	# convert from ints to floats
# 	X = X.astype('float32')
# 	# scale from [0,255] to [-1,1]
# 	X = (X - 127.5) / 127.5
# 	print(X.shape, trainy.shape)
# 	return [X, trainy]

# select real samples
def generate_real_samples(dataset, n_samples):
	# split into images and labels
	images, labels = dataset
	# choose random instances
	ix = randint(0, images.shape[0], n_samples)
	# select images and labels
	X, labels = images[ix], labels[ix]
	# generate class labels
	y = ones((n_samples, 1))
	return [X, labels], y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples, n_classes=10):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_input = x_input.reshape(n_samples, latent_dim)
	# generate labels
	labels = randint(0, n_classes, n_samples)
	return [z_input, labels]

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	z_input, labels_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	images = generator.predict([z_input, labels_input])
	# create class labels
	y = zeros((n_samples, 1))
	return [images, labels_input], y

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
	# prepare fake examples
	[X, _], _ = generate_fake_samples(g_model, latent_dim, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(100):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'results/ACGAN/generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'results/ACGAN/model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# train the generator and discriminator
# def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=64):
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=5, n_batch=20):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset[0].shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# manually enumerate epochs
	for i in range(n_steps):
		# get randomly selected 'real' samples
		[X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)
		# update discriminator model weights
		_,d_r1,d_r2 = d_model.train_on_batch(X_real, [y_real, labels_real])
		# generate 'fake' examples
		[X_fake, labels_fake], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
		# update discriminator model weights
		_,d_f,d_f2 = d_model.train_on_batch(X_fake, [y_fake, labels_fake])
		# prepare points in latent space as input for the generator
		[z_input, z_labels] = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = ones((n_batch, 1))
		# update the generator via the discriminator's error
		_,g_1,g_2 = gan_model.train_on_batch([z_input, z_labels], [y_gan, z_labels])
		# summarize loss on this batch
		print('>%d, dr[%.3f,%.3f], df[%.3f,%.3f], g[%.3f,%.3f]' % (i+1, d_r1,d_r2, d_f,d_f2, g_1,g_2))
		# evaluate the model performance every 'epoch'
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, latent_dim)

# make folder for results
os.makedirs('results/ACGAN', exist_ok=True)
# size of the latent space
latent_dim = 100
# create the discriminator
discriminator = define_discriminator()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, discriminator)
# load image data
dataset = load_real_samples()
# train model
train(generator, discriminator, gan_model, dataset, latent_dim)

  super(Adam, self).__init__(name, **kwargs)


(1000, 28, 28, 1) (1000,)
>1, dr[0.511,3.253], df[1.110,3.134], g[1.003,3.384]
>2, dr[1.065,3.273], df[0.983,3.509], g[0.962,3.253]
>3, dr[1.016,3.382], df[0.816,2.653], g[0.765,2.843]
>4, dr[0.634,2.160], df[1.213,2.136], g[1.151,3.642]
>5, dr[0.881,4.090], df[1.044,3.154], g[1.439,3.420]
>6, dr[0.896,2.956], df[0.501,2.780], g[0.840,3.105]
>7, dr[0.549,3.241], df[0.995,2.992], g[1.303,2.718]
>8, dr[0.857,2.822], df[0.838,3.777], g[1.194,3.283]
>9, dr[0.720,2.752], df[0.564,3.582], g[0.864,3.470]
>10, dr[0.733,3.150], df[0.344,2.397], g[0.815,3.650]
>11, dr[0.539,2.602], df[0.811,3.779], g[0.980,3.616]
>12, dr[0.430,4.171], df[0.685,3.183], g[1.176,3.368]
>13, dr[0.729,3.230], df[1.145,3.486], g[1.192,3.135]
>14, dr[0.354,3.281], df[0.549,3.346], g[1.469,2.893]
>15, dr[0.959,2.791], df[0.660,3.558], g[1.241,2.881]
>16, dr[0.877,2.784], df[1.090,2.992], g[1.373,3.286]
>17, dr[0.412,2.246], df[0.824,3.419], g[1.321,3.034]
>18, dr[0.352,2.927], df[0.558,3.060], g[2.029,2.756]
>19, dr[0.4

In [10]:
print(len(discriminator.get_weights()))

24
