# License

Copyright 2019 Hamaad Musharaf Shah

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Automatic feature engineering using Generative Adversarial Networks: Application to computer vision and synthetic financial transactions data
## Author: Hamaad Shah

---

The purpose of deep learning is to learn a representation of high dimensional and noisy data using a sequence of differentiable functions, i.e., geometric transformations, that can perhaps be used for supervised learning tasks among others. It has had great success in discriminative models while generative models have fared worse due to the limitations of explicit maximum likelihood estimation (MLE). Adversarial learning as presented in the Generative Adversarial Network (GAN) aims to overcome these problems by using implicit MLE. 

We will use the MNIST computer vision dataset and a synthetic financial transactions dataset for an insurance task for these experiments. GAN is a remarkably different method of learning compared to explicit MLE. Our purpose will be to show that the representation learnt by a GAN can be used for supervised learning tasks such as image recognition and insurance loss risk prediction. In this manner we avoid the manual process of handcrafted feature engineering by learning a set of features automatically, i.e., representation learning.

In [None]:
# Author: Hamaad Musharaf Shah.
from PIL import Image

from six.moves import range

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import math
import sys
import importlib

import numpy as np

import pandas as pd

from sklearn import linear_model
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, LabelBinarizer
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import keras
from keras import backend as bkend
from keras.datasets import cifar10, mnist
from keras import layers
from keras.layers import Input, Dense, BatchNormalization, Dropout, Flatten, convolutional, pooling, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras import metrics
from keras.models import Model
from keras.utils.generic_utils import Progbar

import tensorflow as tf
from tensorflow.python.client import device_lib

import matplotlib.pyplot as plt

from plotnine import *
import plotnine

get_ipython().magic("matplotlib inline")

os.environ["KERAS_BACKEND"] = "tensorflow"
importlib.reload(bkend)

print(device_lib.list_local_devices())

mnist = mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist
x_train = np.reshape(x_train, [x_train.shape[0], x_train.shape[1] * x_train.shape[2]])
x_test = np.reshape(x_test, [x_test.shape[0], x_test.shape[1] * x_test.shape[2]])
y_train = y_train.ravel()
y_test = y_test.ravel()
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255.0
x_test /= 255.0

scaler_classifier = MinMaxScaler(feature_range=(0.0, 1.0))
logistic = linear_model.LogisticRegression(random_state=666, verbose=1)
lb = LabelBinarizer()
lb = lb.fit(y_train.reshape(y_train.shape[0], 1))

dl4j_res_path = "/Users/samson/Projects/gan_deeplearning4j/Java/src/main/resources/"

np.savetxt(fname=dl4j_res_path + "mnist_train.csv",
           X=np.hstack([x_train, 
                        y_train.reshape([-1, 1])]),
           fmt="%.2f",
           delimiter=",",
           newline="\n",
           header="")

np.savetxt(fname=dl4j_res_path + "mnist_test.csv",
           X=np.hstack([x_test,
                        y_test.reshape([-1, 1])]),
           fmt="%.2f",
           delimiter=",",
           newline="\n",
           header="")

## Generative Adversarial Network

---

There are 2 main components to a GAN, the generator and the discriminator, that play an adversarial game against each other. In doing so the generator learns how to create realistic synthetic samples from noise, i.e., the latent space $z$, while the discriminator learns how to distinguish between a real sample and a synthetic sample. 

The representation learnt by the discriminator can later on be used for other supervised learning tasks, i.e., automatic feature engineering or representation learning. This can also be viewed through the lens of transfer learning. A GAN can also be used for semi-supervised learning which we will get to in another paper where we will look into using variational autoencoders, ladder networks and adversarial autoencoders for this purpose.

### Computer Vision

---

We will use the MNIST dataset for this purpose where the raw data is a 2 dimensional tensor of pixel intensities per image. The image is our unit of analysis: We will predict the probability of each class for each image. This is a multiclass classification task and we will use the accuracy score to assess model performance on the test fold.

![](pixel_lattice.png)

Some examples of handcrafted feature engineering for the computer vision task perhaps might be using Gabor filters.

### Insurance

---

We will use a synthetic dataset where the raw data is a 2 dimensional tensor of historical policy level information per policy-period combination: Per unit this will be a 4 by 3 dimensional tensor, i.e., 4 historical time periods and 3 transactions types. The policy-period combination is our unit of analysis: We will predict the probability of loss for time period 5 in the future — think of this as a potential renewal of the policy for which we need to predict whether it would make a loss for us or not hence affecting whether we decided to renew the policy and / or adjust the renewal premium to take into account the additional risk. This is a binary class classification task and we will use the AUROC score to assess model performance.

![](trans_lattice.png)

Some examples of handcrafted feature engineering for the insurance task perhaps might be using column or row averages.

The synthetic insurance financial transactions dataset was coded in R. All the rest of the work is done in Python.

Please note the similarities between the raw data for the computer vision task and the raw data for the insurance task. Our main goal here is to learn a good representation of this raw data using automatic feature engineering via GANs.

### Distributed Machine Learning via Deeplearning4j and Apache Spark

---

- We will use the Java library Deeplearning4j for the GANs along with Apache Spark for distributed machine learning via synchronous parameter averaging. 
- We use 2 external GPUs, i.e., GTX 1070, on a MacBook Pro. 
- We use IntelliJ for the Java IDE and Apache Maven for build automation.

### Synchronous Parameter Averaging

---

Assume we have $n$ workers, $m$ observations per worker, $\eta$ is the learning rate and the Mean Squared Error (MSE) as our objective function. Synchronous parameter averaging can be defined as follows.
- **Map**: A local copy of the global model with weights $\theta$ is sent to each worker and the local model is trained on the $m$ observations subset in each $n$ workers.
- **Reduce**: The global model at the parameter server with weights $\Theta$ is updated periodically by the average of the model weights taken from each of the $m$ workers.

$$
\Theta \leftarrow \Theta + \frac{\eta}{n}\sum_{i=1}^{n} \left[\nabla_{\theta}\text{MSE}\right]_{i}
$$

### Generator

---

Assume that we have a prior belief on where the latent space $z$ lies: $p(z)$. Given a draw from this latent space the generator $G$, a deep learner parameterized by $\theta_{G}$, outputs a synthetic sample.

$$
G(z|\theta_{G}): z \rightarrow x_{synthetic}
$$ 

### Discriminator

---

The discriminator $D$ is another deep learner parameterized by $\theta_{D}$ and it aims to classify if a sample is real or synthetic, i.e., if a sample is from the real data distribution,

$$
P_{\text{data}}
$$ 

or the synthetic data distribution.

$$
P_{G}
$$

Let us denote the discriminator $D$ as follows.

$$
D(x|\theta_{D}): x \rightarrow [0, 1]
$$ 

Here we assume that the positive examples are from the real data distribution while the negative examples are from the synthetic data distribution.

### Game: Optimality

---

A GAN simultaneously trains the discriminator to correctly classify real and synthetic examples while training the generator to create synthetic examples such that the discriminator incorrectly classifies real and synthetic examples. This 2 player minimax game has the following objective function.

$$
\min_{G(z|\theta_{G})} \max_{D(x|\theta_{D})} V(D(x|\theta_{D}), G(z|\theta_{G})) = \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{D(x|\theta_{D})} + \mathbb{E}_{z \sim p(z)} \log{(1 - D(G(z|\theta_{G})|\theta_{D}))}
$$

Please note that the above expression is basically the objective function of the discriminator.

