# Goal

Using gradient descent, find the correct array _index_ to set to one.

_With the limitation of using MSE loss? Note that something like `scipy.stats.wasserstein_distance` would be a smoother distance measure for this specific problem, but in my application, finding an index to set is only one subproblem within a tree of parameters that all contribute to an output signal to be optimized._

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import value_and_grad, jit, lax, random
from jax.ops import index_update
import scipy
from matplotlib import pyplot as plt

In [None]:
Y = np.zeros(10)
target_index = 3
Y[3] = 1.0

plt.figure(figsize=(14, 3))
plt.stem(Y, basefmt=' ')
_ = plt.title('Target function', size=16)

In [None]:
def correlation(X, Y):
    return 1 - jnp.correlate(X, Y).mean()

def mse(X, Y):
    return ((Y - X) ** 2).mean()

In [None]:
def single_index(index, key=None):
    X = jnp.zeros(Y.size)
    return index_update(X, index.astype('int32'), 1.0)

In [None]:
def linear_interp_index(index, key=None):
    X = jnp.zeros(Y.size)
    i0 = jnp.floor(index)
    i1 = i0 + 1
    X = index_update(X, i0.astype('int32'), (i1 - index))
    X = index_update(X, i1.astype('int32'), (index - i0))
    return X

In [None]:
def gaussian(x,x0,sigma):
    return jnp.exp(-((x - x0) / sigma)**2 / 2)

def gaussian_interp_index(index, key=None):
    return gaussian(jnp.arange(Y.size), index, 2.0)

In [None]:
def gaussian_sample_index(index, key):
    weights = gaussian(jnp.arange(Y.size), index, 3.0)
    discrete_index = random.choice(key, Y.size, p=weights)
    index_array = jnp.zeros(Y.size)
    return index_update(index_array, discrete_index, 1.0)

In [None]:
import scipy.signal

def window(size):
    return signal.blackmanharris()

In [None]:
def triangle(X, center):
    return 1 - jnp.abs(X - center) / jnp.max(X - center)

def triangle_interp_index(index):
    return triangle(jnp.arange(Y.size), index)

In [None]:
def index_loss(index_guess, create_indices_fn, key):
    X = create_indices_fn(index_guess, key=key)
    return mse(X, Y)

In [None]:
from functools import partial

def optimize(create_indices_fn, index_guess, steps=10):
    key = random.PRNGKey(0)
    loss_fn = partial(index_loss, create_indices_fn=create_indices_fn)
    grad_fn = jit(value_and_grad(loss_fn))
    estimated_index = index_guess
    initial_loss, _ = grad_fn(estimated_index, key=key)
    key, subkey = random.split(key)
    for train_i in range(steps):
        loss, grad = grad_fn(estimated_index, key=key)
        key, subkey = random.split(key)
        estimated_index -= grad
    return estimated_index, initial_loss, loss

In [None]:
def plot_optimized_indexes(create_indices_fn, start_indices=np.linspace(0, Y.size - 1, 90, endpoint=False), steps=20):
    fig, plots = plt.subplots(3, 1, figsize=(14, 8))
    estimated_index_plot, initial_loss_plot, optimized_loss_plot = plots
    optimized = np.array([optimize(create_indices_fn, start_index, steps=steps) for start_index in start_indices])
    estimated_index_plot.plot(start_indices, optimized[:,0], linewidth=3)
    estimated_index_plot.set_title('Estimated index', size=18)
    estimated_index_plot.set_ylabel('Estimated index after {} steps'.format(steps), size=11)
    estimated_index_plot.axhline(target_index, linestyle='--', c='r', label='Target index')
    estimated_index_plot.legend()
    initial_loss_plot.plot(start_indices, optimized[:,1], linewidth=3)
    initial_loss_plot.set_title('Initial loss', size=18)
    initial_loss_plot.set_ylabel('Initial MSE loss'.format(steps), size=11)
    optimized_loss_plot.plot(start_indices, optimized[:,2], linewidth=3)
    optimized_loss_plot.set_title('Optimized loss', size=18)
    optimized_loss_plot.set_xlabel('Initial guess for index value', size=16)
    optimized_loss_plot.set_ylabel('MSE loss after {} steps'.format(steps), size=11)
    for plot in plots:
        plot.grid(True)
        plot.set_xticks(np.arange(Y.size))
    plt.tight_layout()
    plt.grid(True)

In [None]:
plot_optimized_indexes(single_index)

In [None]:
plot_optimized_indexes(linear_interp_index)

In [None]:
# What's happening is that at the edges, moving the gaussian distribution removes some of the mass from the
# smooth-indexing array that's being compared with target array (unit impulse at some target index).
# As mass is removed, MSE loss decreases since the area under the gaussian curve and the zeros in the target
# array decreases.
# TODO might be able to extend the successful range by using a different distribution.
# Ideal distribution would be smooth, with constant area under the curve, extending the full window,
# with skew centered around a parameter.
plot_optimized_indexes(gaussian_interp_index, steps=50)

In [None]:
plot_optimized_indexes(gaussian_sample_index, steps=20) # don't get why this is constant optimized loss

In [None]:
summed = jnp.zeros(Y.size)
key = random.PRNGKey(0)
for _ in range(100):
    a = gaussian_sample_index(4.2, key)
    key, subkey = random.split(key)
    summed += a

In [None]:
summed