# Training a Neural Network
In this example, we'll be training a neural network using particle swarm optimization. For this we'll be using the standard global-best PSO `pyswarms.single.GBestPSO` for optimizing the network's weights and biases. This aims to demonstrate how the API is capable of handling custom-defined functions.

For this example, we'll try to classify the three iris species in the Iris Dataset.

In [1]:
import sys
# Change directory to access the pyswarms module
sys.path.append('../')

In [2]:
print('Running on Python version: {}'.format(sys.version))

Running on Python version: 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) 
[GCC 7.3.0]


In [3]:
# Import modules
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris


# Import PySwarms
import pyswarms as ps

# Some more magic so that the notebook will reload external python modules;
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

First, we'll load the dataset from `scikit-learn`. The Iris Dataset contains 3 classes for each of the iris species (_iris setosa_, _iris virginica_, and _iris versicolor_). It has 50 samples per class with 150 samples in total, making it a very balanced dataset. Each sample is characterized by four features (or dimensions): sepal length, sepal width, petal length, petal width.

In [4]:
# Load the iris dataset
data = load_iris()

# Store the features as X and the labels as y
X = data.data
y = data.target

## Constructing a custom objective function
Recall that neural networks can simply be seen as a mapping function from one space to another. For now, we'll build a simple neural network with the following characteristics:
* Input layer size: 4
* Hidden layer size: 20 (activation: $\tanh(x)$)
* Output layer size: 3 (activation: $softmax(x)$)

Things we'll do:
1. Create a `forward_prop` method that will do forward propagation for one particle.
2. Create an overhead objective function `f()` that will compute `forward_prop()` for the whole swarm.

What we'll be doing then is to create a swarm with a number of dimensions equal to the weights and biases. We will __unroll__ these parameters into an n-dimensional array, and have each particle take on different values. Thus, each particle represents a candidate neural network with its own weights and bias. When feeding back to the network, we will reconstruct the learned weights and biases. 

When rolling-back the parameters into weights and biases, it is useful to recall the shape and bias matrices:
* Shape of input-to-hidden weight matrix: (4, 20)
* Shape of input-to-hidden bias array: (20, )
* Shape of hidden-to-output weight matrix: (20, 3)
* Shape of hidden-to-output bias array: (3, )

By unrolling them together, we have $(4 * 20) + (20 * 3) + 20 + 3 = 163$ parameters, or 163 dimensions for each particle in the swarm.


The negative log-likelihood will be used to compute for the error between the ground-truth values and the predictions. Also, because PSO doesn't rely on the gradients, we'll not be performing backpropagation (this may be a good thing or bad thing under some circumstances).

Now, let's write the forward propagation procedure as our objective function. Let $X$ be the input, $z_l$ the pre-activation at layer $l$, and $a_l$ the activation for layer $l$:

In [5]:
# Forward propagation
def forward_prop(params):
    """Forward propagation as objective function
    
    This computes for the forward propagation of the neural network, as
    well as the loss. It receives a set of parameters that must be 
    rolled-back into the corresponding weights and biases.
    
    Inputs
    ------
    params: np.ndarray
        The dimensions should include an unrolled version of the 
        weights and biases.
        
    Returns
    -------
    float
        The computed negative log-likelihood loss given the parameters
    """
    # Neural network architecture
    n_inputs = 4
    n_hidden = 20
    n_classes = 3
    
    # Roll-back the weights and biases
    W1 = params[0:80].reshape((n_inputs,n_hidden))
    b1 = params[80:100].reshape((n_hidden,))
    W2 = params[100:160].reshape((n_hidden,n_classes))
    b2 = params[160:163].reshape((n_classes,))
    
    # Perform forward propagation
    z1 = X.dot(W1) + b1  # Pre-activation in Layer 1
    a1 = np.tanh(z1)     # Activation in Layer 1
    z2 = a1.dot(W2) + b2 # Pre-activation in Layer 2
    logits = z2          # Logits for Layer 2
    
    # Compute for the softmax of the logits
    exp_scores = np.exp(logits)
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) 
    
    # Compute for the negative log likelihood
    N = 150 # Number of samples
    corect_logprobs = -np.log(probs[range(N), y])
    loss = np.sum(corect_logprobs) / N
    
    return loss


