# Q1. ax, Autograd and Neural Networks (40 points + 14 Bonus points)

## (a) Univariate Function Gradient (5 points)

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

# Q1a.  Example for taking the derivative of a simple function

def fun(x):
    return jnp.cos(x**3 - 2*x**2 + 5)

def manual_grad(x):
    # Implement the derivative w.r.t. x using the chain rule
    # f(x) = cos(u(x)) where u(x) = x^3 - 2*x^2 + 5
    # f'(x) = -sin(u(x)) * u'(x)
    u = x**3 - 2*x**2 + 5
    du_dx = 3*x**2 - 4*x
    gradient = -jnp.sin(u) * du_dx
    return gradient

grad_fun = grad(fun)

x = 3.5

print("Gradient:", grad_fun(x),  manual_grad(x))


Gradient: 22.353619 22.353619


## (b) Multivariate Function Gradient (5 points)

In [2]:
import jax.numpy as jnp
from jax import grad

def cubic_l2_norm(x):
    return jnp.linalg.norm(x)**3

# Automatic differentiation gradient using JAX
grad_fun = grad(cubic_l2_norm)

x = jnp.array([5.0, 0.3, 2.0])

def manual_grad(x):
    # Compute the L2 norm
    r = jnp.linalg.norm(x)
    # Applying the chain rule:
    # f(x) = (r)^3, so f'(r) = 3 * r^2 and ∇r = x / r.
    # Thus, gradient = 3 * r^2 * (x / r) = 3 * r * x.
    gradient = 3 * r * x
    return gradient

print("Gradient:", grad_fun(x), manual_grad(x))


Gradient: [80.902725  4.854164 32.36109 ] [80.90272   4.854163 32.361088]


### Implement a function that computes the Hessian matrix for the given function (2 Bonus points)

In [13]:
import jax.numpy as jnp
from jax import grad

def cubic_l2_norm(x):
    return jnp.linalg.norm(x)**3

def hessian_manual(f, x):
    n = x.shape[0]
    hess = jnp.zeros((n, n))

    for i in range(n):
        grad_i = grad(f, argnums=0)
        def partial_derivative(x):
            return grad_i(x)[i]

        hess_i = grad(partial_derivative, argnums=0)(x)
        hess = hess.at[i].set(hess_i)

    return hess

x = jnp.array([5.0, 0.3, 2.0])
hessian_matrix = hessian_manual(cubic_l2_norm, x)
print("Hessian Matrix:", hessian_matrix)

Hessian Matrix: [[30.086134    0.8343353   5.5622354 ]
 [ 0.8343354  16.230606    0.33373415]
 [ 5.562236    0.33373415 18.40544   ]]


## (c) Soft-Max Regression and Neural Networks (15 points)

In [4]:

# Q1c. Neural Network from Scratch
# Example implementation of simple soft-max regression with square loss
#[Nothing to implement in this block]

# define soft argmax
def soft_argmax(x):
    y = jnp.exp(x)
    return y/jnp.sum(y)

# Specify the how to caculate the

def predict(params, inputs):
    W = params[0]
    b = params[1]
    scores = jnp.dot( W, inputs) + b
    outputs = soft_argmax(scores)
    return outputs

def loss_fun(params, inputs, targets):
    preds = predict(params, inputs)
    # You can use cross-entropy loss instead
    return jnp.sum((preds - targets)**2)

grad_fun = grad(loss_fun)  # gradient evaluation function

# Lets' cook up some input for this
W = np.random.randn(3,4)
b = np.random.randn(3,1)
inputs = np.random.randn(4,1)
targets = np.array([1,0,0])
params=[W,b]



In [5]:
# Example usage of above functions
# [Nothing to implement in this block]

# How to Generate a prediction for this inputs
print(predict(params, inputs))

# How to Compute the gradient from this imput
print(grad_fun(params, inputs,targets))

