In [18]:
from functools import partial
from typing import List, Tuple

import equinox as eqx
import jax
from jax import Array
import jax.numpy as jnp
import numpy as np
import optax

from function_learning_task import init_function_learning_task, step_function_learning_task
from training import train_on_sequence
from utils import tree_replace

In [260]:
class SwiftTDState(eqx.Module):
    # Static params
    n_features: int = eqx.field(static=True) # Number of input features
    meta_lr: float = eqx.field(static=True) # Meta learning rate

    epsilon: float = eqx.field(static=True) # LR decay factor
    eta: float = eqx.field(static=True) # Max learning rate
    trace_decay: float = eqx.field(static=True) # Lambda trace decay
    gamma: float = eqx.field(static=True) # Discount factor

    # State vars
    beta: Array # Learning rate exponent
    h_old: Array
    h_temp: Array
    z_delta: Array
    p: Array # Eligibility trace of lr exponent
    h: Array
    z: Array # Eligibility trace of weights
    z_bar: Array
    V_delta: Array
    V_old: Array

    def __init__(
            self,
            n_features,
            lr_init: float = 1e-7,
            meta_lr: float = 1e-3,
            epsilon: float = 0.9,
            eta: float = 0.5,
            trace_decay: float = 0.95,
            gamma: float = 0.1,
        ):
        self.n_features = n_features
        self.beta = jnp.log(lr_init) * jnp.ones(self.n_features)
        self.meta_lr = meta_lr
        self.epsilon = epsilon
        self.eta = eta
        self.trace_decay = trace_decay
        self.gamma = gamma

        self.h_old, self.h_temp, self.z_delta, self.p, self.h, self.z, self.z_bar = [jnp.zeros(self.n_features) for _ in range(7)]
        self.V_delta, self.V_old = [jnp.array(0.0) for _ in range(2)]

In [261]:
def swift_td_step(
        state: SwiftTDState,
        weights: Array,
        features: Array,
        cumulant: float,
    ) -> Tuple[SwiftTDState, Array, float]:
    """SwiftTD update step.
    
    Args:
        state (SwiftTDState): Current state of the Swift-TD algorithm.
        weights (Array): Current weights of the model. Must have elements of shape (n_features,).
        features (Array): Input features.
        cumulant (float): Scalar cumulant signal.

    Returns:
        Tuple[SwiftTDState, Array, float]: Updated state, updated weights, and TD error.
    """
    orig_weight_shape = weights.shape
    weights = weights.flatten()
    V = jnp.dot(weights, features)
    delta = cumulant + state.gamma * V - state.V_old

    # Weight and lr updates
    out = {}
    delta_w = delta * state.z - state.z_delta * state.V_delta
    delta_w = jnp.where(state.z == 0, jnp.zeros_like(features), delta_w)
    weights = jnp.where(state.z == 0, weights, weights + delta_w) # Weight update

    out['beta'] = state.beta + state.meta_lr / (jnp.exp(state.beta) + 1e-8) # Meta learning rate update
    out['beta'] = jnp.minimum(out['beta'], jnp.log(state.eta)) # Clip learning rate
    out['h_old'] = state.h
    out['h'] = state.h_temp
    out['h_temp'] = out['h'] + delta * state.z_bar - state.z_delta * state.V_delta
    out['z_delta'] = jnp.zeros_like(state.z_delta)

    # Decay traces
    out['z'] = state.gamma * state.trace_decay * state.z
    out['p'] = state.gamma * state.trace_decay * state.p
    out['z_bar'] = state.gamma * state.trace_decay * state.z_bar

    # Replace state variables with out values only where z != 0
    state = tree_replace(
        state,
        **{k: jnp.where(state.z == 0, getattr(state, k), v) for k, v in out.items()}
    )

    state = tree_replace(state, V_delta=jnp.array(0.0))
    lr = jnp.exp(state.beta)
    E = jnp.maximum(jnp.array(state.eta), jnp.dot(lr, features ** 2)) # Rate of learning
    T = jnp.dot(state.z, features)
    state = tree_replace(state, V_delta=state.V_delta + jnp.dot(delta_w, features)) # Minor error because delta_w may not be defined

    # Eligibility trace updates
    out = {}
    out['z_delta'] = state.eta / E * jnp.exp(state.beta) * features
    out['z'] = state.z + out['z_delta'] * (1 - T) # Update weight eligibility trace
    out['p'] = state.p + features * state.h # Update lr eligibility trace
    out['z_bar'] = state.z_bar + out['z_delta'] * (1 - T - features * state.z_bar)
    out['h_temp'] = state.h_temp - state.h_old * features * (out['z'] - out['z_delta']) \
        - state.h * out['z_delta'] * features
    
    # Conditionally decay lr
    out['beta'] = jax.lax.cond(
        E <= state.eta,
        lambda beta: beta,
        lambda beta: beta + jnp.abs(features) * jnp.log(state.epsilon),
        state.beta,
    )
    out['h_temp'], out['h'], out['z_bar'] = jax.lax.cond(
        E <= state.eta,
        lambda h_temp, h, z_bar: (h_temp, h, z_bar),
        lambda h_temp, h, z_bar: (jnp.zeros_like(h_temp), jnp.zeros_like(h), jnp.zeros_like(z_bar)),
        out['h_temp'], state.h, out['z_bar'],
    )

    # Replace state variables with out values only where features != 0
    state = tree_replace(
        state,
        **{k: jnp.where(features == 0, getattr(state, k), v) for k, v in out.items()}
    )

    state = tree_replace(state, V_old=V)

    return state, weights.reshape(orig_weight_shape), delta