Now that we have a method to do forward propagation for one particle (or for one set of dimensions), we can then create a higher-level method to compute `forward_prop()` to the whole swarm:

In [6]:
def f(x):
    """Higher-level method to do forward_prop in the 
    whole swarm.
    
    Inputs
    ------
    x: numpy.ndarray of shape (n_particles, dimensions)
        The swarm that will perform the search
        
    Returns
    -------
    numpy.ndarray of shape (n_particles, )
        The computed loss for each particle
    """
    n_particles = x.shape[0]
    j = [forward_prop(x[i]) for i in range(n_particles)]
    return np.array(j)
    

## Performing PSO on the custom-function
Now that everything has been set-up, we just call our global-best PSO and run the optimizer as usual. For now, we'll just set the PSO parameters arbitrarily.

In [8]:
%%time
# Initialize swarm
options = {'c1': 0.5, 'c2': 0.3, 'w':0.9}

# Call instance of PSO
dimensions = (4 * 20) + (20 * 3) + 20 + 3 
optimizer = ps.single.GlobalBestPSO(n_particles=100, dimensions=dimensions, options=options)

# Perform optimization
cost, pos = optimizer.optimize(f, iters=1000)

2019-02-21 00:27:43,567 - pyswarms.single.global_best - INFO - Optimize for 1000 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}