# How to implement an SGD udpate
# Take a minibatch of data X,y
def sgd_update(params,X,y):
    lr = 0.01
    gradient = grad_fun(params,X,y)
    for param, g in zip(params,gradients):
        param -= lr*g


[[0.7718097]
 [0.023276 ]
 [0.2049143]]
[Array([[ 0.44290727, -1.1176932 ,  0.34664053, -0.39820448],
       [-0.06148668,  0.15516396, -0.04812244,  0.05528081],
       [-0.3814207 ,  0.9625295 , -0.29851818,  0.34292376]],      dtype=float32), Array([[ 0.61862624],
       [-0.0858809 ],
       [-0.5327455 ]], dtype=float32)]


**Now you are expected to implement a two-layer neural network from scratch by modifying the above code for soft-max regression.**

**Very Important the only function you need to modify is** `predict_nn`

In [6]:
# Q1c.  continues

#  You may use jnp.tanh to implement a hyperbolic tangent activation fuonction.
#

def soft_argmax(x):
    y = jnp.exp(x)
    return y/jnp.sum(y)

def predict_nn(params, inputs):
    W1 = params[0]
    b1 = params[1]
    W2 = params[2]
    b2 = params[3]

    # First layer: linear transformation then tanh activation
    hidden = jnp.tanh(jnp.dot(W1, inputs) + b1)
    # Second layer: linear transformation to obtain scores
    scores = jnp.dot(W2, hidden) + b2
    outputs = soft_argmax(scores)
    return outputs

def loss_fun_nn(params, inputs, targets):
    preds = predict_nn(params, inputs)
    # Using squared error loss
    return jnp.sum((preds - targets)**2)


grad_fun_nn = grad(loss_fun_nn)

## (d) Neural Network Parameters and Gradient (15 points)
Now let's say the neural network is supposed to classify

In [7]:
# Q1d Lets' cook up some input:

inputs = np.random.randn(64,1)
targets = np.array([1,0,0])

# What could be a valid shape of the following parameters?  Replace the question marks with valid numbers

W1 = jnp.array(np.random.randn(6, 64), dtype=jnp.float32)
b1 = jnp.array(np.random.randn(6, 1), dtype=jnp.float32)
W2 = jnp.array(np.random.randn(3, 6), dtype=jnp.float32)
b2 = jnp.array(np.random.randn(3, 1), dtype=jnp.float32)


params=[W1,b1,W2,b2]



In [12]:

predict_nn(params, inputs)

grad_fun_nn(params, inputs,targets)

[Array([[ 2.3542825e-06, -3.6968760e-07, -4.9402229e-06, -4.1394696e-06,
         -9.8418186e-06,  8.0379522e-07,  5.4087349e-07,  4.8220763e-06,
          4.6904765e-06,  3.5718469e-06, -3.4993764e-06, -3.2033065e-06,
          1.7704331e-06,  1.7595565e-06, -1.9591282e-06, -3.9403026e-06,
         -6.5035924e-06,  1.3518017e-07, -9.2539295e-07, -1.8957248e-07,
          7.4537111e-06,  6.5348318e-06,  6.9858943e-06, -2.0507632e-06,
          1.0439821e-06,  2.4046362e-06, -5.5359474e-06, -2.8329914e-06,
          5.0868121e-06,  2.3051925e-06,  6.4688902e-07,  4.9020241e-07,
         -4.4519857e-06,  3.5559385e-06,  1.5751663e-06, -2.8684735e-06,
         -4.5185211e-06, -6.9223955e-07, -4.1694479e-06, -5.4795537e-06,
         -1.8865945e-06, -2.5612119e-06, -3.6040640e-06,  2.8282750e-06,
         -3.6670867e-06,  1.1330413e-06,  2.2795441e-06, -1.2136221e-06,
         -7.9058336e-06,  6.0834745e-06, -5.1512361e-06,  4.0811137e-06,
          1.2462170e-06,  2.4620606e-06,  5.6822437