In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"
import jax.numpy as jnp
import matplotlib.pyplot as plt
from backend import *
import numpy as np
from jax import config
config.update('jax_platform_name', 'cpu')
from jax.lib import xla_bridge
device = xla_bridge.get_backend().platform
print(device)

Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

cpu




In [2]:
def td_loss(params, coords, actions, rewards, gamma, betas):
    aprobs = []
    values = []
    for coord in coords:
        pcact = predict_placecell(params, coord)
        aprob = predict_action(params, pcact)
        value = predict_value(params, pcact)
        aprobs.append(aprob)
        values.append(value)
    aprobs = jnp.array(aprobs)
    values = jnp.array(values)

    log_likelihood = jnp.log(aprobs) * actions  # log probability of action as policy
    tde = jnp.array(compute_reward_prediction_error(rewards[:,None], values, gamma))
    actor_loss = jnp.sum(log_likelihood * lax.stop_gradient(tde))  # log policy * discounted reward
    critic_loss = -jnp.sum(tde ** 2) # grad decent
    tot_loss = actor_loss + betas[0] * critic_loss
    return tot_loss

@jit
def update_td_params(params, coords, actions, rewards, etas, gamma, betas):
    loss, grads = value_and_grad(td_loss)(params, coords,actions, rewards, gamma, betas)
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    dpcc, dpcs, dpca, dact, dcri = grads

    # + for gradient ascent
    pc_eta, sigma_eta,constant_eta, actor_eta, critic_eta = etas
    newpc_centers = pc_centers + pc_eta * dpcc
    newpc_sigma = pc_sigmas + sigma_eta * dpcs
    newpc_const = pc_constant + constant_eta * dpca
    newactor_weights = actor_weights + actor_eta * dact
    newcritic_weights = critic_weights + critic_eta * dcri  # gradient descent
    return [newpc_centers, newpc_sigma,newpc_const, newactor_weights,newcritic_weights], grads, loss


In [3]:
def learn(params, reward, state, onehotg, gamma, etas,beta=2):
    pcact = predict_placecell(params, state)
    # newpcact = predict_placecell(params, newstate)
    value =  predict_value(params, pcact)
    # newvalue = predict_value(params, newpcact)
    td = compute_reward_prediction_error(reward[:,None], np.array([value]), gamma)[:,0]

    # td = (reward + gamma * newvalue - value)[0]
    aprob = predict_action(params, pcact)
    onehotg = onehotg[0]

    # get critic grads
    dcri =  pcact[:,None] * td

    # get actor grads
    decay = beta * (onehotg[:,None]- aprob[:,None])  # derived from softmax grads
    # decay = onehotg[:,None]  # foster et al. 2000 rule, simplified form of the derivative
    dact = (pcact[:,None] @ decay.T) * td

    # get phi grads: dp = phi' (W^actor @ act + W^critic) * td
    post_td = (params[3] @ decay + params[4]) * td

    dpcc = (post_td * (pcact[:,None]) * ((state - params[0])/params[1]**2)[:,None])[:,0]
    dpcs = (post_td * (pcact[:,None]) * ((state - params[0])**2/params[1]**3)[:,None])[:,0]
    dpca = (post_td * (pcact[:,None]) * (2 / params[2][:,None]))[:,0]

    grads = [dpcc, dpcs, dpca, dact, dcri]

    #update weights
    for p in range(len(params)):
        params[p] += etas[p] * grads[p]

    return params, grads, td

In [4]:
# gradients computed by jax

npc = 2
nact = 2
seed = 2020
sigma = 0.25
alpha = 0.1
envsize=1.0
gamma= 0.6

# model 
params = uniform_pc_weights(npc, nact, seed, sigma=sigma, alpha=alpha, envsize=envsize)
#params[3] /=params[3]
#params[4] /=params[4]
#params[4] *= 1
etas = [1, 1, 1, 1, 1]
betas = [0.5]

# dataset
coords = np.array([-0.75])
rewards = np.array([1.0])
actions = np.array([[0,1]])


In [5]:
# update by jax
_, jaxgrads, loss = update_td_params(params.copy(), coords, actions, rewards, etas, gamma, betas)
print(loss)

for g in jaxgrads: print(g)

-1.1931472
[-3.592929e-07  8.143338e-17]
[-3.5929290e-07 -5.7003364e-16]
[-1.7964643e-06 -5.8166699e-17]
[[-6.0653072e-03  6.0653081e-03]
 [-2.2897350e-13  2.2897352e-13]]
[[6.065307e-03]
 [2.289735e-13]]


In [6]:
# update by numpy
_, npgrads, td = learn(params.copy(), rewards, coords, actions, gamma, etas)
print(td)

for g in npgrads: print(g)

[1.]
[-3.592929e-07  8.143337e-17]
[-3.592929e-07 -5.700336e-16]
[-1.7964645e-06 -5.8166692e-17]
[[-6.065307e-03  6.065307e-03]
 [-2.289735e-13  2.289735e-13]]
[[6.065307e-03]
 [2.289735e-13]]


In [7]:
np.array(npgrads[:3])/np.array(jaxgrads[:3])

array([[1.        , 0.99999994],
       [1.        , 0.9999999 ],
       [1.0000001 , 0.9999999 ]], dtype=float32)