# Learning to Learning

In [57]:
import autoroot
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Problem

Let's define a function as:

$$
\boldsymbol{y} = \boldsymbol{f}(\mathbf{x})
$$

We can the function as

$$
\boldsymbol{f}(\mathbf{x};\boldsymbol{\theta}) = \mathbf{wx}
$$

where $\boldsymbol{\theta}=\{ \mathbf{w}\}$ are the parameters

### Loss Function

Now, we can define a loss function as the MSE:

$$
\boldsymbol{L}(\mathbf{x};\boldsymbol{\theta}) = 
||\mathbf{wx} 
- \mathbf{y} ||_2^2
$$

$$
\boldsymbol{R}(\mathbf{x};\boldsymbol{\theta}) = 
\alpha||\mathbf{w}||_2^2
$$

$$
\mathbf{x}^*(\boldsymbol{\theta}) = 
\underset{\mathbf{x}}{\text{argmin}} \hspace{2mm}
\boldsymbol{J}(\mathbf{x};\boldsymbol{\theta})
$$

### Optimization

**Parameter Estimation**

$$
\mathbf{x}^{(k+1)} = 
\mathbf{x}^{(k)} +
\boldsymbol{g}_k
$$

The optimizer function, $\boldsymbol{g}(\cdot)$, defined as:

$$
\left[ \boldsymbol{g}_k, \boldsymbol{h}_{k+1}\right] =
\boldsymbol{g}
\left( \boldsymbol{\nabla}_{\mathbf{x}}\boldsymbol{L}(\mathbf{x};\boldsymbol{\theta}), \boldsymbol{h}_k \right)
$$

In [58]:
import typing as tp
import equinox as eqx
from jaxtyping import Array


class LinearModel(eqx.Module):
    weight: Array

    def __init__(self, dim_in, dim_out, key=jrandom.PRNGKey(123)):

        self.weight = jrandom.normal(key=key, shape=(dim_out, dim_in))

    def __call__(self, x):
        return jnp.matmul(self.weight, x)
    



**Objective Function**

$$
\boldsymbol{J}(\mathbf{x};\boldsymbol{\theta}) =
\boldsymbol{L}(\mathbf{x};\boldsymbol{\theta}) +
\lambda
\boldsymbol{R}(\mathbf{x};\boldsymbol{\theta})
$$

**Data Fidelity Term**

$$
\begin{aligned}
\boldsymbol{L}(\mathbf{x};\boldsymbol{\theta}) &=
||\mathbf{y} - \boldsymbol{f}(\mathbf{x};\boldsymbol{\theta})||_2^2 \\
&= ||\mathbf{y} - \mathbf{xw}||_2^2
\end{aligned}
$$

**Regularization Term**

$$
\begin{aligned}
\boldsymbol{R}(\boldsymbol{\theta}) &=
||\boldsymbol{\theta}||_2^2 \\
&= ||\mathbf{w}||_2^2
\end{aligned}
$$

In [None]:
class RidgeLoss(eqx.Module):
    alpha: Array
    model: tp.Callable = LinearModel

    def __init__(self, model: tp.Callable, alpha=0.1):
        self.model = model
        self.alpha = jnp.asarray(alpha)

    def data_loss(self, x, y):
        y_pred = self.model(x)
        return jnp.mean(y_pred - y)
    
    def reg_loss(self):
        return jnp.sum(self.model.weight ** 2)

    def loss(self, x, y, return_losses: bool=False):
        
        # data loss
        data_loss = 0.5 * self.data_loss(x, y)

        # reg loss
        reg_loss = 0.5 * self.alpha * self.reg_loss()

        # total loss
        loss = data_loss + reg_loss

        if return_losses:
            losses = dict(loss=loss, data=data_loss, reg=reg_loss)
            return loss, losses
        else:
            return loss


In [30]:
x = jnp.ones((10,))
y = jnp.ones((1,))

lr_model = LinearModel(10, 1)
y_pred = lr_model(x)
y_pred.shape, y.shape

((1,), (1,))

In [31]:
ridge_loss = RidgeLoss(lr_model, alpha=0.01)

ridge_loss.loss(x, y)

Array(-2.0766206, dtype=float32)

In [38]:
def jaxopt_loss(model, x, y):
    return model.loss(x, y)

$$
\mathbf{x}^*(\boldsymbol{\theta}) =
\underset{\mathbf{x}}{\text{argmin}} \hspace{2mm}
\boldsymbol{J}(\mathbf{x};\boldsymbol{\theta})
$$

In [39]:
from jaxopt import OptaxSolver
import jaxopt
import optax

learning_rate = 1e-3
maxiter = 100

opt = optax.adam(learning_rate)
# solver = OptaxSolver(opt=opt, fun=jaxopt_loss, maxiter=1000)

In [40]:
solver = jaxopt.LBFGS(fun=jaxopt_loss, maxiter=maxiter)
res = solver.run(ridge_loss, x=x, y=y)

# Alternatively, we could have used one of these solvers as well:
# solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500)
# solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500)
# solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500)

In [55]:

opt = optax.adam(learning_rate)

def jaxopt_soln(x, y):
  gd = OptaxSolver(opt=opt, fun=jaxopt_loss, maxiter=1000, implicit_diff=True)
  # gd = jaxopt.LBFGS(fun=jaxopt_loss, maxiter=500, implicit_diff=True, )
  # gd = jaxopt.GradientDescent(fun=jaxopt_loss, maxiter=500, implicit_diff=True, )
  return gd.run(ridge_loss, x=x, y=y).params

In [56]:
soln  = jaxopt_soln(x, y)
soln.alpha

Array(-1.2881405, dtype=float32)

In [46]:
res.params.alpha, res.params.model.weight

(Array(0.01, dtype=float32),
 Array([[-0.10502207, -0.56205004, -0.56485987, -1.7063935 ,  0.56626016,
         -0.42215332,  1.0077653 ,  0.9922631 , -0.61236995, -1.8450408 ]],      dtype=float32))

In [47]:
res.state

LbfgsState(iter_num=Array(100, dtype=int32, weak_type=True), value=Array(-2.0766206, dtype=float32), grad=RidgeLoss(alpha=f32[], model=LinearModel(weight=f32[1,10])), stepsize=Array(0., dtype=float32), error=Array(5.1628613, dtype=float32), s_history=RidgeLoss(alpha=f32[10], model=LinearModel(weight=f32[10,1,10])), y_history=RidgeLoss(alpha=f32[10], model=LinearModel(weight=f32[10,1,10])), rho_history=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), gamma=Array(1., dtype=float32), aux=None, failed_linesearch=Array(True, dtype=bool))