# ECON622: Problem Set 2

Jesse Perla, UBC

## Student Name/Number: Wan-Chiao Chang / 17065012

# Packages

Add whatever packages you wish here

In [18]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import jax
import jax.numpy as jnp
from jax import grad, hessian
from jax import random
import optax
import optimistix as otx
import time

## Question 1

The trace of the Hessian matrix is useful in a variety of applications
in statistics, econometrics, and stochastic processes. It can also be
used to regularize a loss function.

For a function $f:\mathbb{R}^N\to\mathbb{R}$, denote the Hessian as
$\nabla^2 f(x) \in \mathbb{R}^{N\times N}$.

It can be shown that for some mean zero, unit variance random vectors
$v\in\mathbb{R}^N$ with $\mathbb{E}(v) = 0$ and
$\mathbb{E}(v v^{\top}) = I$ the trace of the Hessian fulfills

$$
\mathrm{Tr}(\nabla^2 f(x)) = \mathbb{E}\left[v^{\top} \nabla^2 f(x)\, v\right]
$$

Which leads to a random algorithm by sampling $M$ vectors
$v_1,\ldots,v_M$ and using the Monte Carlo approximation of the
expectation, called the [Hutchinson Trace
Estimator](https://www.tandfonline.com/doi/abs/10.1080/03610918908812806)

$$
\mathrm{Tr}(\nabla^2 f(x)) \approx \frac{1}{M} \sum_{m=1}^M v_m^{\top} \nabla^2 f(x)\, v_m
$$

### Question 1.1

Now, let’s take the function $f(x) = \frac{1}{2}x^{\top} P x$, which is
a quadratic form and where we know that $\nabla^2 f(x) = P$.

The following code finds the trace of the Hessian, which is equivalently
just the sum of the diagonal of $P$ in this simple function.

In [11]:
key = jax.random.PRNGKey(0)

N = 100  # Dimension of the matrix
A = jax.random.normal(key, (N, N))
# Create a positive-definite matrix P by forming A^T * A
P = jnp.dot(A.T, A)
def f(x):
    return 0.5 * jnp.dot(x.T, jnp.dot(P, x))
x = jax.random.normal(key, (N,))
print(jnp.trace(jax.hessian(f)(x)))
print(jnp.diag(P).sum())

10240.816
10240.816


Now, instead of calculating the whole Hessian, use a [Hessian-vector
product in
JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#hessian-vector-products-using-both-forward-and-reverse-mode)
and the approximation above with $M$ draws of random vectors to
calculate an approximation of the trace of the Hessian. Increase the
numbers of $M$ to see what the variance of the estimator is, comparing
to the above closed-form solution for this quadratic.

Hint: you will want to do Forward-over-Reverse mode differentiation for
this (i.e. the `vjp` gives a pullback function for first derivative,
then differentiate that new function. Given that it would then be
$\mathbb{R}^N \to \mathbb{R}^N$, it makes sense to use forward mode with
a `jvp`)

In [12]:
# jax.grad (Reverse) computes the gradient
# jax.jvp (Forward) differentiates the gradient in direction v
def hessian_vector_product(f, x, v):
    # jvp returns (primals, tangents). We only need the tangent (the derivative part).
    primals, tangent = jax.jvp(jax.grad(f), (x,), (v,))
    return tangent

# Hutchinson Trace Estimator Function
def estimate_trace(f, x, M, key):
    estimates = []
    
    # Generate M random vectors (standard normal distribution)
    vs = jax.random.normal(key, (M, N))
    
    for i in range(M):
        v = vs[i]
        # Compute Hv (Hessian-vector product)
        Hv = hessian_vector_product(f, x, v)
        # Compute v^T * Hv
        val = jnp.dot(v.T, Hv)
        estimates.append(val)
    
    return jnp.mean(jnp.array(estimates))


In [14]:
# We try increasing values of M to see the variance decrease
true_trace = jnp.diag(P).sum()
M_values = [10, 100, 500, 1000, 5000]

print("Estimating Trace using Hutchinson's Method:")
for M in M_values:
    # Use a new key for the sampling step
    key, subkey = jax.random.split(key)
    
    approx_trace = estimate_trace(f, x, M, subkey)
    error = abs(approx_trace - true_trace)
    percent_error = (error / true_trace) * 100
    
    print(f"M = {M:<5} | Est: {approx_trace:.3f} | Error: {percent_error:.2f}%")

Estimating Trace using Hutchinson's Method:
M = 10    | Est: 10379.485 | Error: 1.35%
M = 100   | Est: 10214.259 | Error: 0.26%
M = 500   | Est: 10146.236 | Error: 0.92%
M = 1000  | Est: 10223.255 | Error: 0.17%
M = 5000  | Est: 10229.800 | Error: 0.11%


### Question 1.2 (Bonus)

If you wish, you can play around with radically increase the size of the
`N` and change the function itself. One suggestion is to move towards a
sparse or even matrix-free $f(x)$ calculation so that the $P$ doesn’t
itself need to materialize.

In [16]:
# --- 1. SETUP: RADICAL SCALE ---
# We use N = 100,000. 
N = 100_000 
key = jax.random.PRNGKey(42)

# --- 2. DEFINE MATRIX-FREE FUNCTION ---
# Use weighted sum of square function
# Effectively, our "Matrix" P is a diagonal matrix where P_ii = i + 1
def f(x):
    # Create the "diagonal" weights on the fly (1, 2, ..., N)
    weights = jnp.arange(1, N + 1, dtype=jnp.float32)
    return 0.5 * jnp.sum(weights * (x ** 2))

# Analytical True Trace: Sum of 1 to N -> N*(N+1)/2
true_trace = (N * (N + 1)) / 2
print(f"Dimension N: {N}")
print(f"True Trace (Analytical): {true_trace:.0f}")
print("-" * 40)

# --- 3. HUTCHINSON ESTIMATOR (Matrix-Free) ---

# Using the efficient Forward-over-Reverse (HvP) from the previous step
def hessian_vector_product(f, x, v):
    # jax.grad computes gradient (Reverse mode)
    # jax.jvp computes directional derivative (Forward mode)
    # This computes H*v without building H
    return jax.jvp(jax.grad(f), (x,), (v,))[1]

def hutchinson_estimate(f, x, M, key):
    # Generate M random vectors (Rademacher or Normal)
    # Normal (mean=0, var=1) is standard for Hutchinson
    vs = jax.random.normal(key, (M, N))
    
    estimates = []
    for i in range(M):
        v = vs[i]
        Hv = hessian_vector_product(f, x, v)
        # v^T * Hv
        estimates.append(jnp.dot(v, Hv))
        
    return jnp.mean(jnp.array(estimates))

# --- 4. RUN THE ESTIMATION ---
# Arbitrary input x (value doesn't matter for quadratic Hessian)
x_input = jnp.ones(N)

# Let's try with 1000 samples
M = 1000
key, subkey = jax.random.split(key)

print(f"Running Hutchinson Estimator with M={M} samples...")
est_trace = hutchinson_estimate(f, x_input, M, subkey)

error_perc = 100 * abs(est_trace - true_trace) / true_trace

print(f"Estimated Trace:       {est_trace:.0f}")
print(f"Error:                 {error_perc:.4f}%")

Dimension N: 100000
True Trace (Analytical): 5000050000
----------------------------------------
Running Hutchinson Estimator with M=1000 samples...
Estimated Trace:       4999923200
Error:                 0.0025%


## Question 2

This section gives some hints on how to setup a differentiable
likelihood function with implicit functions

### Question 2.1

The following code uses scipy to find the equilibrium price and demand
for some simple supply and demand functions with embedded parameters

In [17]:
from scipy.optimize import root_scalar

# Define the demand function with power c
def demand(P, c_d):
    return 100 - 2 * P**c_d

# Define the supply function with power f
def supply(P, c_s):
    return 5 * 3**(c_s * P)

# Define the function to find the root of, including c and f
def equilibrium(P, c_d, c_s):
    return demand(P, c_d) - supply(P, c_s)

# Use root_scalar to find the equilibrium price
def find_equilibrium(c_d, c_s):
    result = root_scalar(equilibrium, args=(c_d, c_s), bracket=[0, 100], method='brentq')
    return result.root, demand(result.root, c_d)

# Example usage
c_d = 0.5
c_s = 0.15
equilibrium_price, equilibrium_quantity = find_equilibrium(c_d, c_s)
print(f"Equilibrium Price: {equilibrium_price:.2f}")
print(f"Equilibrium Quantity: {equilibrium_quantity:.2f}")

Equilibrium Price: 17.65
Equilibrium Quantity: 91.60


First, convert this to use JAX and
[Optimistix](https://docs.kidger.site/optimistix/) for finding the root
using `optimistix.root_find()`. Make sure you can jit the whole
`find_equilibrium` function

In [22]:
def demand(P, c_d):
    return 100 - 2 * jnp.power(P, c_d)

def supply(P, c_s):
    return 5 * jnp.power(3.0, c_s * P)

def equilibrium_fn(P, args):
    c_d, c_s = args
    return demand(P, c_d) - supply(P, c_s)

@jax.jit
def find_equilibrium(c_d, c_s):
    solver = otx.Bisection(rtol=1e-6, atol=1e-6)
    
    # Bisection in Optimistix requires the bracket in the options dict
    # y0 is used as a dummy scalar to define the shape of the output (a scalar price)
    options = dict(lower=0.0, upper=100.0)
    
    sol = otx.root_find(
        fn=equilibrium_fn,
        solver=solver,
        y0=0.0,  # Tells Optimistix we are looking for a scalar root
        args=(c_d, c_s),
        options=options,
        throw=False 
    )
    
    price = sol.value
    quantity = demand(price, c_d)
    
    return price, quantity

# --- Test Usage ---
c_d_val = 0.5
c_s_val = 0.15

price, quantity = find_equilibrium(c_d_val, c_s_val)

print(f"Equilibrium Price: {price:.4f}")
print(f"Equilibrium Quantity: {quantity:.4f}")

Equilibrium Price: 17.6464
Equilibrium Quantity: 91.5985


### Question 2.2

Now, assume that you get a noisy signal on the price that fulfills that
demand system.

$$
\hat{p} \sim \mathcal{N}(p, \sigma^2)
$$

In that case, the log likelihood for the Gaussian is

$$
\log \mathcal{L}(\hat{p}\,|\,c_d, c_s, p) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p)^2
$$

Or, if $p$ was implicitly defined by the equilibrium conditions as some
$p(c_d, c_s)$ from above,

$$
\log \mathcal{L}(\hat{p}\,|\,c_d, c_s) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p(c_d, c_s))^2
$$

Then for some $\sigma = 0.01$ we can calculate this log likelihood the
above as

In [26]:
def log_likelihood(p_hat, c_d, c_s, sigma):
    p, x = find_equilibrium(c_d, c_s)
    return -0.5 * np.log(2 * np.pi * sigma**2) - 0.5 * (p_hat - p)**2 / sigma**2

c_d = 0.5
c_s = 0.15
sigma = 0.01
p, x = find_equilibrium(c_d, c_s) # get the true value for simulation
p_hat = p + np.random.normal(0, sigma) # simulate a noisy signal
log_likelihood(p_hat, c_d, c_s, sigma)

Array(2.2389274, dtype=float32)

Now, take this code for the likelihood and convert it to JAX and jit.
Use your function from Question 2.1

In [25]:
@jax.jit
def log_likelihood(p_hat, c_d, c_s, sigma):
    """
    Calculates the Gaussian log-likelihood for an observed price p_hat.
    The true price p is implicitly defined by (c_d, c_s).
    """
    # 1. Solve for the theoretical equilibrium price p
    p, _ = find_equilibrium(c_d, c_s)
    
    # 2. Compute Gaussian log-likelihood components
    # Formula: -0.5 * log(2πσ²) - 0.5 * (p_hat - p)² / σ²
    term1 = -0.5 * jnp.log(2 * jnp.pi * jnp.power(sigma, 2))
    term2 = -0.5 * jnp.power(p_hat - p, 2) / jnp.power(sigma, 2)
    
    return term1 + term2

# --- Execution & Simulation ---
# Parameters from the screenshot
c_d_true = 0.5
c_s_true = 0.15
sigma_val = 0.01

# Find true price for simulation
p_true, _ = find_equilibrium(c_d_true, c_s_true)

# Simulate a noisy signal using JAX's random keys
key = jax.random.PRNGKey(0)
p_hat = p_true + jax.random.normal(key) * sigma_val

# Calculate Likelihood
ll_value = log_likelihood(p_hat, c_d_true, c_s_true, sigma_val)

print(f"Simulated Observed Price (p_hat): {p_hat:.4f}")
print(f"Log Likelihood: {ll_value:.4f}")

Simulated Observed Price (p_hat): 17.6626
Log Likelihood: 2.3698


### Question 2.3

Use the function from the previous part and calculate the gradient with
respect to `params` (i.e., `c_d` and `c_s`) using `grad` and JAX.

In [None]:
# Define the gradient function with respect to the first argument (params)
grad_ll = jax.grad(log_likelihood, argnums=0)

# Setup test values
params_init = (0.5, 0.15)  # (c_d, c_s)
p_hat_observed = 18.51     # Simulated observation
sigma_val = 0.01

# Calculate the gradient
gradients = grad_ll(p_hat_observed, params_init, sigma_val)

print(f"Gradient with respect to c_d: {gradients[0]:.6f}")
print(f"Gradient with respect to c_s: {gradients[1]:.6f}")

TypeError: log_likelihood() missing 1 required positional argument: 'sigma'

### Question 2.4 (Bonus)

You could try to run maximum likelihood estimation by using a
gradient-based optimizer. You can use either
[Optax](https://optax.readthedocs.io/) (standard for ML optimization) or
[Optimistix](https://docs.kidger.site/optimistix/) with
`optimistix.minimise()`.

If you attempt this:

-   Consider starting your optimization at the “pseudo-true” values with
    the `c_s, c_d, sigma` you used to simulate the data and even start
    with `p_hat = p`.
-   You may find that it is a little too noisy with only the one
    observation. If so, you could adapt your likelihood to take a vector
    of $\hat{p}$ instead. The likelihood of IID gaussians is a simple
    variation on the above.

In [None]:
# Question 2.4 (Bonus): Maximum Likelihood Estimation with Gradient-Based Optimizer

# --- 1. Define log-likelihood for multiple IID observations ---
# For N IID Gaussian observations, the log-likelihood is:
# LL = sum_i [ -0.5 * log(2πσ²) - 0.5 * (p_hat_i - p)² / σ² ]
#    = -N/2 * log(2πσ²) - 1/(2σ²) * sum_i (p_hat_i - p)²

@jax.jit
def neg_log_likelihood_multi(params, p_hat_vec, sigma):
    """
    Negative log-likelihood for multiple IID observations.
    params: (c_d, c_s) - parameters to estimate
    p_hat_vec: array of observed prices
    sigma: noise standard deviation (known)
    """
    c_d, c_s = params
    
    # Solve for equilibrium price
    p_eq, _ = find_equilibrium(c_d, c_s)
    
    # Number of observations
    N = p_hat_vec.shape[0]
    
    # Log-likelihood for IID Gaussians
    term1 = -0.5 * N * jnp.log(2 * jnp.pi * sigma**2)
    term2 = -0.5 * jnp.sum((p_hat_vec - p_eq)**2) / sigma**2
    
    # Return NEGATIVE log-likelihood for minimization
    return -(term1 + term2)


# --- 2. Simulate data ---
# True parameters
c_d_true = 0.5
c_s_true = 0.15
sigma_true = 0.1  # Slightly larger sigma for more realistic noise

# Get true equilibrium price
p_true, q_true = find_equilibrium(c_d_true, c_s_true)
print(f"True parameters: c_d = {c_d_true}, c_s = {c_s_true}")
print(f"True equilibrium price: {p_true:.4f}")

# Generate multiple noisy observations
key = jax.random.PRNGKey(123)
N_obs = 100  # Number of observations
p_hat_data = p_true + sigma_true * jax.random.normal(key, (N_obs,))
print(f"Number of observations: {N_obs}")
print(f"Sample mean of observations: {jnp.mean(p_hat_data):.4f}")


# --- 3. MLE using Optax (gradient descent) ---
# Initialize at values close to but not exactly at true values
params_init = jnp.array([0.6, 0.2])  # Starting guess: (c_d, c_s)

# Setup optimizer (Adam with learning rate)
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params_init)

# Gradient of negative log-likelihood w.r.t. params
grad_nll = jax.grad(neg_log_likelihood_multi, argnums=0)

# Training loop
@jax.jit
def update_step(params, opt_state, p_hat_vec, sigma):
    grads = grad_nll(params, p_hat_vec, sigma)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, grads

# Run optimization
params = params_init
n_iterations = 2000
print_every = 500

print("\n--- Starting Optax Optimization ---")
print(f"Initial params: c_d = {params[0]:.4f}, c_s = {params[1]:.4f}")
print(f"Initial NLL: {neg_log_likelihood_multi(params, p_hat_data, sigma_true):.4f}")

for i in range(n_iterations):
    params, opt_state, grads = update_step(params, opt_state, p_hat_data, sigma_true)
    
    if (i + 1) % print_every == 0:
        nll = neg_log_likelihood_multi(params, p_hat_data, sigma_true)
        print(f"Iter {i+1:4d} | c_d: {params[0]:.4f}, c_s: {params[1]:.4f} | NLL: {nll:.4f} | grad: [{grads[0]:.4f}, {grads[1]:.4f}]")

# Final results
print("\n--- Final Results ---")
print(f"Estimated c_d: {params[0]:.4f} (true: {c_d_true})")
print(f"Estimated c_s: {params[1]:.4f} (true: {c_s_true})")

# Check equilibrium price from estimated parameters
p_estimated, _ = find_equilibrium(params[0], params[1])
print(f"Estimated equilibrium price: {p_estimated:.4f} (true: {p_true:.4f})")

# Compare to sample mean (which is MLE for Gaussian mean)
print(f"Sample mean of observations: {jnp.mean(p_hat_data):.4f}")