$$
\mathbb{E}_{x \sim p_{\text{data}}(x)} \log{D(x|\theta_{D})} + \mathbb{E}_{x \sim p_{G}(x)} \log{(1 - D(x|\theta_{D}))}
$$

It is clear from how the game has been set up that we are trying to obtain a solution $\theta_{D}$ for $D$ such that it maximizes $V(D, G)$ while simultaneously we are trying to obtain a solution $\theta_{G}$ for $G$ such that it minimizes $V(D, G)$.

We do not simultaneously train $D$ and $G$. We train them alternately: Train $D$ and then train $G$ while freezing $D$. We repeat this for a fixed number of steps.

If the synthetic samples taken from the generator $G$ are realistic then implicitly we have learnt the distribution $P_{G}$. In other words, $P_{G}$ can be seen as a good estimation of $P_{\text{data}}$. The optimal solution will be as follows.

$$
P_{G}=P_{\text{data}}
$$

To show this let us find the optimal discriminator $D^\ast$ given a generator $G$ and sample $x$. 

\begin{align*}
V(D, G) &= \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{D(x|\theta_{D})} + \mathbb{E}_{x \sim p_{G}(x)} \log{(1 - D(x|\theta_{D}))} \\
&= \int_{x} p_{\text{data}}(x) \log{D(x|\theta_{D})} dx + \int_{x} p_{G}(x) \log{(1 - D(x|\theta_{D}))} dx \\
&= \int_{x} \underbrace{p_{\text{data}}(x) \log{D(x|\theta_{D})} + p_{G}(x) \log{(1 - D(x|\theta_{D}))}}_{J(D(x|\theta_{D}))} dx
\end{align*}

Let us take a closer look at the discriminator's objective function for a sample $x$.

\begin{align*}
J(D(x|\theta_{D})) &= p_{\text{data}}(x) \log{D(x|\theta_{D})} + p_{G}(x) \log{(1 - D(x|\theta_{D}))} \\
\frac{\partial J(D(x|\theta_{D}))}{\partial D(x|\theta_{D})} &= \frac{p_{\text{data}}(x)}{D(x|\theta_{D})} - \frac{p_{G}(x)}{(1 - D(x|\theta_{D}))} \\
0 &= \frac{p_{\text{data}}(x)}{D^\ast(x|\theta_{D^\ast})} - \frac{p_{G}(x)}{(1 - D^\ast(x|\theta_{D^\ast}))} \\
p_{\text{data}}(x)(1 - D^\ast(x|\theta_{D^\ast})) &= p_{G}(x)D^\ast(x|\theta_{D^\ast}) \\
p_{\text{data}}(x) - p_{\text{data}}(x)D^\ast(x|\theta_{D^\ast})) &= p_{G}(x)D^\ast(x|\theta_{D^\ast}) \\
p_{G}(x)D^\ast(x|\theta_{D^\ast}) + p_{\text{data}}(x)D^\ast(x|\theta_{D^\ast})) &= p_{\text{data}}(x) \\
D^\ast(x|\theta_{D^\ast}) &= \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)} 
\end{align*}

We have found the optimal discriminator given a generator. Let us focus now on the generator's objective function which is essentially to minimize the discriminator's objective function.