In [300]:
std_state = SwiftTDState(n_features=2, gamma=0)
weights = jnp.array([0.0, 0.0])
true_weights = jnp.array([-1.0, 1.0])

feature_sequence = jnp.tile(jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]]), (10, 1))
order = jax.random.permutation(jax.random.PRNGKey(0), len(feature_sequence))
feature_sequence = feature_sequence[order]
cumulant_sequence = jnp.dot(feature_sequence, true_weights)

# Randomize order
feature_sequence = feature_sequence[1:]
cumulant_sequence = cumulant_sequence[:-1]

update_fn = jax.jit(swift_td_step)

In [302]:
for features, cumulant in zip(feature_sequence, cumulant_sequence):
    std_state, weights, td_error = update_fn(std_state, weights, features, cumulant)
    print(td_error)
print(weights)

0.002597332
0.004532099
0.009252369
0.008685052
0.0
0.0
-0.007547915
0.0
0.0
0.005201757
0.0
0.00401783
-0.005837679
-0.005837679
-0.0045090914
-0.00037896633
0.0
0.0030030608
0.00022655725
0.0023105145
-0.0002501607
0.0018559098
0.0020457506
-0.0009137988
0.0
-0.0021799803
0.0015004277
-0.0016685724
0.0
-0.00013160706
-0.00013160706
-0.0012443662
-0.0012279749
0.00025689602
-0.0007109046
0.0010761023
0.0
0.0
-0.00062686205
[-0.99952525  0.9991839 ]


Array([ 0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1.,
        1.,  1.,  1., -1., -1.,  0.,  0.,  0.,  0., -1., -1.,  0., -1.,
        0.,  1., -1., -1.,  1.,  0.,  1., -1.,  1.,  0.,  0.,  1., -1.,
        0.,  0.,  1.,  0.,  0., -1.,  0.,  1.,  0.,  1., -1.,  0.,  1.,
        0.,  0.,  0.,  1.,  0., -1.,  1.,  0.,  0.,  1., -1.,  1.,  0.,
        0.,  0.,  0.,  0.,  0., -1.,  0., -1.,  0.,  0.,  0., -1.,  0.,
        1.,  0.,  1., -1.,  0.,  0., -1.,  0., -1.,  0.,  1.,  0.,  1.,
        0., -1., -1.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0., -1.,  1.,
       -1.,  1.,  0.,  1., -1., -1.,  0., -1.,  1.,  0.,  1.,  1.,  1.,
        1., -1., -1.,  0., -1., -1.,  0.,  0., -1.,  1., -1.,  1.,  0.,
        0.,  0.,  1., -1.,  0.,  0.,  0.,  1., -1.,  0.,  1.,  1.,  1.,
       -1.,  0., -1.,  1., -1.,  0.,  0., -1.,  0.,  0.,  1., -1.,  0.,
        0., -1.,  0., -1.,  1.,  0.,  0., -1.,  0.,  1.,  1.,  0., -1.,
        0.,  0.,  0.,  1.,  0., -1.,  0.,  0.,  0.,  1.,  1.,  1

In [None]:
lr_init = 0.1
meta_step_size = 1e-3
trace_decay = 0.95
gamma = 0.1 # 0.1
eta = 0.5
lr_decay = 0.9

In [None]:
StepData = namedtuple('StepData', ['input', 'reward', 'prediction'])
history = []

agent = Agent()

prev_obs, env_state = init_env()
prev_obs = agent.perception(prev_obs)

trace, p, h, prev_h, trace, trace_bar = [torch.zeros_like(prev_obs) for _ in range(6)]
beta = torch.empty_like(prev_obs).fill_(np.log(lr_init))
lr = torch.exp(beta)
old_value_pred = 0

for _ in range(2000):
    # Env step
    obs, reward, env_state = env_step(env_state)
    obs = agent.perception(obs)
    obs[1] = 0

    # Update weights
    prev_value_pred = (agent.value_weights @ prev_obs).squeeze()
    value_pred = (agent.value_weights @ obs).squeeze()
    value_diff = prev_value_pred - old_value_pred
    td_error = reward + gamma * value_pred - prev_value_pred
    trace_value = torch.dot(trace, prev_obs)
    lr_scale_max = torch.max(torch.tensor(eta), torch.sum(lr * prev_obs ** 2))

    lr = torch.exp(beta)
    trace *= gamma * trace_decay
    trace += (eta / lr_scale_max) * lr * prev_obs \
        - lr * gamma * trace_decay * trace_value * prev_obs
    agent.value_weights += td_error * trace - lr * prev_obs * value_diff

    # Step-size optimization
    p = trace_decay * gamma * p + prev_obs * h
    beta += meta_step_size / (lr + 1e-8) * td_error * p
    trace_bar *= gamma * trace_decay
    trace_bar += lr * prev_obs * (
        1 - gamma * trace_decay * trace_value \
            - gamma * trace_decay * prev_obs * trace_bar
    )
    k = h.clone()
    h = h * (1 - lr * prev_obs ** 2) \
        - prev_h * prev_obs * (trace - prev_obs * lr) \
        + td_error * trace_bar - prev_obs * lr * value_diff
    prev_h = k
    
    lr_decay_mask = ((lr_scale_max > eta) & (prev_obs != 0)).float()
    beta -= torch.log(torch.tensor(lr_decay)) * lr_decay_mask

    # Log
    history.append(StepData(prev_obs, reward, prev_value_pred))

    # Update for next iter
    old_value_pred = value_pred
    prev_obs = obs