In [None]:
# default_exp optimizers

In [None]:
#exporti
from jax import grad
from functools import partial
import numpy as np
from typing import Callable, Union

# Optimizers

> This module implements interfaces and several known optimizers that can be tested against different functions

In [None]:
#exporti
def tuple_float_cast(_tuple):
    x, y = _tuple
    return np.round(float(x), 3), np.round(float(y), 3)

class History(list):
    """
    This object stores the states through which an optimizer has passed through.
    
    Normally we would have just a list for this but because we are storing `jax` states,
    we need to subclass the `__repr__` method so we process the output a bit 
    (displaying the parameters of the state) and not the state in itself
    """
    def __repr__(self):
        if not hasattr(self, '_get_params'):
            return super().__repr__()
        else:
            elements = [tuple_float_cast(self._get_params(state)) for state in self]
            return str(elements)

In [None]:
#hide
h = History([(1, 2), (2, 3), (4, 5)])
assert str(h) == str([(1, 2), (2, 3), (4, 5)])

h._get_params = lambda x: x
assert str(h) == str([(1.0, 2.0), (2.0, 3.0), (4.0, 5.0)]), str(h)

In [None]:
#exporti
"""
Each class of optimizers has a special calling convention.
It's unfortunte that we can't just subclass the optimizers and inject our custom method
and we need to do this. This happens because the optimziers are written in a pure functional
style, and the methods are just functions that share state between them, back and forth.
"""
def _derivatives_based_update(i, state, update_fn, get_params_fn, grad_fn):
    params = get_params_fn(state)
    grads = grad_fn(*params)
    return update_fn(i, grads, state)

def _derivatives_free_update(i, state, update_fn, function):
    return update_fn(i, function, state)

In [None]:
#export
class optimize:
    def __init__(self, function):
        self.function = function
        self.history = History()        

    def using(self, optimizer=(None, None, None), name='sgd', derivatives_based=True, render_decorator: Callable=None):
        self.derivatives_based = derivatives_based
        self.__init, self.__update, self._get_params = optimizer
        self.render_decorator = render_decorator
        
        # add this to the history object so it can extract the value for presenting them in __repr__
        # otherwise we will see a list of `jax` states
        self.history._get_params = self._get_params
        
        #functional polymorphysm ?!
        if derivatives_based:
            self._update_fn = partial(
                _derivatives_based_update, 
                update_fn=self.__update, 
                get_params_fn=self._get_params,
                grad_fn=grad(self.function, argnums=(0, 1))
            )
        else:
            self._update_fn = partial(
                _derivatives_free_update, 
                update_fn=self.__update, 
                function=self.function
            )

        self.optimizer = optimizer
        self.optimizer_name = name
        return self


    def start_from(self, params):
        self.state = self.__init(tuple(params))
        self.history.append(self.state)
        return self

    def update(self, nr_iterations=1):
        # we add the initial state as state 0, but we haven't made any udpdates yet
        # so even if we have something in history, the current_iteration is one behind
        current_iteration = len(self.history) - 1   
        for i in range(nr_iterations):
            self.state = self._update_fn(current_iteration + i, self.state)
            self.history.append(self.state)
        return self.history

In [None]:
from optimisations.functions import himmelblau
from jax.experimental.optimizers import sgd

(
    optimize(himmelblau())
        .using(sgd(step_size=0.001))
        .start_from([1., 1.])
        .update(10)
)



[(1.0, 1.0), (1.046, 1.038), (1.093, 1.076), (1.141, 1.114), (1.189, 1.152), (1.238, 1.189), (1.288, 1.226), (1.338, 1.263), (1.389, 1.3), (1.44, 1.336), (1.491, 1.371)]

`JAX` is kind of rough, the optimizers (for now) sit inside the `experimental` submodule which means that their API might change in the future. 