pyswarms.single.global_best:   0%|          |0/1000[A
pyswarms.single.global_best:   0%|          |0/1000, best_cost=1.11[A
pyswarms.single.global_best:   0%|          |0/1000, best_cost=1.11[A
pyswarms.single.global_best:   0%|          |0/1000, best_cost=1.11[A
pyswarms.single.global_best:   0%|          |3/1000, best_cost=1.11[A
pyswarms.single.global_best:   0%|          |3/1000, best_cost=1.11[A
pyswarms.single.global_best:   0%|          |3/1000, best_cost=1.1 [A
pyswarms.single.global_best:   0%|          |3/1000, best_cost=1.1[A
pyswarms.single.global_best:   1%|          |6/1000, best_cost=1.1[A
pyswarms.single.global_best:   1%|          |6/1000, best_cost=1.02[A
pyswarms.single.global_best:   1%|          |6/1000, best_cost=0.858[A
pyswarms.single.global_best:   1%|          |6/1000, best_cost=0.858[A
pyswarms.single.global_best:   1%|    

pyswarms.single.global_best:   8%|▊         |81/1000, best_cost=0.185[A
pyswarms.single.global_best:   8%|▊         |84/1000, best_cost=0.185[A
pyswarms.single.global_best:   8%|▊         |84/1000, best_cost=0.169[A
pyswarms.single.global_best:   8%|▊         |84/1000, best_cost=0.166[A
pyswarms.single.global_best:   8%|▊         |84/1000, best_cost=0.162[A
pyswarms.single.global_best:   9%|▊         |87/1000, best_cost=0.162[A
pyswarms.single.global_best:   9%|▊         |87/1000, best_cost=0.151[A
pyswarms.single.global_best:   9%|▊         |87/1000, best_cost=0.151[A
pyswarms.single.global_best:   9%|▊         |87/1000, best_cost=0.132[A
pyswarms.single.global_best:   9%|▉         |90/1000, best_cost=0.132[A
pyswarms.single.global_best:   9%|▉         |90/1000, best_cost=0.114[A
pyswarms.single.global_best:   9%|▉         |90/1000, best_cost=0.111[A
pyswarms.single.global_best:   9%|▉         |90/1000, best_cost=0.111[A
pyswarms.single.global_best:   9%|▉         |93/100

pyswarms.single.global_best:  16%|█▋        |165/1000, best_cost=0.0442[A
pyswarms.single.global_best:  16%|█▋        |165/1000, best_cost=0.0442[A
pyswarms.single.global_best:  16%|█▋        |165/1000, best_cost=0.0442[A
pyswarms.single.global_best:  16%|█▋        |165/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |168/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |168/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |168/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |168/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |171/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |171/1000, best_cost=0.0439[A
pyswarms.single.global_best:  17%|█▋        |171/1000, best_cost=0.0438[A
pyswarms.single.global_best:  17%|█▋        |171/1000, best_cost=0.0438[A
pyswarms.single.global_best:  17%|█▋        |174/1000, best_cost=0.0438[A
pyswarms.single.global_be

pyswarms.single.global_best:  25%|██▍       |246/1000, best_cost=0.0403[A
pyswarms.single.global_best:  25%|██▍       |246/1000, best_cost=0.0403[A
pyswarms.single.global_best:  25%|██▍       |246/1000, best_cost=0.0403[A
pyswarms.single.global_best:  25%|██▍       |249/1000, best_cost=0.0403[A
pyswarms.single.global_best:  25%|██▍       |249/1000, best_cost=0.0403[A
pyswarms.single.global_best:  25%|██▍       |249/1000, best_cost=0.0402[A
pyswarms.single.global_best:  25%|██▍       |249/1000, best_cost=0.0402[A
pyswarms.single.global_best:  25%|██▌       |252/1000, best_cost=0.0402[A
pyswarms.single.global_best:  25%|██▌       |252/1000, best_cost=0.0402[A
pyswarms.single.global_best:  25%|██▌       |252/1000, best_cost=0.0402[A
pyswarms.single.global_best:  25%|██▌       |252/1000, best_cost=0.0402[A
pyswarms.single.global_best:  26%|██▌       |255/1000, best_cost=0.0402[A
pyswarms.single.global_best:  26%|██▌       |255/1000, best_cost=0.0402[A
pyswarms.single.global_be

pyswarms.single.global_best:  33%|███▎      |327/1000, best_cost=0.0389[A
pyswarms.single.global_best:  33%|███▎      |327/1000, best_cost=0.0389[A
pyswarms.single.global_best:  33%|███▎      |330/1000, best_cost=0.0389[A
pyswarms.single.global_best:  33%|███▎      |330/1000, best_cost=0.0388[A
pyswarms.single.global_best:  33%|███▎      |330/1000, best_cost=0.0388[A
pyswarms.single.global_best:  33%|███▎      |330/1000, best_cost=0.0388[A
pyswarms.single.global_best:  33%|███▎      |333/1000, best_cost=0.0388[A
pyswarms.single.global_best:  33%|███▎      |333/1000, best_cost=0.0387[A
pyswarms.single.global_best:  33%|███▎      |333/1000, best_cost=0.0387[A
pyswarms.single.global_best:  33%|███▎      |333/1000, best_cost=0.0387[A
pyswarms.single.global_best:  34%|███▎      |336/1000, best_cost=0.0387[A
pyswarms.single.global_best:  34%|███▎      |336/1000, best_cost=0.0386[A
pyswarms.single.global_best:  34%|███▎      |336/1000, best_cost=0.0386[A
pyswarms.single.global_be

pyswarms.single.global_best:  41%|████      |408/1000, best_cost=0.0373[A
pyswarms.single.global_best:  41%|████      |411/1000, best_cost=0.0373[A
pyswarms.single.global_best:  41%|████      |411/1000, best_cost=0.0373[A
pyswarms.single.global_best:  41%|████      |411/1000, best_cost=0.0372[A
pyswarms.single.global_best:  41%|████      |411/1000, best_cost=0.0372[A
pyswarms.single.global_best:  41%|████▏     |414/1000, best_cost=0.0372[A
pyswarms.single.global_best:  41%|████▏     |414/1000, best_cost=0.0372[A
pyswarms.single.global_best:  41%|████▏     |414/1000, best_cost=0.0372[A
pyswarms.single.global_best:  41%|████▏     |414/1000, best_cost=0.0372[A
pyswarms.single.global_best:  42%|████▏     |417/1000, best_cost=0.0372[A
pyswarms.single.global_best:  42%|████▏     |417/1000, best_cost=0.0372[A
pyswarms.single.global_best:  42%|████▏     |417/1000, best_cost=0.0371[A
pyswarms.single.global_best:  42%|████▏     |417/1000, best_cost=0.0371[A
pyswarms.single.global_be

pyswarms.single.global_best:  49%|████▉     |492/1000, best_cost=0.0359[A
pyswarms.single.global_best:  49%|████▉     |492/1000, best_cost=0.0359[A
pyswarms.single.global_best:  49%|████▉     |492/1000, best_cost=0.0359[A
pyswarms.single.global_best:  49%|████▉     |492/1000, best_cost=0.0358[A
pyswarms.single.global_best:  50%|████▉     |495/1000, best_cost=0.0358[A
pyswarms.single.global_best:  50%|████▉     |495/1000, best_cost=0.0358[A
pyswarms.single.global_best:  50%|████▉     |495/1000, best_cost=0.0357[A
pyswarms.single.global_best:  50%|████▉     |495/1000, best_cost=0.0357[A
pyswarms.single.global_best:  50%|████▉     |498/1000, best_cost=0.0357[A
pyswarms.single.global_best:  50%|████▉     |498/1000, best_cost=0.0356[A
pyswarms.single.global_best:  50%|████▉     |498/1000, best_cost=0.0356[A
pyswarms.single.global_best:  50%|████▉     |498/1000, best_cost=0.0355[A
pyswarms.single.global_best:  50%|█████     |501/1000, best_cost=0.0355[A
pyswarms.single.global_be

pyswarms.single.global_best:  57%|█████▋    |573/1000, best_cost=0.0319[A
pyswarms.single.global_best:  57%|█████▋    |573/1000, best_cost=0.0318[A
pyswarms.single.global_best:  57%|█████▋    |573/1000, best_cost=0.0318[A
pyswarms.single.global_best:  58%|█████▊    |576/1000, best_cost=0.0318[A
pyswarms.single.global_best:  58%|█████▊    |576/1000, best_cost=0.0318[A
pyswarms.single.global_best:  58%|█████▊    |576/1000, best_cost=0.0318[A
pyswarms.single.global_best:  58%|█████▊    |576/1000, best_cost=0.0316[A
pyswarms.single.global_best:  58%|█████▊    |579/1000, best_cost=0.0316[A
pyswarms.single.global_best:  58%|█████▊    |579/1000, best_cost=0.0315[A
pyswarms.single.global_best:  58%|█████▊    |579/1000, best_cost=0.0315[A
pyswarms.single.global_best:  58%|█████▊    |579/1000, best_cost=0.0315[A
pyswarms.single.global_best:  58%|█████▊    |582/1000, best_cost=0.0315[A
pyswarms.single.global_best:  58%|█████▊    |582/1000, best_cost=0.0315[A
pyswarms.single.global_be

pyswarms.single.global_best:  65%|██████▌   |654/1000, best_cost=0.0302[A
pyswarms.single.global_best:  65%|██████▌   |654/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |657/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |657/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |657/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |657/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |660/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |660/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |660/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▌   |660/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▋   |663/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▋   |663/1000, best_cost=0.0302[A
pyswarms.single.global_best:  66%|██████▋   |663/1000, best_cost=0.0302[A
pyswarms.single.global_be

pyswarms.single.global_best:  74%|███████▎  |735/1000, best_cost=0.029 [A
pyswarms.single.global_best:  74%|███████▍  |738/1000, best_cost=0.029[A
pyswarms.single.global_best:  74%|███████▍  |738/1000, best_cost=0.0289[A
pyswarms.single.global_best:  74%|███████▍  |738/1000, best_cost=0.0288[A
pyswarms.single.global_best:  74%|███████▍  |738/1000, best_cost=0.0288[A
pyswarms.single.global_best:  74%|███████▍  |741/1000, best_cost=0.0288[A
pyswarms.single.global_best:  74%|███████▍  |741/1000, best_cost=0.0288[A
pyswarms.single.global_best:  74%|███████▍  |741/1000, best_cost=0.0287[A
pyswarms.single.global_best:  74%|███████▍  |741/1000, best_cost=0.0285[A
pyswarms.single.global_best:  74%|███████▍  |744/1000, best_cost=0.0285[A
pyswarms.single.global_best:  74%|███████▍  |744/1000, best_cost=0.0284[A
pyswarms.single.global_best:  74%|███████▍  |744/1000, best_cost=0.0283[A
pyswarms.single.global_best:  74%|███████▍  |744/1000, best_cost=0.0283[A
pyswarms.single.global_bes

pyswarms.single.global_best:  82%|████████▏ |819/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |819/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |819/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |819/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |822/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |822/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |822/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▏ |822/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▎ |825/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▎ |825/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▎ |825/1000, best_cost=0.0261[A
pyswarms.single.global_best:  82%|████████▎ |825/1000, best_cost=0.0261[A
pyswarms.single.global_best:  83%|████████▎ |828/1000, best_cost=0.0261[A
pyswarms.single.global_be

pyswarms.single.global_best:  90%|█████████ |900/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |900/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |900/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |903/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |903/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |903/1000, best_cost=0.0254[A
pyswarms.single.global_best:  90%|█████████ |903/1000, best_cost=0.0254[A
pyswarms.single.global_best:  91%|█████████ |906/1000, best_cost=0.0254[A
pyswarms.single.global_best:  91%|█████████ |906/1000, best_cost=0.0253[A
pyswarms.single.global_best:  91%|█████████ |906/1000, best_cost=0.0253[A
pyswarms.single.global_best:  91%|█████████ |906/1000, best_cost=0.0253[A
pyswarms.single.global_best:  91%|█████████ |909/1000, best_cost=0.0253[A
pyswarms.single.global_best:  91%|█████████ |909/1000, best_cost=0.0253[A
pyswarms.single.global_be

pyswarms.single.global_best:  98%|█████████▊|980/1000, best_cost=0.0221[A
pyswarms.single.global_best:  98%|█████████▊|983/1000, best_cost=0.0221[A
pyswarms.single.global_best:  98%|█████████▊|983/1000, best_cost=0.022 [A
pyswarms.single.global_best:  98%|█████████▊|983/1000, best_cost=0.0218[A
pyswarms.single.global_best:  98%|█████████▊|983/1000, best_cost=0.0217[A
pyswarms.single.global_best:  99%|█████████▊|986/1000, best_cost=0.0217[A
pyswarms.single.global_best:  99%|█████████▊|986/1000, best_cost=0.0217[A
pyswarms.single.global_best:  99%|█████████▊|986/1000, best_cost=0.0215[A
pyswarms.single.global_best:  99%|█████████▊|986/1000, best_cost=0.0214[A
pyswarms.single.global_best:  99%|█████████▉|989/1000, best_cost=0.0214[A
pyswarms.single.global_best:  99%|█████████▉|989/1000, best_cost=0.0213[A
pyswarms.single.global_best:  99%|█████████▉|989/1000, best_cost=0.0211[A
pyswarms.single.global_best:  99%|█████████▉|989/1000, best_cost=0.021 [A
pyswarms.single.global_be

CPU times: user 29.2 s, sys: 900 ms, total: 30.1 s
Wall time: 39.1 s


## Checking the accuracy
We can then check the accuracy by performing forward propagation once again to create a set of predictions. Then it's only a simple matter of matching which one's correct or not. For the `logits`, we take the `argmax`. Recall that the softmax function returns probabilities where the whole vector sums to 1. We just take the one with the highest probability then treat it as the network's prediction.

Moreover, we let the best position vector found by the swarm be the weight and bias parameters of the network.

In [9]:
def predict(X, pos):
    """
    Use the trained weights to perform class predictions.
    
    Inputs
    ------
    X: numpy.ndarray
        Input Iris dataset
    pos: numpy.ndarray
        Position matrix found by the swarm. Will be rolled
        into weights and biases.
    """
    # Neural network architecture
    n_inputs = 4
    n_hidden = 20
    n_classes = 3
    
    # Roll-back the weights and biases
    W1 = pos[0:80].reshape((n_inputs,n_hidden))
    b1 = pos[80:100].reshape((n_hidden,))
    W2 = pos[100:160].reshape((n_hidden,n_classes))
    b2 = pos[160:163].reshape((n_classes,))
    
    # Perform forward propagation
    z1 = X.dot(W1) + b1  # Pre-activation in Layer 1
    a1 = np.tanh(z1)     # Activation in Layer 1
    z2 = a1.dot(W2) + b2 # Pre-activation in Layer 2
    logits = z2          # Logits for Layer 2
    
    y_pred = np.argmax(logits, axis=1)
    return y_pred

And from this we can just compute for the accuracy. We perform predictions, compare an equivalence to the ground-truth value `y`, and get the mean.

In [10]:
(predict(X, pos) == y).mean()

0.9866666666666667