<a href="https://colab.research.google.com/github/benwtks/machine-learning/blob/master/lab2a.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import jax.numpy as jnp
from jax import grad
from jax import random

# Linear function for score

The score for class $k$ at data point $x$ is given by $$a_k = w_{k0} + \sum_j w_{kj}x_j=w_0+\mathbf{w}_k^\top\mathbf{x}$$

The predicted probability is $$\hat{y}_k = \exp(a_k)/\sum_i\exp(a_i)$$

## For loop implementation of probability

In [None]:
def softmax_prob_forloop(W, b, inputs): # output is datalen-by-C (NumPy, no JAX here)
    # W is (C-by-dim) of weights
    # b is C-dimensional vector of biases
    # inputs is dim-by-datalen
    dim, datalen = np.shape(inputs) # how many dimensions, points
    c = len(b) # number of classes, C, each class has a bias 
    score = np.zeros((c, datalen))
    for ci in range(c):
        for lj in range(datalen):
            score[ci, lj] = b[ci]
            for dk in range(dim):
                score[ci, lj] += W[ci, dk]*inputs[dk, lj]
    maxes = np.zeros(datalen)
    for lj in range(datalen):
        maxes[lj] = np.max(score[:, lj])
    for ci in range(c):
        for lj in range(datalen):
            score[ci, lj] = score[ci, lj] - maxes[lj]
    # subtract off the largest score from the bias of each class 
    # This is for stability to underflow/overflow when exponentiating
    expscore = np.exp(score)
    norm_factor = np.diag(1/np.sum(expscore, axis=0))
    return np.dot(expscore, norm_factor).T  

In [None]:
W = np.linspace(1,5,20)
print(W)

[1.         1.21052632 1.42105263 1.63157895 1.84210526 2.05263158
 2.26315789 2.47368421 2.68421053 2.89473684 3.10526316 3.31578947
 3.52631579 3.73684211 3.94736842 4.15789474 4.36842105 4.57894737
 4.78947368 5.        ]


## Vector implementation of probability (w/ JAX)

In [None]:
def softmax_prob1(W, b, inputs):  # output is datalen-by-C
    # inputs is dim-by-datalen
    # b is C-dimensional vector W is (C-by-dim)
    # Make sure all numerical operations are from JAX, so 'jnp', not 'np'
    datalen = jnp.shape(inputs)[1] # how many points
    c = len(b) # number of classes, C, each class has a bias 
    linear_part = jnp.dot(W, inputs) # (C-by-dim)*(dim-by-datalen) = C-by-datalen
    large = jnp.max(linear_part, axis=0) # largest of the class scores for each data point
    bias_offset = jnp.dot(jnp.diag(b),jnp.ones((c, datalen))) # (C-by-C)*(C-by-L)
    # subtract off the largest score from the bias of each class for stability to underflow/overflow
    large_offset = jnp.dot(np.ones((c, datalen)),jnp.diag(large)) #  (C-by-L)*(L-by-L)    
    expscore = jnp.exp(linear_part + bias_offset - large_offset)
    norm_factor = jnp.diag(1/jnp.sum(expscore, axis=0))
    return jnp.dot(expscore, norm_factor).T 

In what follows, the trick of setting the zeroth feature to be 1 is used to absorb the constant  $w0$  into the dot product. Redefine the input data to be $$x=(x1,…,xp)⟶x=(1,x1,…,xp).$$

Correspondingly redefining the weight vectors to be $\mathbf{w}=(w_0, w_1, \ldots, w_p)$

In [None]:
def softmax_prob(W, inputs):  
    # output is datalen-by-C
    # inputs is (dim)-by-datalen
    # W is C-by-(dim+1)
    # Make sure all numerical operations are from JAX, so 'jnp', not 'np'
    datalen = jnp.shape(inputs)[1] # how many points
    c = len(W) # number of classes, C, each class has a bias
    inputs = jnp.concatenate((jnp.ones((1,datalen)), inputs), axis=0)
    # create inputs (dim+1)-by-datalen 
    score = jnp.dot(W,inputs) 
    # (C-by-(1+dim))*((1+dim)-by-datalen) = C-by-datalen
    large = jnp.max(score, axis=0) # largest of the class scores for each data point
    # subtract off the largest score from the bias of each class for stability to underflow/overflow
    large_offset = jnp.dot(np.ones((c, datalen)),jnp.diag(large)) #  (C-by-L)*(L-by-L)    
    expscore = jnp.exp(score  - large_offset)
    norm_factor = jnp.diag(1/jnp.sum(expscore, axis=0))
    return jnp.dot(expscore, norm_factor).T  

In [None]:
Wb = jnp.array([[-3., 1.3, 2.0, -1.0], [-6., -2., -3., 1.5], [1., 2.0, 2.0, 2.5], [3., 4.0, 4.0, -2.5]])
# Build a toy dataset: 6 3-dim points with C=4  targets dim-by-datalen
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [3.82, -6.11, 3.15],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39],
                   [0.14, -0.43, -1.69]]).T # transpose to make it a dim-by-datalen array
targets = jnp.array([0, 1, 3, 2, 1, 2])

In [None]:
# Initialize random model coefficients
key = random.PRNGKey(0)
key, W_key= random.split(key, 2)
[classes, dim] = 4, 3
Winit = random.normal(W_key, (classes, dim+1))
print(Winit)

# Automatic differentiation

In [None]:
def softmax_xentropy(Wb, inputs, targets, num_classes):
    epsilon = 1e-8
    ys = get_one_hot(targets, num_classes)
    logprobs = -jnp.log(softmax_prob(Wb, inputs)+epsilon)
    return jnp.mean(ys*logprobs)

In [None]:
def get_one_hot(targets, num_classes):
    res = jnp.eye(num_classes)[jnp.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape)+[num_classes])

In [None]:
def grad_descent(Wb, inputs, targets, num_classes,  lrate, nsteps):
    W1 = Wb
    Whist = [W1]
    losshist = [softmax_xentropy(W1,inputs, targets, num_classes )]
    eta = lrate # learning rate
    for i in range(nsteps):        
        gWb = grad(softmax_xentropy, (0))(W1, inputs, targets, num_classes)
        W1 = W1 - eta*gWb
        if (i%5 ==0):
            Whist.append(W1)
            losshist.append(softmax_xentropy(W1, inputs, targets, num_classes))
    Whist.append(W1)
    losshist.append(softmax_xentropy(W1, inputs, targets, num_classes))    
    return W1, Whist, losshist

In [None]:
W2, Whist, losshist = grad_descent(Winit, inputs, targets, 4, 0.75, 200)

## Loss history

In [None]:
plt.plot([5*i for i in range(len(losshist))], losshist)