An optimizer is a function that has some initialization parameters, and which returns 3 functions:
* `init` - is a function to which you pass all the initial values of your hidden parameters and you get back a `state` object, which is a `pytree` structure (some internal representation). This is a bit confusing and I'm guessing this intermediate `pytree` thing might disappear from the API in the near future.
* `update` - is the function that does a single update pass over the whole parameters. It receives as inputs:
    * `i` - the count of the current iteration. This usefull because, depending on the optimizer implementation, you can have different learning properties at each iteration (like some annealing strategy for the learning rate, etc..)
    * `g` - the gradient values (you get these by extracting the params from the `state` function, using the `get_params` function bellow (these are the variables that will get updated by the optimizer). Then pass these onto your gradient function and its results as input to this function. 
    * `state` - that `pytre` structure that you've got after calling `init` (and which you'll contrantly replace with the result of this `update` function call)
* `get_params` - a `utils` function that extracts the param object from a known `state` object (which is a `pytree`). 

So the full flow of the above, in code is shown bellow:

In [None]:
from jax.experimental.optimizers import sgd

init, update, get_params = sgd(step_size=0.001) # instantiate the optimizer

state = init((1., 2.)) # initialize the optimizer state with some initial weights and get a state back
print(state)
print(get_params(state))    # you use this function to extract the weight values from the state object

grad_function = grad(himmelblau(), argnums=(0, 1))  # you build the function that will compute your gradients
                                                    # The argnum part is needed because we have to specify that there are two parameters the parent function uses, and we want the derivative to both of them.
    
state = update(0, grad_function(*get_params(state)), state)    # you call update with a iteration number, the gradient of the params, and the previous state and you get back a new state 
print(state)

OptimizerState(packed_state=([1.0], [2.0]), tree_def=PyTreeDef(tuple, [*,*]), subtree_defs=(*, *))
(1.0, 2.0)
OptimizerState(packed_state=([DeviceArray(1.036, dtype=float32)], [DeviceArray(2.032, dtype=float32)]), tree_def=PyTreeDef(tuple, [*,*]), subtree_defs=(*, *))


In [None]:
from jax import grad
from optimisations.functions import himmelblau

grad(himmelblau(), argnums=(0, 1))(*get_params(state))

(DeviceArray(-36.385605, dtype=float32),
 DeviceArray(-30.704094, dtype=float32))

And you can see the result of running 10 iterations of the above, in a loop. It moves to some direction, and I'm sure you're eager to see where, on the graph...

In [None]:
grad_function = grad(himmelblau(), argnums=(0, 1))
def run():
    state = init((1., 2.))
    for i in range(10):
        params = get_params(state)
        yield params
        state = update(i, grad_function(*params), state)
    
[(float(x), float(y)) for x, y in run()]

[(1.0, 2.0),
 (1.0360000133514404, 2.0320000648498535),
 (1.0723856687545776, 2.062704086303711),
 (1.1091352701187134, 2.092081069946289),
 (1.1462260484695435, 2.1201066970825195),
 (1.18363356590271, 2.1467630863189697),
 (1.22133207321167, 2.1720387935638428),
 (1.2592941522598267, 2.1959288120269775),
 (1.2974909543991089, 2.2184340953826904),
 (1.335891842842102, 2.239561080932617)]

In [None]:
#exporti
import re

def heuristic_get_jax_optimizer_name(init):
    """
    Tries to find the name of the optimiser used to instantiate the init function, by parting the 
    string representation of the given function.
    
    JAX based optimisers usually have the following string representation:
    
        function jax.experimental.optimizers.sgd.<locals>.init(x0)
        function jax.experimental.optimizers.sgd.<locals>.update(i, g, x)
        function jax.experimental.optimizers.sgd.<locals>.get_params(x)
    
    """
    function_name = str(init)
    result = re.search("function\s+([^\.]+)", function_name)
    if result is not None:
        return result.group(1)
    else:
        None

In [None]:
heuristic_get_jax_optimizer_name(init)

'sgd'

