# Final Review

In [1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, random

## The Bregman Proximal method


Formulate the Bregman proximal method given a function $E:\mathbb{R}^N \to \mathbb{R}$ for the choice 

$$
    J({\bf x}) = \frac{1}{\tau}\sum_{j=1}^N\left[(1-x_n)\log(1-x_n) + x_n \log(x_n) -1 \right]
$$


The Bregman distance for the given choice $J$ is given by

$$
    D_J({\bf u}, {\bf v}) = \sum_{n=1}^N u_n \left[ \log\left(\frac{u_n}{v_n}\right) + (1 - u_n) \log\left(\frac{1-u_n}{1-v_n}\right) \right]
$$

In [4]:
key = random.PRNGKey(0)
x, y = jnp.abs(random.beta(key, 2, 4, (2, 15)))
x

array([0.6010489 , 0.4508557 , 0.35213393, 0.02290265, 0.69061667,
       0.48838946, 0.44977003, 0.7096753 , 0.85323936, 0.19534549,
       0.45806506, 0.3516461 , 0.21225834, 0.45255467, 0.19804347],
      dtype=float32)

In [5]:
def J(x, τ=1):
    v =  1/τ * ((1 - x) * jnp.log(1 - x) + x * jnp.log(x)) - 1
    return jnp.sum(v)

def dxJ_exact(x, τ=1):
    return 1/τ * jnp.log(x / (1 - x))

dxJ = grad(J)

In [6]:
dxJ(x)

DeviceArray([ 0.40983742, -0.19721389, -0.6096724 , -3.7533333 ,
              0.8030039 , -0.0464505 , -0.2015999 ,  0.89380765,
              1.7602372 , -1.4156432 , -0.16813481, -0.61181146,
             -1.3113661 , -0.19035393, -1.3985679 ], dtype=float32)

In [7]:
dxJ_exact(x)

DeviceArray([ 0.40983737, -0.19721383, -0.60967237, -3.7533333 ,
              0.80300385, -0.04645047, -0.2015999 ,  0.89380765,
              1.7602372 , -1.4156432 , -0.16813473, -0.6118114 ,
             -1.3113661 , -0.19035396, -1.3985679 ], dtype=float32)

In [8]:
def bregman_bern_exact(u, v, τ=1):
    """
    Compute the exact bregman distance for with respect
    to the the bernoulli log-likelihood minus N
    """
    res = 1 / τ * (u * jnp.log(u / v) + (1 - u) * jnp.log((1 - u) / (1 - v)))
    return jnp.sum(res)


def bregman(u, v, J):
    """
    Generalized Bregman distance for arbitrary,
    differentiable, and convex choice of J
    """
    dxJ = grad(J)
    return J(u) - J(v) - dxJ(v) @ (u - v)

In [9]:
bregman_bern_exact(x, y, 0.4)

DeviceArray(7.044682, dtype=float32)

In [10]:
bregman(x, y, lambda u: J(u, 0.4))

DeviceArray(7.0446863, dtype=float32)

### Comparing the asymmetry

In [12]:
bregman(x, y, J)

DeviceArray(2.817871, dtype=float32)

In [13]:
bregman(y, x, J)

DeviceArray(3.3040211, dtype=float32)