<a href="https://colab.research.google.com/github/elliotpaquette/elliotpaquette/blob/master/ODE_resolvent_stuff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def ode_resolvent_adaGrad(K, x, x_star, h, Dh, cov_grad_f, b, eta, ridge, times, label_noise,m):
  """Generate the theoretical solution to gradient flow

  Parameters
  ----------
  K : array (d x d)
      covariance matrix
  x, x_star : array (d x o)
      initializations of x_0 and x_star
  h : function (outputs scalar)
    Computes the risk given C
  Dh : function (outputs 2 matrices)
        Computes for any time the derivatives $Dh_11$ and $Dh_{12}$ for $h(C(X)) = R(X)$
  cov_grad_f : function (outputs 1 matrix)
        Computes for any time the derivative of $E_a[\nabla f(x) \nabla f(x)^T]
  b_0 : float
        Constant in the denominator of adaGrad_norm
  eta : float
        Constant in the numerator of adaGrad_norm
  ridge : float
        The ridge parameter to pass
  t_max : float
      The number of epochs
  n_grid : int
      The number of grid points

  Returns
  -------
  t_grid: numpy.array(float)
      the time steps used, which will discretize (0,t_max) into n_grid points
  risks: numpy.array(float)
      the values of the risk
  """

  risks = np.zeros_like(times)
  n_grid = jnp.shape(times)
  n_grid = n_grid[0]

  o = x.shape[1]
  t = x_star.shape[1]
  cross_terms = np.zeros((n_grid,o,t)) # times x o x t (same shape as C_12)
  norm_terms = np.zeros((n_grid,o,o)) #times x o x o (same shape as C_11)
  adaGrad_gammas =  np.zeros_like(times)

  Dt = jnp.diff(times, prepend=times[0])

  Keigs, Kvecs = jnp.linalg.eigh(K)
  halfS_x = x.transpose()@Kvecs
  halfS_x_star = x_star.transpose()@Kvecs


  #K_squared = np.einsum('ij,jk->ik', K,K)
  trace_K = jnp.sum(Keigs)
  length_d = len(Keigs)

  #S_12 is d x o x t
  S_11 = jnp.einsum('ki,ji->ijk', halfS_x, halfS_x)
  S_12 = jnp.einsum('ji,ki->ijk', halfS_x, halfS_x_star)
  S_22 = jnp.einsum('ki,ji->ijk', halfS_x_star, halfS_x_star)

  def adaGrad_gamma(eta, b, length_d, trace_K, risk, Dt):
    integral_risk = jnp.sum( risk * Dt  )
    #print(eta / jnp.sqrt( b**2 + ( 2.0 * trace_K) / length_d * integral_risk ))
    return eta / jnp.sqrt( b**2 + ( 2.0 * trace_K) / length_d * integral_risk )

  gamma = eta / b

  for i in range(n_grid):
    C_11 = jnp.tensordot(S_11, Keigs, axes=(0,0))
    C_12 = jnp.tensordot(S_12, Keigs, axes=(0,0))
    C_22 = jnp.tensordot(S_22, Keigs, axes=(0,0))

  #DH_11 is o x o and DH_21 is t x o
    DH_11, DH_21 = Dh(C_11, C_12, C_22, m)

    #pdb.set_trace()
    S_11_gr = -2.0*gamma*jnp.einsum('i,ijk->ijk', Keigs,(
                                       jnp.tensordot(S_11,DH_11, axes=(2,0))
                                       +jnp.tensordot(S_12,DH_21, axes=(2,0))
                                       +jnp.einsum('ijk,jl->ilk', S_11, DH_11)
                                       #+np.tensordot(S_11,DH_11, axes=(1,1))
                                       +jnp.tensordot(S_12,DH_21, axes=(2,0))
                                       )) - 2.0 * gamma * ridge * S_11

    S_12_gr = -2.0*gamma*np.einsum('i,ijk->ikj', Keigs,(
                                       jnp.tensordot(S_12,DH_11, axes=(1,0)) #output (d x t x o)
                                       +jnp.tensordot(S_22,DH_21, axes=(1,0)) #output (d x t x o)

                                       )) - 1.0 * gamma * ridge * S_12

    S_11_noise = ( gamma**2/ float(d) )*jnp.tensordot(Keigs,cov_grad_f(C_11,C_12,C_22,label_noise),axes=0)

    S_11 += Dt[i]*(S_11_gr + S_11_noise)
    S_12 += Dt[i]*(S_12_gr)

    risks[i] = h(C_11, C_12, C_22, label_noise)
    adaGrad_gammas[i] = adaGrad_gamma(eta, b, length_d, trace_K, risks, Dt)
    gamma = adaGrad_gammas[i]

    cross_terms[i] = C_12

    norm_terms[i] = C_11

  return times, risks, cross_terms, norm_terms, adaGrad_gammas

In [None]:
#Functions used to compute h, Dh and cov_grad_f, Dh = [D_11, D_21]

def h(B11,B12,B22, label_noise):
  return 0.5 * jnp.trace(B11) - 0.5 * jnp.trace(B12) - 0.5 * jnp.trace(B12.transpose()) + 0.5 *jnp.trace(B22) + 0.5 * jnp.sum( label_noise**2 )

def Dh(C11,C21,C22, m):
  Dh11 = 0.5 * jnp.identity( m )
  Dh21 = -0.5 * jnp.identity( m )
  return (Dh11, Dh21)

def cov_grad_f(C11,C21,C22,label_noise):
  cov_grad = C11 - C21 - C21.transpose() + C22 + jnp.einsum('i,j->ij', label_noise,label_noise)
  return cov_grad


In [None]:
#b and eta are for adaGrad gammas - just change

adaGrad_times, adaGrad_risks, adaGrad_cross_terms, adaGrad_norm_terms, adaGrad_gammas = ode_resolvent_adaGrad(K, jnp.reshape(initial, (d,phase_o)), jnp.reshape(strlin_xstar,(d,phase_t)),
              h, Dh, cov_grad_f, b, eta, ridge, times, jnp.reshape(epsilon,(m,)), m)