In [None]:
#exporti
def build_optimizer_params(elements_list):
    optimizer_params = dict()
    if len(elements_list) == 1:
        optimizer_params['optimizer'] = (init, update, get_params) = elements_list[0]
        assert callable(init), f"Expected {init} be a callable."
        assert callable(update), f"Expected {update} be a callable."
        assert callable(get_params), f"Expected {get_params} be a callable."
        
        optimizer_params['name'] = heuristic_get_jax_optimizer_name(init)
    elif len(elements_list) == 2:
        optimizer_params = elements_list[1]
        optimizer_params['optimizer'] = (init, update, get_params) = elements_list[0]
        assert callable(init), f"Expected {init} be a callable."
        assert callable(update), f"Expected {update} be a callable."
        assert callable(get_params), f"Expected {get_params} be a callable."

    elif len(elements_list) == 3:
        optimizer_params['optimizer'] = (init, update, get_params) = elements_list
        assert callable(init), f"Expected {init} be a callable."
        assert callable(update), f"Expected {update} be a callable."
        assert callable(get_params), f"Expected {get_params} be a callable."

        optimizer_params['name'] = heuristic_get_jax_optimizer_name(init)        
    else:
        raise f"""
        Unknown optimizer constructor list shape or size {len(elements_list)}. 
        Expected either 
            1 for [(init, update, get_params)] or 
            2 for [(init, update, get_params), \{other: configs\}] or
            3 for (init, update, get_params)
        Received {elements_list}
        """
    return optimizer_params

In [None]:
print(build_optimizer_params([sgd(step_size=0.01), {"name":"sgd", "derivative":True}]))
print(build_optimizer_params([sgd(step_size=0.01)]))
print(build_optimizer_params(sgd(step_size=0.01)))

{'name': 'sgd', 'derivative': True, 'optimizer': Optimizer(init_fn=<function sgd.<locals>.init at 0x1296cab70>, update_fn=<function sgd.<locals>.update at 0x1296cabf8>, params_fn=<function sgd.<locals>.get_params at 0x1296cac80>)}
{'optimizer': Optimizer(init_fn=<function sgd.<locals>.init at 0x1296ca950>, update_fn=<function sgd.<locals>.update at 0x1296cac80>, params_fn=<function sgd.<locals>.get_params at 0x1296caae8>), 'name': 'sgd'}
{'optimizer': Optimizer(init_fn=<function sgd.<locals>.init at 0x1296cab70>, update_fn=<function sgd.<locals>.update at 0x1296caae8>, params_fn=<function sgd.<locals>.get_params at 0x1296caa60>), 'name': 'sgd'}


In [None]:
#export
class optimize_multi:
    def __init__(self, function):
        self.function = function

    def using(self, optimizers):
        self.optimizers = optimizers
        return self

    def start_from(self, params):
        self.params = params
        return self

    def tolist(self):
        return [optimize(self.function).using(**build_optimizer_params(optimizer)).start_from(self.params) for optimizer in self.optimizers]

When you want to compare the performance of multiple optimizers.

In [None]:
from jax.experimental.optimizers import sgd, adam
from optimisations.functions import himmelblau

(optimizers) = (
    optimize_multi(himmelblau())
        .using([
            sgd(step_size=0.01),
            adam(step_size=0.3),
        ])
        .start_from([-1., 1.])
        .tolist()
)

optimizers

[<__main__.optimize at 0x12abfbcf8>, <__main__.optimize at 0x12abfb8d0>]

In [None]:
for optimizer in optimizers:
    print(optimizer.optimizer_name, optimizer.update(10))

sgd [(-1.0, 1.0), (-1.22, 1.46), (-1.491, 1.977), (-1.805, 2.475), (-2.132, 2.846), (-2.419, 3.036), (-2.619, 3.103), (-2.728, 3.122), (-2.776, 3.128), (-2.795, 3.13), (-2.801, 3.131), (-2.804, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131), (-2.805, 3.131)]
adam [(-1.0, 1.0), (-1.3, 1.3), (-1.6, 1.6), (-1.9, 1.901), (-2.201, 2.202), (-2.501, 2.502), (-2.79, 2.796), (-3.044, 3.078), (-3.221, 3.332), (-3.297, 3.534), (-3.288, 3.662), (-3.219, 3.711), (-3.113, 3.695), (-2.988, 3.629), (-2.858, 3.529), (-2.737, 3.408), (-2.634, 3.278), (-2.557, 3.149), (-2.507, 3.031), (-2.486, 2.929), (-2.491, 2.85)]