\begin{align*}
J(G(x|\theta_{G})) &= \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{D^\ast(x|\theta_{D^\ast})} + \mathbb{E}_{x \sim p_{G}(x)} \log{(1 - D^\ast(x|\theta_{D^\ast}))} \\
&= \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{\bigg( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) + \mathbb{E}_{x \sim p_{G}(x)} \log{\bigg(1 - \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} \\
&= \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{\bigg( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) + \mathbb{E}_{x \sim p_{G}(x)} \log{\bigg(\frac{p_{G}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} \\
&= \int_{x} p_{\text{data}}(x) \log{\bigg( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) dx + \int_{x} p_{G}(x) \log{\bigg(\frac{p_{G}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} dx
\end{align*}

We will note the Kullback–Leibler (KL) divergences in the above objective function for the generator.

$$
D_{KL}(P||Q) = \int_{x} p(x) \log\bigg(\frac{p(x)}{q(x)}\bigg) dx
$$

Recall the definition of a $\lambda$ divergence.

$$
D_{\lambda}(P||Q) = \lambda D_{KL}(P||\lambda P + (1 - \lambda) Q) + (1 - \lambda) D_{KL}(Q||\lambda P + (1 - \lambda) Q)
$$

If $\lambda$ takes the value of 0.5 this is then called the Jensen-Shannon (JS) divergence. This divergence is symmetric and non-negative.

$$
D_{JS}(P||Q) = 0.5 D_{KL}\bigg(P\bigg|\bigg|\frac{P + Q}{2}\bigg) + 0.5 D_{KL}\bigg(Q\bigg|\bigg|\frac{P + Q}{2}\bigg)
$$

Keeping this in mind let us take a look again at the objective function of the generator.

\begin{align*}
J(G(x|\theta_{G})) &= \int_{x} p_{\text{data}}(x) \log{\bigg( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) dx + \int_{x} p_{G}(x) \log{\bigg(\frac{p_{G}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} dx \\
&= \int_{x} p_{\text{data}}(x) \log{\bigg(\frac{2}{2}\frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) dx + \int_{x} p_{G}(x) \log{\bigg(\frac{2}{2}\frac{p_{G}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} dx \\
&= \int_{x} p_{\text{data}}(x) \log{\bigg(\frac{1}{2}\frac{1}{0.5}\frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{G}(x)}} \bigg) dx + \int_{x} p_{G}(x) \log{\bigg(\frac{1}{2}\frac{1}{0.5}\frac{p_{G}(x)}{p_{\text{data}}(x) + p_{G}(x)}\bigg)} dx \\
&= \int_{x} p_{\text{data}}(x) \bigg[ \log(0.5) + \log{\bigg(\frac{p_{\text{data}}(x)}{0.5 (p_{\text{data}}(x) + p_{G}(x))}} \bigg) \bigg] dx \\ &+ \int_{x} p_{G}(x) \bigg[\log(0.5) + \log{\bigg(\frac{p_{G}(x)}{0.5 (p_{\text{data}}(x) + p_{G}(x))}\bigg) \bigg] } dx \\
&= \log\bigg(\frac{1}{4}\bigg) + \int_{x} p_{\text{data}}(x) \bigg[\log{\bigg(\frac{p_{\text{data}}(x)}{0.5 (p_{\text{data}}(x) + p_{G}(x))}} \bigg) \bigg] dx \\ 
&+ \int_{x} p_{G}(x) \bigg[\log{\bigg(\frac{p_{G}(x)}{0.5 (p_{\text{data}}(x) + p_{G}(x))}\bigg) \bigg] } dx \\
&= -\log(4) + D_{KL}\bigg(P_{\text{data}}\bigg|\bigg|\frac{P_{\text{data}} + P_{G}}{2}\bigg) + D_{KL}\bigg(P_{G}\bigg|\bigg|\frac{P_{\text{data}} + P_{G}}{2}\bigg) \\
&= -\log(4) + 2 \bigg(0.5 D_{KL}\bigg(P_{\text{data}}\bigg|\bigg|\frac{P_{\text{data}} + P_{G}}{2}\bigg) + 0.5 D_{KL}\bigg(P_{G}\bigg|\bigg|\frac{P_{\text{data}} + P_{G}}{2}\bigg)\bigg) \\
&= -\log(4) + 2D_{JS}(P_{\text{data}}||P_{G}) 
\end{align*}

It is clear from the objective function of the generator above that the global minimum value attained is $-\log(4)$ which occurs when the following holds.

$$
P_{G}=P_{\text{data}}
$$

When the above holds the Jensen-Shannon divergence, i.e., $D_{JS}(P_{\text{data}}||P_{G})$, will be zero. Hence we have shown that the optimal solution is as follows.

$$
P_{G}=P_{\text{data}}
$$

### Game: Convergence

---

Assuming that the discriminator is allowed to reach its optimum given a generator, then $P_{G}$ can be shown to converge to $P_{\text{data}}$. 

Consider the following objective function which has been previously shown to be convex with respect to $P_{G}$ as we found the global minimum at $-\log(4)$.

$$
U(D^\ast, P_{G}) = \mathbb{E}_{x \sim p_{\text{data}}(x)} \log{D^\ast(x|\theta_{D^\ast})} + \mathbb{E}_{x \sim p_{G}(x)} \log{(1 - D^\ast(x|\theta_{D^\ast}))}
$$

Gradient descent is used by the generator to move towards the global minimum given an optimal discriminator. We will show that the gradient of the generator exists given an optimal discriminator, i.e., $\nabla_{P_{G}} U(D^\ast, P_{G})$, such that convergence of $P_{G}$ to $P_{\text{data}}$ is guaranteed.

Note that the following is a supremum of a set of convex functions where the set is indexed by the discriminator $D$: $U(D^\ast, P_{G})=\sup_{D} U(D, P_{G})$. Remember that the supremum is the least upper bound.

Let us recall a few definitions regarding gradients and subgradients. A vector $g \in \mathbb{R}^K$ is a subgradient of a function $f: \mathbb{R}^K \rightarrow \mathbb{R}$ at a point $x \in \mathbb{dom}(f)$ if $\forall z \in \mathbb{dom}(f)$, the following relationship holds:

$$
f(z) \geq f(x) + g^{T}(z - x)
$$

If $f$ is convex and differentiable then its gradient at a point $x$ is also the subgradient. Most importantly, a subgradient can exist even if $f$ is not differentiable.

The subgradients of the supremum of a set of convex functions include the subgradient of the function at the point where the supremum is attained. As mentioned earlier, we have already shown that $U(D^\ast, P_{G})$ is convex.

\begin{align*}
&U(D^\ast, P_{G})=\sup_{D} U(D, P_{G}) \\
&\nabla_{P_{G}} \sup_{D} U(D, P_{G}) \in \nabla_{P_{G}} U(D, P_{G}) \\
&\nabla_{P_{G}} U(D^\ast, P_{G}) \in \nabla_{P_{G}} U(D, P_{G})
\end{align*}

The gradient of the generator, $\nabla_{P_{G}} U(D^\ast, P_{G})$, is used to make incremental improvements to the objective function of the generator, $U(D^\ast, P_{G})$, given an optimal discriminator, $D^\ast$. Therefore convergence of $P_{G}$ to $P_{\text{data}}$ is guaranteed.

### Results

---

In these experiments we show the ability of the generator to create realistic synthetic examples for the MNIST dataset and the insurance dataset. We use a 2-dimensional latent manifold.

Finally we show that using the representation learnt by the discriminator we can attain competitive results to using other representation learning methods for the MNIST dataset and the insurance dataset such as a wide variety of autoencoders.

### Results: Generating new data

---

![](DCGAN_Generated_Images.png)

![](DCGAN_Generated_Lattices.png)

### Results: GAN for representation learning

---

* The accuracy score for the MNIST classification task with DCGAN: 97.070000%.
* The AUROC score for the insurance classification task with DCGAN: 91.632719%.

### The insurance data: A closer look

With image data we can perhaps judge qualitatively whether the generated data makes sense. For financial transactions data this is not possible. However let's have a look at an example of a generated transactions lattice. Please note that all financial transactions data has been transformed to lie between 0 and 1. 

![](DCGAN_Generated_Lattice_Example.png)

If we use the same matplotlib code as applied to the image data to plot the above generated transactions lattice we get the following image. We can see that where we have the maximum value possible for a transaction, i.e., 1, that is colored as black, while where we have the minimum value possible for a transaction, i.e., 0, that is colored as white. Transactions values in between have some gray color.  

![](DCGAN_Generated_Lattice_Example_Plotted.png)

### Java code for computer vision task

```java
package org.deeplearning4j;

import java.io.*;
import java.util.*;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.*;
import org.deeplearning4j.nn.weights.*;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.util.ModelSerializer;

import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.lossfunctions.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class dl4jGANComputerVision {
    private static final Logger log = LoggerFactory.getLogger(dl4jGANComputerVision.class);

    private static final int batchSizePerWorker = 200;
    private static final int batchSizePred = 500;
    private static final int labelIndex = 784;
    private static final int numClasses = 10; // Using Softmax.
    private static final int numClassesDis = 1; // Using Sigmoid.
    private static final int numFeatures = 784;
    private static final int numIterations = 10000;
    private static final int numGenSamples = 10; // This will be a grid so effectively we get {numGenSamples * numGenSamples} samples.
    private static final int numLinesToSkip = 0;
    private static final int numberOfTheBeast = 666;
    private static final int printEvery = 10;
    private static final int saveEvery = 100;
    private static final int tensorDimOneSize = 28;
    private static final int tensorDimTwoSize = 28;
    private static final int tensorDimThreeSize = 1;
    private static final int zSize = 2;

    private static final double dis_learning_rate = 0.002;
    private static final double frozen_learning_rate = 0.0;
    private static final double gen_learning_rate = 0.004;

    private static final String delimiter = ",";
    private static final String resPath = "/Users/samson/Projects/gan_deeplearning4j/outputs/computer_vision/";
    private static final String newLine = "\n";
    private static final String dataSetName = "mnist";

    private static final boolean useGpu = true;

    public static void main(String[] args) throws Exception {
        new dl4jGANComputerVision().GAN(args);
    }

    private void GAN(String[] args) throws Exception {
        for (int i = 0; i < args.length; i++) {
            System.out.println(args[i]);
        }

        if (useGpu) {
            System.out.println("Setting up CUDA environment!");
            Nd4j.setDataType(DataBuffer.Type.FLOAT);

            CudaEnvironment.getInstance().getConfiguration()
                    .allowMultiGPU(true)
                    .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)
                    .allowCrossDeviceAccess(true)
                    .setVerbose(true);
        }

        System.out.println(Nd4j.getBackend());
        Nd4j.getMemoryManager().setAutoGcWindow(5000);

        log.info("Unfrozen discriminator!");
        ComputationGraph dis = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("dis_input_layer_0")
                .setInputTypes(InputType.convolutionalFlat(tensorDimOneSize, tensorDimTwoSize, tensorDimThreeSize))
                .addLayer("dis_batch_layer_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .build(), "dis_input_layer_0")
                .addLayer("dis_conv2d_layer_2", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1)
                        .nOut(64)
                        .build(), "dis_batch_layer_1")
                .addLayer("dis_maxpool_layer_3", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "dis_conv2d_layer_2")
                .addLayer("dis_conv2d_layer_4", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(128)
                        .build(), "dis_maxpool_layer_3")
                .addLayer("dis_maxpool_layer_5", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "dis_conv2d_layer_4")
                .addLayer("dis_dense_layer_6", new DenseLayer.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "dis_maxpool_layer_5")
                .addLayer("dis_output_layer_7", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "dis_dense_layer_6")
                .setOutputs("dis_output_layer_7")
                .build());
        dis.init();
        System.out.println(dis.summary());
        System.out.println(Arrays.toString(dis.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        log.info("Frozen generator!");
        ComputationGraph gen = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("gen_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gen_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gen_input_layer_0")
                .addLayer("gen_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gen_batch_1")
                .addLayer("gen_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(7 * 7 * 128)
                        .build(), "gen_dense_layer_2")
                .addLayer("gen_batch_4", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gen_dense_layer_3")
                .inputPreProcessor("gen_deconv2d_5", new FeedForwardToCnnPreProcessor(7, 7, 128))
                .addLayer("gen_deconv2d_5", new Upsampling2D.Builder(2)
                        .build(), "gen_batch_4")
                .addLayer("gen_conv2d_6", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(128)
                        .nOut(64)
                        .build(), "gen_deconv2d_5")
                .addLayer("gen_deconv2d_7", new Upsampling2D.Builder(2)
                        .build(), "gen_conv2d_6")
                .addLayer("gen_conv2d_8", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .activation(Activation.SIGMOID)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(1)
                        .build(), "gen_deconv2d_7")
                .setOutputs("gen_conv2d_8")
                .build());
        gen.init();
        System.out.println(gen.summary());
        System.out.println(Arrays.toString(gen.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        log.info("GAN with unfrozen generator and frozen discriminator!");
        ComputationGraph gan = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .l2(0.0001)
                .graphBuilder()
                .addInputs("gan_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gan_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_input_layer_0")
                .addLayer("gan_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gan_batch_1")
                .addLayer("gan_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(7 * 7 * 128)
                        .build(), "gan_dense_layer_2")
                .addLayer("gan_batch_4", new BatchNormalization.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_dense_layer_3")
                .inputPreProcessor("gan_deconv2d_5", new FeedForwardToCnnPreProcessor(7, 7, 128))
                .addLayer("gan_deconv2d_5", new Upsampling2D.Builder(2)
                        .build(), "gan_batch_4")
                .addLayer("gan_conv2d_6", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nIn(128)
                        .nOut(64)
                        .build(), "gan_deconv2d_5")
                .addLayer("gan_deconv2d_7", new Upsampling2D.Builder(2)
                        .build(), "gan_conv2d_6")
                .addLayer("gan_conv2d_8", new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .padding(2, 2)
                        .activation(Activation.SIGMOID)
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(1)
                        .build(), "gan_deconv2d_7")

                .addLayer("gan_dis_batch_layer_9", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_conv2d_8")
                .addLayer("gan_dis_conv2d_layer_10", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(1)
                        .nOut(64)
                        .build(), "gan_dis_batch_layer_9")
                .addLayer("gan_dis_maxpool_layer_11", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "gan_dis_conv2d_layer_10")
                .addLayer("gan_dis_conv2d_layer_12", new ConvolutionLayer.Builder(5, 5)
                        .stride(2, 2)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(64)
                        .nOut(128)
                        .build(), "gan_dis_maxpool_layer_11")
                .addLayer("gan_dis_maxpool_layer_13", new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(1, 1)
                        .build(), "gan_dis_conv2d_layer_12")
                .addLayer("gan_dis_dense_layer_14", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(1024)
                        .build(), "gan_dis_maxpool_layer_13")
                .addLayer("gan_dis_output_layer_15", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "gan_dis_dense_layer_14")
                .setOutputs("gan_dis_output_layer_15")
                .build());
        gan.init();
        System.out.println(gan.summary());
        System.out.println(Arrays.toString(gan.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        log.info("Setting up Spark configuration!");
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[4]");
        sparkConf.setAppName("Deeplearning4j on Apache Spark: Generative Adversarial Network!");
        sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
        sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        log.info("Setting up Synchronous Parameter Averaging!");
        TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker)
                .averagingFrequency(10)
                .rngSeed(numberOfTheBeast)
                .workerPrefetchNumBatches(0)
                .batchSizePerWorker(batchSizePerWorker)
                .build();

        SparkComputationGraph sparkDis = new SparkComputationGraph(sc, dis, tm);
        SparkComputationGraph sparkGan = new SparkComputationGraph(sc, gan, tm);

        log.info("Computer vision deep learning model with pre-trained layers from the GAN's discriminator!");
        ComputationGraph computerVision = new TransferLearning.GraphBuilder(sparkDis.getNetwork())
                .fineTuneConfiguration(new FineTuneConfiguration.Builder()
                        .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                        .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                        .gradientNormalizationThreshold(1.0)
                        .activation(Activation.TANH)
                        .l2(0.0001)
                        .weightInit(WeightInit.XAVIER)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .seed(numberOfTheBeast)
                        .build())
                .setFeatureExtractor("dis_dense_layer_6")
                .removeVertexKeepConnections("dis_output_layer_7")
                .addLayer("dis_batch", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1024)
                        .nOut(1024)
                        .build(), "dis_dense_layer_6")
                .addLayer("dis_output_layer_7", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(1024)
                        .nOut(numClasses)
                        .activation(Activation.SOFTMAX)
                        .build(), "dis_batch")
                .build();
        System.out.println(computerVision.summary());
        System.out.println(Arrays.toString(computerVision.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        SparkComputationGraph sparkCV = new SparkComputationGraph(sc, computerVision, tm);

        RecordReader recordReaderTrain = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReaderTrain.initialize(new FileSplit(new ClassPathResource(dataSetName + "_train.csv").getFile()));

        DataSetIterator iterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSizePerWorker, labelIndex, numClasses);
        List<DataSet> trainDataList = new ArrayList<>();

        JavaRDD<DataSet> trainDataDis, trainDataGen, trainData;

        INDArray grid = Nd4j.linspace(-1.0, 1.0, numGenSamples);
        Collection<INDArray> z = new ArrayList<>();
        log.info("Creating some noise!");
        for (int i = 0; i < numGenSamples; i++) {
            for (int j = 0; j < numGenSamples; j++) {
                z.add(Nd4j.create(new double[]{grid.getDouble(0, i), grid.getDouble(0, j)}));
            }
        }

        int batch_counter = 0;

        DataSet trDataSet;

        RecordReader recordReaderTest = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReaderTest.initialize(new FileSplit(new ClassPathResource(dataSetName + "_test.csv").getFile()));

        DataSetIterator iterTest = new RecordReaderDataSetIterator(recordReaderTest, batchSizePred, labelIndex, numClasses);

        Collection<INDArray> outFeat;

        INDArray out;
        INDArray soften_labels_fake = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);
        INDArray soften_labels_real = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);

        while (iterTrain.hasNext() && batch_counter < numIterations) {
            trainDataList.clear();
            trDataSet = iterTrain.next();

            // This is real data...
            // [Fake, Real].
            trainDataList.add(new DataSet(trDataSet.getFeatures(), Nd4j.ones(batchSizePerWorker, 1).addi(soften_labels_real)));

            // ...and this is fake data.
            // [Fake, Real].
            trainDataList.add(new DataSet(gen.output(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0))[0], Nd4j.zeros(batchSizePerWorker, 1).addi(soften_labels_fake)));

            // Unfrozen discriminator is trying to figure itself out given a frozen generator.
            log.info("Training discriminator!");
            trainDataDis = sc.parallelize(trainDataList);
            sparkDis.fit(trainDataDis);

            // Update GAN's frozen discriminator with unfrozen discriminator.
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_9").setParam("gamma", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("gamma"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_9").setParam("beta", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("beta"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_9").setParam("mean", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("mean"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_9").setParam("var", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("var"));

            sparkGan.getNetwork().getLayer("gan_dis_conv2d_layer_10").setParam("W", sparkDis.getNetwork().getLayer("dis_conv2d_layer_2").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_conv2d_layer_10").setParam("b", sparkDis.getNetwork().getLayer("dis_conv2d_layer_2").getParam("b"));

            sparkGan.getNetwork().getLayer("gan_dis_conv2d_layer_12").setParam("W", sparkDis.getNetwork().getLayer("dis_conv2d_layer_4").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_conv2d_layer_12").setParam("b", sparkDis.getNetwork().getLayer("dis_conv2d_layer_4").getParam("b"));

            sparkGan.getNetwork().getLayer("gan_dis_dense_layer_14").setParam("W", sparkDis.getNetwork().getLayer("dis_dense_layer_6").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_dense_layer_14").setParam("b", sparkDis.getNetwork().getLayer("dis_dense_layer_6").getParam("b"));

            sparkGan.getNetwork().getLayer("gan_dis_output_layer_15").setParam("W", sparkDis.getNetwork().getLayer("dis_output_layer_7").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_output_layer_15").setParam("b", sparkDis.getNetwork().getLayer("dis_output_layer_7").getParam("b"));

            trainDataList.clear();
            // Tell the frozen discriminator that all the fake examples are real examples.
            // [Fake, Real].
            trainDataList.add(new DataSet(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0), Nd4j.ones(batchSizePerWorker, 1)));

            // Unfrozen generator is trying to fool the frozen discriminator.
            log.info("Training generator!");
            trainDataGen = sc.parallelize(trainDataList);
            sparkGan.fit(trainDataGen);

            // Update frozen generator with GAN's unfrozen generator.
            gen.getLayer("gen_batch_1").setParam("gamma", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("gamma"));
            gen.getLayer("gen_batch_1").setParam("beta", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("beta"));
            gen.getLayer("gen_batch_1").setParam("mean", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("mean"));
            gen.getLayer("gen_batch_1").setParam("var", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("var"));

            gen.getLayer("gen_dense_layer_2").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_2").getParam("W"));
            gen.getLayer("gen_dense_layer_2").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_2").getParam("b"));

            gen.getLayer("gen_dense_layer_3").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_3").getParam("W"));
            gen.getLayer("gen_dense_layer_3").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_3").getParam("b"));

            gen.getLayer("gen_batch_4").setParam("gamma", sparkGan.getNetwork().getLayer("gan_batch_4").getParam("gamma"));
            gen.getLayer("gen_batch_4").setParam("beta", sparkGan.getNetwork().getLayer("gan_batch_4").getParam("beta"));
            gen.getLayer("gen_batch_4").setParam("mean", sparkGan.getNetwork().getLayer("gan_batch_4").getParam("mean"));
            gen.getLayer("gen_batch_4").setParam("var", sparkGan.getNetwork().getLayer("gan_batch_4").getParam("var"));

            gen.getLayer("gen_conv2d_6").setParam("W", sparkGan.getNetwork().getLayer("gan_conv2d_6").getParam("W"));
            gen.getLayer("gen_conv2d_6").setParam("b", sparkGan.getNetwork().getLayer("gan_conv2d_6").getParam("b"));

            gen.getLayer("gen_conv2d_8").setParam("W", sparkGan.getNetwork().getLayer("gan_conv2d_8").getParam("W"));
            gen.getLayer("gen_conv2d_8").setParam("b", sparkGan.getNetwork().getLayer("gan_conv2d_8").getParam("b"));

            trainDataList.clear();
            trainDataList.add(trDataSet);

            log.info("Training computer vision model!");
            sparkCV.getNetwork().getLayer("dis_batch_layer_1").setParam("gamma", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("gamma"));
            sparkCV.getNetwork().getLayer("dis_batch_layer_1").setParam("beta", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("beta"));
            sparkCV.getNetwork().getLayer("dis_batch_layer_1").setParam("mean", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("mean"));
            sparkCV.getNetwork().getLayer("dis_batch_layer_1").setParam("var", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("var"));

            sparkCV.getNetwork().getLayer("dis_conv2d_layer_2").setParam("W", sparkDis.getNetwork().getLayer("dis_conv2d_layer_2").getParam("W"));
            sparkCV.getNetwork().getLayer("dis_conv2d_layer_2").setParam("b", sparkDis.getNetwork().getLayer("dis_conv2d_layer_2").getParam("b"));

            sparkCV.getNetwork().getLayer("dis_conv2d_layer_4").setParam("W", sparkDis.getNetwork().getLayer("dis_conv2d_layer_4").getParam("W"));
            sparkCV.getNetwork().getLayer("dis_conv2d_layer_4").setParam("b", sparkDis.getNetwork().getLayer("dis_conv2d_layer_4").getParam("b"));

            sparkCV.getNetwork().getLayer("dis_dense_layer_6").setParam("W", sparkDis.getNetwork().getLayer("dis_dense_layer_6").getParam("W"));
            sparkCV.getNetwork().getLayer("dis_dense_layer_6").setParam("b", sparkDis.getNetwork().getLayer("dis_dense_layer_6").getParam("b"));

            trainData = sc.parallelize(trainDataList);
            sparkCV.fit(trainData);

            batch_counter++;
            log.info("Completed Batch {}!", batch_counter);

            if ((batch_counter % printEvery) == 0) {
                out = gen.output(Nd4j.vstack(z))[0].reshape(numGenSamples * numGenSamples, numFeatures);

                FileWriter fileWriter = new FileWriter(String.format("%s%s_out_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < out.shape()[0]; i++) {
                    for (int j = 0; j < out.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(out.getDouble(i, j)));
                        if (j != out.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != out.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();
            }

            if ((batch_counter % saveEvery) == 0) {
                log.info("Ensemble of deep learners for estimation of uncertainty!");

                outFeat = new ArrayList<>();
                iterTest.reset();
                while (iterTest.hasNext()) {
                    outFeat.add(sparkCV.getNetwork().output(iterTest.next().getFeatures())[0]);
                }

                INDArray toWrite = Nd4j.vstack(outFeat);
                FileWriter fileWriter = new FileWriter(String.format("%s%s_test_predictions_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < toWrite.shape()[0]; i++) {
                    for (int j = 0; j < toWrite.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(toWrite.getDouble(i, j)));
                        if (j != toWrite.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != toWrite.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();
            }

            if (!iterTrain.hasNext()) {
                iterTrain.reset();
            }
        }

        log.info("Saving models!");
        ModelSerializer.writeModel(sparkDis.getNetwork(), new File(resPath + dataSetName + "_dis_model.zip"), true);
        ModelSerializer.writeModel(sparkGan.getNetwork(), new File(resPath + dataSetName + "_gan_model.zip"), true);
        ModelSerializer.writeModel(gen, new File(resPath + dataSetName + "_gen_model.zip"), true);
        ModelSerializer.writeModel(sparkCV.getNetwork(), new File(resPath + dataSetName + "_CV_model.zip"), true);

        tm.deleteTempFiles(sc);
    }
}
```

In [None]:
output_path = "/Users/samson/Projects/gan_deeplearning4j/outputs/computer_vision/"
out_feat = pd.read_csv(filepath_or_buffer=output_path + "mnist_test_predictions_100.csv",
                       header=None)
acc_dcgan = sum(out_feat.idxmax(axis=1) == y_test) / y_test.shape[0]
print("The accuracy score for the MNIST classification task with DCGAN: %.6f%%." % (acc_dcgan * 100))

out = pd.read_csv(filepath_or_buffer=output_path + "mnist_out_10.csv", header=None)
digit_size = 28
n = 10
figure = np.zeros((digit_size * n, digit_size * n))

counter = 0
for i in range(n):
    for j in range(n):
        figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = out.iloc[counter].reshape(digit_size, digit_size)
        counter = counter + 1
        
plt.figure(figsize=(20, 20))
plt.title("Deep Convolutional Generative Adversarial Network (DCGAN) with a 2-dimensional latent manifold\nGenerating new images on the 2-dimensional latent manifold", fontsize=20)
plt.xlabel("Latent dimension 1", fontsize=24)
plt.ylabel("Latent dimension 2", fontsize=24)
plt.imshow(figure, cmap="Greys_r")
plt.savefig(fname="DCGAN_Generated_Images.png")

In [None]:
claim_risk = pd.read_csv(filepath_or_buffer="data/claim_risk.csv")
claim_risk.drop(columns="policy.id", axis=1, inplace=True)
claim_risk = np.asarray(claim_risk).ravel()

transactions = pd.read_csv(filepath_or_buffer="data/transactions.csv")
transactions.drop(columns="policy.id", axis=1, inplace=True)

n_policies = 1000
n_transaction_types = 3
n_time_periods = 4

transactions = np.reshape(np.asarray(transactions), (n_policies, n_time_periods * n_transaction_types))

X_train, X_test, y_train, y_test = train_test_split(transactions, claim_risk, test_size=0.3, random_state=666)

min_X_train = np.apply_along_axis(func1d=np.min, axis=0, arr=X_train)
max_X_train = np.apply_along_axis(func1d=np.max, axis=0, arr=X_train) 
range_X_train = max_X_train - min_X_train + sys.float_info.epsilon
X_train = (X_train - min_X_train) / range_X_train
X_test = (X_test - min_X_train) / range_X_train
transactions = (transactions - min_X_train) / range_X_train

X_train = np.reshape(np.asarray(X_train), (X_train.shape[0], n_time_periods, n_transaction_types, 1))
X_test = np.reshape(np.asarray(X_test), (X_test.shape[0], n_time_periods, n_transaction_types, 1))
transactions = np.reshape(np.asarray(transactions), (n_policies, n_time_periods, n_transaction_types, 1))

np.savetxt(fname=dl4j_res_path + "insurance_train.csv",
           X=np.hstack([X_train.reshape([X_train.shape[0], n_time_periods * n_transaction_types]), 
                        y_train.reshape([-1, 1])]),
           fmt="%.2f",
           delimiter=",",
           newline="\n",
           header="")

np.savetxt(fname=dl4j_res_path + "insurance_test.csv",
           X=np.hstack([X_test.reshape([X_test.shape[0], n_time_periods * n_transaction_types]),
                        y_test.reshape([-1, 1])]),
           fmt="%.2f",
           delimiter=",",
           newline="\n",
           header="")

### Java code for insurance task

```java
package org.deeplearning4j;

import java.io.*;
import java.util.*;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.*;
import org.deeplearning4j.nn.weights.*;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.util.ModelSerializer;

import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.lossfunctions.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class dl4jGANInsurance {
    private static final Logger log = LoggerFactory.getLogger(dl4jGANInsurance.class);

    private static final int batchSizePerWorker = 50;
    private static final int batchSizePred = 700;
    private static final int labelIndex = 12;
    private static final int numClasses = 1; // Using Sigmoid.
    private static final int numClassesDis = 1; // Using Sigmoid.
    private static final int numFeatures = 12;
    private static final int numIterations = 5000;
    private static final int numGenSamples = 50; // This will be a grid so effectively we get {numGenSamples * numGenSamples} samples.
    private static final int numLinesToSkip = 0;
    private static final int numberOfTheBeast = 666;
    private static final int printEvery = 100;
    private static final int saveEvery = 100;
    private static final int tensorDimOneSize = 4;
    private static final int tensorDimTwoSize = 3;
    private static final int tensorDimThreeSize = 1;
    private static final int zSize = 2;

    private static final double dis_learning_rate = 0.0002;
    private static final double frozen_learning_rate = 0.0;
    private static final double gen_learning_rate = 0.0004;

    private static final String delimiter = ",";
    private static final String resPath = "/Users/samson/Projects/gan_deeplearning4j/outputs/insurance/";
    private static final String newLine = "\n";
    private static final String dataSetName = "insurance";

    private static final boolean useGpu = false;

    public static void main(String[] args) throws Exception {
        new dl4jGANInsurance().GAN(args);
    }

    private void GAN(String[] args) throws Exception {
        for (int i = 0; i < args.length; i++) {
            System.out.println(args[i]);
        }

        if (useGpu) {
            System.out.println("Setting up CUDA environment!");
            Nd4j.setDataType(DataBuffer.Type.FLOAT);

            CudaEnvironment.getInstance().getConfiguration()
                    .allowMultiGPU(true)
                    .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)
                    .allowCrossDeviceAccess(true)
                    .setVerbose(true);
        }

        System.out.println(Nd4j.getBackend());
        Nd4j.getMemoryManager().setAutoGcWindow(5000);

        log.info("Unfrozen discriminator!");
        ComputationGraph dis = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                .activation(Activation.ELU)
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("dis_input_layer_0")
                .addLayer("dis_batch_layer_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .build(), "dis_input_layer_0")
                .addLayer("dis_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .nOut(100)
                        .build(),"dis_batch_layer_1")
                .addLayer("dis_dropout_layer_3", new DropoutLayer(),
                        "dis_dense_layer_2")
                .addLayer("dis_output_layer_4", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(100)
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "dis_dropout_layer_3")
                .setOutputs("dis_output_layer_4")
                .build());
        dis.init();
        System.out.println(dis.summary());
        System.out.println(Arrays.toString(dis.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        log.info("Frozen generator!");
        ComputationGraph gen = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .l2(0.0001)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .graphBuilder()
                .addInputs("gen_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gen_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .build(), "gen_input_layer_0")
                .addLayer("gen_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gen_batch_1")
                .addLayer("gen_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gen_dense_layer_2")
                .addLayer("gen_dense_layer_4", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gen_dense_layer_3")
                .addLayer("gen_dense_layer_5", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nIn(100)
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .activation(Activation.SIGMOID)
                        .build(), "gen_dense_layer_4")
                .setOutputs("gen_dense_layer_5")
                .build());
        gen.init();
        System.out.println(gen.summary());
        System.out.println(Arrays.toString(gen.output(Nd4j.randn(numGenSamples, zSize))[0].reshape(numGenSamples, tensorDimOneSize, tensorDimTwoSize, tensorDimThreeSize).shape()));

        log.info("GAN with unfrozen generator and frozen discriminator!");
        ComputationGraph gan = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .seed(numberOfTheBeast)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .l2(0.0001)
                .graphBuilder()
                .addInputs("gan_input_layer_0")
                .setInputTypes(InputType.feedForward(zSize))
                .addLayer("gan_batch_1", new BatchNormalization.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .build(), "gan_input_layer_0")
                .addLayer("gan_dense_layer_2", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gan_batch_1")
                .addLayer("gan_dense_layer_3", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gan_dense_layer_2")
                .addLayer("gan_dense_layer_4", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(100)
                        .build(), "gan_dense_layer_3")
                .addLayer("gan_dense_layer_5", new DenseLayer.Builder()
                        .updater(new RmsProp(gen_learning_rate, 1e-8, 1e-8))
                        .nOut(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .activation(Activation.SIGMOID)
                        .build(), "gan_dense_layer_4")

                .addLayer("gan_dis_batch_layer_6", new BatchNormalization.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .activation(Activation.ELU)
                        .build(), "gan_dense_layer_5")
                .addLayer("gan_dis_dense_layer_7", new DenseLayer.Builder()
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .activation(Activation.ELU)
                        .nIn(tensorDimOneSize * tensorDimTwoSize * tensorDimThreeSize)
                        .nOut(100)
                        .build(),"gan_dis_batch_layer_6")
                .addLayer("gan_dis_dropout_layer_8", new DropoutLayer(),
                        "gan_dis_dense_layer_7")
                .addLayer("gan_dis_output_layer_9", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(frozen_learning_rate, 1e-8, 1e-8))
                        .nOut(numClassesDis)
                        .activation(Activation.SIGMOID)
                        .build(), "gan_dis_dropout_layer_8")
                .setOutputs("gan_dis_output_layer_9")
                .build());
        gan.init();
        System.out.println(gan.summary());
        System.out.println(Arrays.toString(gan.output(Nd4j.randn(numGenSamples, zSize))[0].shape()));

        log.info("Setting up Spark configuration!");
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[4]");
        sparkConf.setAppName("Deeplearning4j on Apache Spark: Generative Adversarial Network!");
        sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
        sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        log.info("Setting up Synchronous Parameter Averaging!");
        TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker)
                .averagingFrequency(5)
                .rngSeed(numberOfTheBeast)
                .workerPrefetchNumBatches(0)
                .batchSizePerWorker(batchSizePerWorker)
                .build();

        SparkComputationGraph sparkDis = new SparkComputationGraph(sc, dis, tm);
        SparkComputationGraph sparkGan = new SparkComputationGraph(sc, gan, tm);

        log.info("Insurance deep learning model with pre-trained layers from the GAN's discriminator!");
        ComputationGraph insurance = new TransferLearning.GraphBuilder(sparkDis.getNetwork())
                .fineTuneConfiguration(new FineTuneConfiguration.Builder()
                        .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                        .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                        .gradientNormalizationThreshold(1.0)
                        .activation(Activation.ELU)
                        .l2(0.0001)
                        .weightInit(WeightInit.XAVIER)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .seed(numberOfTheBeast)
                        .build())
                .setFeatureExtractor("dis_dropout_layer_3")
                .removeVertexKeepConnections("dis_output_layer_4")
                .addLayer("dis_batch", new BatchNormalization.Builder()
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(100)
                        .nOut(100)
                        .build(), "dis_dropout_layer_3")
                .addLayer("dis_output_layer_4", new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                        .updater(new RmsProp(dis_learning_rate, 1e-8, 1e-8))
                        .nIn(100)
                        .nOut(numClasses)
                        .activation(Activation.SIGMOID)
                        .build(), "dis_batch")
                .build();
        System.out.println(insurance.summary());
        System.out.println(Arrays.toString(insurance.output(Nd4j.randn(numGenSamples, numFeatures))[0].shape()));

        SparkComputationGraph sparkInsurance = new SparkComputationGraph(sc, insurance, tm);

        RecordReader recordReaderTrain = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReaderTrain.initialize(new FileSplit(new ClassPathResource(dataSetName + "_train.csv").getFile()));

        DataSetIterator iterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSizePerWorker, labelIndex, numClasses);
        List<DataSet> trainDataList = new ArrayList<>();

        JavaRDD<DataSet> trainDataDis, trainDataGen, trainData;

        INDArray grid = Nd4j.linspace(-1.0, 1.0, numGenSamples);
        Collection<INDArray> z = new ArrayList<>();
        log.info("Creating some noise!");
        for (int i = 0; i < numGenSamples; i++) {
            for (int j = 0; j < numGenSamples; j++) {
                z.add(Nd4j.create(new double[]{grid.getDouble(0, i), grid.getDouble(0, j)}));
            }
        }

        int batch_counter = 0;

        DataSet trDataSet;

        RecordReader recordReaderTest = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReaderTest.initialize(new FileSplit(new ClassPathResource(dataSetName + "_test.csv").getFile()));

        DataSetIterator iterTest = new RecordReaderDataSetIterator(recordReaderTest, batchSizePred, labelIndex, numClasses);

        Collection<INDArray> outFeat;

        INDArray out, outPred;
        INDArray soften_labels_fake = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);
        INDArray soften_labels_real = Nd4j.randn(batchSizePerWorker, 1).muli(0.05);
        
        while (iterTrain.hasNext() && batch_counter < numIterations) {
            trainDataList.clear();
            trDataSet = iterTrain.next();

            // This is real data...
            // [Fake, Real].
            trainDataList.add(new DataSet(trDataSet.getFeatures(), Nd4j.ones(batchSizePerWorker, 1).addi(soften_labels_real)));

            // ...and this is fake data.
            // [Fake, Real].
            trainDataList.add(new DataSet(gen.output(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0))[0], Nd4j.zeros(batchSizePerWorker, 1).addi(soften_labels_fake)));

            // Unfrozen discriminator is trying to figure itself out given a frozen generator.
            log.info("Training discriminator!");
            trainDataDis = sc.parallelize(trainDataList);
            sparkDis.fit(trainDataDis);

            // Update GAN's frozen discriminator with unfrozen discriminator.
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_6").setParam("gamma", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("gamma"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_6").setParam("beta", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("beta"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_6").setParam("mean", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("mean"));
            sparkGan.getNetwork().getLayer("gan_dis_batch_layer_6").setParam("var", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("var"));

            sparkGan.getNetwork().getLayer("gan_dis_dense_layer_7").setParam("W", sparkDis.getNetwork().getLayer("dis_dense_layer_2").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_dense_layer_7").setParam("b", sparkDis.getNetwork().getLayer("dis_dense_layer_2").getParam("b"));

            sparkGan.getNetwork().getLayer("gan_dis_output_layer_9").setParam("W", sparkDis.getNetwork().getLayer("dis_output_layer_4").getParam("W"));
            sparkGan.getNetwork().getLayer("gan_dis_output_layer_9").setParam("b", sparkDis.getNetwork().getLayer("dis_output_layer_4").getParam("b"));

            trainDataList.clear();
            // Tell the frozen discriminator that all the fake examples are real examples.
            // [Fake, Real].
            trainDataList.add(new DataSet(Nd4j.rand(batchSizePerWorker, zSize).muli(2.0).subi(1.0), Nd4j.ones(batchSizePerWorker, 1)));

            // Unfrozen generator is trying to fool the frozen discriminator.
            log.info("Training generator!");
            trainDataGen = sc.parallelize(trainDataList);
            sparkGan.fit(trainDataGen);

            // Update frozen generator with GAN's unfrozen generator.
            gen.getLayer("gen_batch_1").setParam("gamma", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("gamma"));
            gen.getLayer("gen_batch_1").setParam("beta", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("beta"));
            gen.getLayer("gen_batch_1").setParam("mean", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("mean"));
            gen.getLayer("gen_batch_1").setParam("var", sparkGan.getNetwork().getLayer("gan_batch_1").getParam("var"));

            gen.getLayer("gen_dense_layer_2").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_2").getParam("W"));
            gen.getLayer("gen_dense_layer_2").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_2").getParam("b"));

            gen.getLayer("gen_dense_layer_3").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_3").getParam("W"));
            gen.getLayer("gen_dense_layer_3").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_3").getParam("b"));

            gen.getLayer("gen_dense_layer_4").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_4").getParam("W"));
            gen.getLayer("gen_dense_layer_4").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_4").getParam("b"));

            gen.getLayer("gen_dense_layer_5").setParam("W", sparkGan.getNetwork().getLayer("gan_dense_layer_5").getParam("W"));
            gen.getLayer("gen_dense_layer_5").setParam("b", sparkGan.getNetwork().getLayer("gan_dense_layer_5").getParam("b"));

            trainDataList.clear();
            trainDataList.add(trDataSet);

            log.info("Training insurance model!");
            sparkInsurance.getNetwork().getLayer("dis_batch_layer_1").setParam("gamma", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("gamma"));
            sparkInsurance.getNetwork().getLayer("dis_batch_layer_1").setParam("beta", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("beta"));
            sparkInsurance.getNetwork().getLayer("dis_batch_layer_1").setParam("mean", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("mean"));
            sparkInsurance.getNetwork().getLayer("dis_batch_layer_1").setParam("var", sparkDis.getNetwork().getLayer("dis_batch_layer_1").getParam("var"));

            sparkInsurance.getNetwork().getLayer("dis_dense_layer_2").setParam("W", sparkDis.getNetwork().getLayer("dis_dense_layer_2").getParam("W"));
            sparkInsurance.getNetwork().getLayer("dis_dense_layer_2").setParam("b", sparkDis.getNetwork().getLayer("dis_dense_layer_2").getParam("b"));

            trainData = sc.parallelize(trainDataList);
            sparkInsurance.fit(trainData);

            batch_counter++;
            log.info("Completed Batch {}!", batch_counter);

            if ((batch_counter % printEvery) == 0) {
                out = gen.output(Nd4j.vstack(z))[0].reshape(numGenSamples * numGenSamples, numFeatures);

                FileWriter fileWriter = new FileWriter(String.format("%s%s_out_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < out.shape()[0]; i++) {
                    for (int j = 0; j < out.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(out.getDouble(i, j)));
                        if (j != out.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != out.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();

                outPred = sparkInsurance.getNetwork().output(out)[0];

                fileWriter = new FileWriter(String.format("%s%s_out_pred_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < outPred.shape()[0]; i++) {
                    for (int j = 0; j < outPred.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(outPred.getDouble(i, j)));
                        if (j != outPred.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != outPred.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();
            }

            if ((batch_counter % saveEvery) == 0) {
                log.info("Ensemble of deep learners for estimation of uncertainty!");

                outFeat = new ArrayList<>();
                iterTest.reset();
                while (iterTest.hasNext()) {
                    outFeat.add(sparkInsurance.getNetwork().output(iterTest.next().getFeatures())[0]);
                }

                INDArray toWrite = Nd4j.vstack(outFeat);
                FileWriter fileWriter = new FileWriter(String.format("%s%s_test_predictions_%d.csv", resPath, dataSetName, batch_counter));
                for (int i = 0; i < toWrite.shape()[0]; i++) {
                    for (int j = 0; j < toWrite.shape()[1]; j++) {
                        fileWriter.append(String.valueOf(toWrite.getDouble(i, j)));
                        if (j != toWrite.shape()[1] - 1) {
                            fileWriter.append(delimiter);
                        }
                    }
                    if (i != toWrite.shape()[0] - 1) {
                        fileWriter.append(newLine);
                    }
                }
                fileWriter.flush();
                fileWriter.close();
            }

            if (!iterTrain.hasNext()) {
                iterTrain.reset();
            }
        }

        log.info("Saving models!");
        ModelSerializer.writeModel(sparkDis.getNetwork(), new File(resPath + dataSetName + "_dis_model.zip"), true);
        ModelSerializer.writeModel(sparkGan.getNetwork(), new File(resPath + dataSetName + "_gan_model.zip"), true);
        ModelSerializer.writeModel(gen, new File(resPath + dataSetName + "_gen_model.zip"), true);
        ModelSerializer.writeModel(sparkInsurance.getNetwork(), new File(resPath + dataSetName + "_insurance_model.zip"), true);

        tm.deleteTempFiles(sc);
    }
}
```

In [None]:
output_path = "/Users/samson/Projects/gan_deeplearning4j/outputs/insurance/"
out_feat = pd.read_csv(filepath_or_buffer=output_path + "insurance_test_predictions_1000.csv",
                       header=None)
auroc_dcgan_ins = roc_auc_score(y_true=y_test,
                                y_score=out_feat, 
                                average="weighted")
print("The AUROC score for the insurance classification task with DCGAN: %.6f%%." % (auroc_dcgan_ins * 100))

out = pd.read_csv(filepath_or_buffer=output_path + "insurance_out_1000.csv", header=None)
out_pred = pd.read_csv(filepath_or_buffer=output_path + "insurance_out_pred_1000.csv", header=None)

n = 50
figure = np.zeros((n_time_periods * n, n_transaction_types * n))
plot_preds = np.zeros((n, n))
counter = 0
for i in range(n):
    for j in range(n):
        figure[i * n_time_periods: (i + 1) * n_time_periods, j * n_transaction_types: (j + 1) * n_transaction_types] = out.iloc[counter].reshape(n_time_periods, n_transaction_types)
        plot_preds[i: (i + 1), j: (j + 1)] = out_pred.iloc[counter]
        counter = counter + 1
        
fig = plt.figure(figsize=(20, 30))
fig.add_subplot(2, 1, 1)
plt.imshow(plot_preds, cmap="hot_r")
plt.title("Deep Convolutional Generative Adversarial Network (DCGAN) with a 2-dimensional latent manifold for the insurance data\nPredicted probability of insurance loss: Darker means higher probability", fontsize=20)
plt.xlabel("Latent dimension 1", fontsize=24)
plt.ylabel("Latent dimension 2", fontsize=24)
fig.add_subplot(2, 1, 2)
plt.imshow(figure, cmap="Greys")
plt.title("Deep Convolutional Generative Adversarial Network (DCGAN) with a 2-dimensional latent manifold for the insurance data\nGenerating new transactions data on the 2-dimensional latent manifold", fontsize=20)
plt.xlabel("Latent dimension 1", fontsize=24)
plt.ylabel("Latent dimension 2", fontsize=24)
plt.savefig(fname="DCGAN_Generated_Lattices.png")

## Conclusion

---

We have shown how to use GANs to learn a good representation of raw data, i.e., 1 or 2 dimensional tensors per unit of analysis, that can then perhaps be used for supervised learning tasks in the domain of computer vision and insurance. This moves us away from manual handcrafted feature engineering towards automatic feature engineering, i.e., representation learning. GANs can perhaps be also used for semi-supervised learning which will be the topic of another paper.

## References

---

1. Goodfellow, I., Bengio, Y. and Courville A. (2016). Deep Learning (MIT Press).
2. Geron, A. (2017). Hands-On Machine Learning with Scikit-Learn & Tensorflow (O'Reilly).
3. Radford, A., Luke, M. and Chintala, S. (2015). Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (https://arxiv.org/abs/1511.06434).
4. Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., Bengio, Y. (2014). Generative Adversarial Networks (https://arxiv.org/abs/1406.2661).
5. http://scikit-learn.org/stable/#
6. https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1
7. https://stackoverflow.com/questions/42177658/how-to-switch-backend-with-keras-from-tensorflow-to-theano
8. https://blog.keras.io/building-autoencoders-in-keras.html
9. https://keras.io
10. https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py#L24
11. https://en.wikipedia.org/wiki/Kullback–Leibler_divergence
12. https://see.stanford.edu/materials/lsocoee364b/01-subgradients_notes.pdf
13. https://blog.skymind.ai/distributed-deep-learning-part-1-an-introduction-to-distributed-training-of-neural-networks/
14. https://deeplearning4j.org
15. https://github.com/hamaadshah/gan_keras
16. https://towardsdatascience.com/automatic-feature-engineering-using-generative-adversarial-networks-8e24b3c16bf3