# Lab 8: (Single-Layer) Neural Networks and Gradient Descent (50 Pts)

In this lab, we are going to implement gradient descent to train a simple (single-layer) neural network. We will use the Jax software library for doing this:
https://github.com/google/jax

Before we begin, we will first install Jax. To do this, run the following commands in your terminal:

`pip3 install --upgrade pip`

`pip install --user --upgrade jax jaxlib`

The main functionality that Jax offers is the ability to automatically compute gradients (so you don't have to do this by hand!). You can write mathematical functions using Numpy syntax and Jax will allow  you to compute gradients. As an example, take a look at the block of code below.


In [131]:
from jax import grad
import jax.numpy as np

# Define a function using numpy operations 
def my_function(w,x,y):
    return np.power(w,2)*x + y

# Gradient function (partial derivative of my_function w.r.t. the first argument, i.e., w)
gradient_fun = grad(my_function) 

# Evaluate the gradient (partial derivative) of my_function with respect to w at w=w0, x=x0, y=y0
w0 = 0.5
x0 = 0.2
y0 = 0.3
print(gradient_fun(w0, x0, y0))

0.2


You can manually verify that `gradient_fun(x0, x0, y0)` evaluates the gradient of `my_function` w.r.t. the first argument of `my_function` (which is `w`) at `w0, x0, y0`. The beauty of Jax is that `my_function` can be extremely complicated; manually computing the gradient might be very annoying, but Jax saves you the trouble. 

Notice also that we are not actually using numpy above. We are using jax.numpy as if it were numpy. Jax has overloaded a large number of numpy operations. As long as `my_function` uses these basic numpy operations, Jax will allow you to compute gradients. You can take a look at the Jax documentation for all the numpy functions Jax has overloaded (but we won't need anything fancy for this assignment -- you can just pretend like jax.numpy is numpy).

Now, we will train a single-layer neural network using gradient descent. Your task is to fill in the portions marked "TODO". 

First, let's define some helper functions.

In [132]:
from jax import grad, jit, ops
import jax.numpy as np

# Sigmoid nonlinearity (i.e., activation)
def sigmoid(z):
    return 0.5 * (np.tanh(z / 2.) + 1)

######## TODO: Fill in these functions ################################

# Single-layer neural network (with no bias) [10 pts]
def neural_network_prediction(weights, x):
    '''This function should implement a single-layer neural network. We will
    ignore the bias term here to keep things simple. This function
    should take in an input (x) and output sigmoid(w'*x). Recall that
    matrix multiplication in numpy can be done using np.dot().'''
    return 

# Loss function: Binary Cross Entropy [5 pts]
def loss_function(y_pred, y_true):
    '''Loss function: this function takes in a predicted label (scalar)
    and a true label (scalar) and outputs a scalar that quantifies how
    good the prediction is. Implement the binary cross entropy loss
    discussed in Lecture 20.'''
    return 

# Training loss [20 pts]
def training_loss(weights, inputs, labels):
    '''Compute the total training loss here (i.e., the loss summed
    over all the inputs in the dataset)'''   
    return 

#######################################################################

Next, we will define some training data (inputs (x) and labels (y)). Each column of `inputs` corresponds to a datapoint (i.e., a single x). The array `labels` contains the corresponding labels.

In [None]:
############ DO NOT MODIFY ############################################
# Build a toy dataset
inputs = np.array([[1.0,  0.0, 0.25, -1.0,  0.0, -0.25],
                   [0.0,  1.0, 0.3,   0.0, -1.0, -0.4]])

labels = np.array([1, 1, 1, 0, 0, 0])

Finally, we will optimize our weights to minimize the training loss using gradient descent.

In [None]:
############ DO NOT MODIFY #############################################
# Define a function that returns gradients of the training loss.
# This is identical to the simple gradient example we considered at the
# very beginning of this notebook. The "jit" command does "just-in-time
# compilation". Don't worry too much about what jit is. This helps speed
# up the code. 
training_gradient_fun = jit(grad(training_loss))
########################################################################

# Optimize weights using gradient descent
weights = np.ones(2) # Initialize the weights to the all-ones vector (this doesn't really matter. Random
# initializations would be ok too.)
# Print the training loss with the initial weights
print("Initial loss:", training_loss(weights, inputs, labels))

########### TODO: Fill in code to do gradient descent [15 pts] ########
step_size = # Choose a step size
# Take gradient steps
for i in range(100):
    # Write gradient descent step here


    
# Print the training loss with the optimized weights    
print("Trained loss:", training_loss(weights, inputs, labels))

Finally, the block below visualizes the outputs of the trained neural network. Note that the true labels are either 1 (plus) or 0 (dot) and the output of your single-layer neural network will also be between 0 and 1. The contour plot below shows level-sets of the neural network's predictions. The level set corresponding to 0.5 is the "decision boundary". If trained correctly, the 0.5-level set should separate the points with label equals to 1 from the points with labels equal to 0.  

In [None]:
############ DO NOT MODIFY #############################################
import matplotlib.pyplot as plt
%matplotlib notebook
import numpy as npp

fig, ax = plt.subplots()
plt.scatter(inputs[0,0:3], inputs[1,0:3], c='r',marker='+')
plt.scatter(inputs[0,3:6], inputs[1,3:6], c='b',marker='.')

delta = 0.2
xs = np.arange(-1.5, 1.5, delta)
ys = np.arange(-1.5, 1.5, delta)
X, Y = np.meshgrid(xs, ys)

Z = npp.zeros(np.shape(X))
for i in range(0,len(xs)):
    for j in range(0,len(xs)):
        xi = np.transpose(np.array([X[i,j], Y[i,j]]))
        Z[i,j] = neural_network_prediction(weights, xi)


CS = ax.contour(X, Y, Z) 

ax.clabel(CS, inline=1, fontsize=10)
ax.set_title('Neural network predictions')

# Submission Instructions

As usual, you can submit your completed notebook [here](https://www.dropbox.com/request/L7DAefVsTWLLd1sBniQJ).