In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, jit, vmap, random, nn, lax
from jax import value_and_grad
import numpy as np
from jax import config
config.update('jax_platform_name', 'cpu')
from jax.lib import xla_bridge
device = xla_bridge.get_backend().platform
print(device)

Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

cpu




In [2]:
def uniform_2D_pc_weights(npc, nact,seed,sigma=0.1, alpha=1,envsize=1, numsigma=1):
    x = np.linspace(-envsize,envsize,int(npc**0.5))
    xx,yy = np.meshgrid(x,x)
    pc_cent = np.concatenate([xx.reshape(-1)[:,None],yy.reshape(-1)[:,None]],axis=1)
    pc_sigma = jnp.ones([npc,numsigma])*sigma
    pc_constant = jnp.ones(npc) * alpha #/ jnp.sqrt((2*jnp.pi * jnp.sum(pc_sigma,axis=1)**2))
    actor_key, critic_key = random.split(random.PRNGKey(seed), num=2)
    return [jnp.array(pc_cent), jnp.array(pc_sigma), jnp.array(pc_constant), 
            1e-5 * random.normal(actor_key, (npc,nact)), 1e-5 * random.normal(critic_key, (npc,1))]


def predict_placecell(params, x):
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    exponent = jnp.sum((x - pc_centers)**2 / (2 * pc_sigmas ** 2),axis=1)
    pcact = jnp.exp(-exponent) * pc_constant #/ jnp.sqrt((2*jnp.pi * jnp.sum(pc_sigmas,axis=1)**2))
    return pcact


def compute_reward_prediction_error(rewards, values, gamma=0.95):
    new_values = jnp.concatenate([values[1:], jnp.array([[0]])])
    td = rewards + gamma * new_values - values
    return td


def td_loss(params, coords, actions, rewards, gamma, betas):
    aprobs = []
    values = []
    for coord in coords:
        pcact = predict_placecell(params, coord)
        aprob = predict_action(params, pcact)
        value = predict_value(params, pcact)
        aprobs.append(aprob)
        values.append(value)
    aprobs = jnp.array(aprobs)
    values = jnp.array(values)

    log_likelihood = jnp.log(aprobs) * actions  # log probability of action as policy
    tde = jnp.array(compute_reward_prediction_error(rewards[:,None], values, gamma))
    actor_loss = jnp.sum(log_likelihood * lax.stop_gradient(tde))  # log policy * discounted reward
    critic_loss = -jnp.sum(tde ** 2) # grad decent
    tot_loss = actor_loss + betas[0] * critic_loss
    return tot_loss

def predict_value(params, pcact):
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    value = jnp.matmul(pcact, critic_weights)
    return value


def predict_action(params, pcact, beta=1):
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    actout = jnp.matmul(pcact, actor_weights)
    aprob = nn.softmax(beta * actout)
    return aprob

@jit
def update_td_params(params, coords, actions, rewards, etas, gamma, betas):
    loss, grads = value_and_grad(td_loss)(params, coords,actions, rewards, gamma, betas)
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    dpcc, dpcs, dpca, dact, dcri = grads

    # + for gradient ascent
    pc_eta, sigma_eta,constant_eta, actor_eta, critic_eta = etas
    newpc_centers = pc_centers + pc_eta * dpcc
    newpc_sigma = pc_sigmas + sigma_eta * dpcs
    newpc_const = pc_constant + constant_eta * dpca
    newactor_weights = actor_weights + actor_eta * dact
    newcritic_weights = critic_weights + critic_eta * dcri  # gradient descent
    return [newpc_centers, newpc_sigma,newpc_const, newactor_weights,newcritic_weights], grads, loss


In [3]:
def learn(params, reward, state, onehotg, gamma, etas,beta=1):
    pc_centers, pc_sigmas, pc_constant, actor_weights, critic_weights = params
    pcact = predict_placecell(params, state)
    # newpcact = predict_placecell(params, newstate)
    value =  predict_value(params, pcact)
    # newvalue = predict_value(params, newpcact)
    td = compute_reward_prediction_error(reward[:,None], np.array([value]), gamma)[:,0]

    # td = (reward + gamma * newvalue - value)[0]
    aprob = predict_action(params, pcact)

    # Critic grads
    dcri = pcact[:, None] * td
    
    # Actor grads
    decay = beta * (onehotg[:, None] - aprob[:, None])
    dact = np.dot(pcact[:, None], decay.T) * td
    
    # Grads for field parameters
    post_td = (actor_weights @ decay + critic_weights) * td

    df = state - pc_centers
    inv_sigma = np.linalg.inv(np.diag(pc_sigmas.flatten()))

    dpcc = post_td * pcact[:, None] * (inv_sigma @ df.T).T
    outer = np.einsum('nj,nk->njk', df, df)
    dpcs = 0.5 * (post_td * pcact[:, None])[:, :, None] * (inv_sigma @ outer @ inv_sigma)
    dpca = (post_td * pcact[:, None] * (2 / pc_constant[:, None]))[:, 0]

    grads = [dpcc, dpcs, dpca, dact, dcri]  # dpcc needs to be transposed back
    
    for p in range(len(params)):
        params[p] += etas[p] * grads[p]

    return params, grads, td

In [4]:
# gradients computed by jax

npc = 8**2
nact = 4
seed = 2020
sigma = 0.1
alpha = 0.5
envsize=1.0
gamma= 0.6

# model 
params = uniform_2D_pc_weights(npc, nact, seed, sigma=sigma, alpha=alpha, envsize=envsize)
#params[3] /=params[3]
#params[4] /=params[4]
#params[4] *= 1
etas = [1, 1, 1, 1, 1]
betas = [0.5]

# dataset
coords = np.array([[0.0, 0.0]])
rewards = np.array([1.0])
actions = np.array([[0,1,0,0]])


In [5]:
print(params[1][0])

[0.1]


In [6]:
# update by jax
_, jaxgrads, loss = update_td_params(params.copy(), coords, actions, rewards, etas, gamma, betas)

for g in jaxgrads[4]: print(g)

[0.]
[8.040925e-34]
[9.904398e-27]
[3.476075e-23]
[3.476075e-23]
[9.904398e-27]
[8.040925e-34]
[0.]
[8.040925e-34]
[3.476075e-23]
[4.2816507e-16]
[1.5026971e-12]
[1.5026971e-12]
[4.2816507e-16]
[3.476075e-23]
[8.040925e-34]
[9.904398e-27]
[4.2816507e-16]
[5.273907e-09]
[1.850945e-05]
[1.850945e-05]
[5.273907e-09]
[4.2816507e-16]
[9.904398e-27]
[3.476075e-23]
[1.5026971e-12]
[1.850945e-05]
[0.06496122]
[0.06496122]
[1.850945e-05]
[1.5026971e-12]
[3.476075e-23]
[3.476075e-23]
[1.5026971e-12]
[1.850945e-05]
[0.06496122]
[0.06496122]
[1.850945e-05]
[1.5026971e-12]
[3.476075e-23]
[9.904398e-27]
[4.2816507e-16]
[5.273907e-09]
[1.850945e-05]
[1.850945e-05]
[5.273907e-09]
[4.2816507e-16]
[9.904398e-27]
[8.040925e-34]
[3.476075e-23]
[4.2816507e-16]
[1.5026971e-12]
[1.5026971e-12]
[4.2816507e-16]
[3.476075e-23]
[8.040925e-34]
[0.]
[8.040925e-34]
[9.904398e-27]
[3.476075e-23]
[3.476075e-23]
[9.904398e-27]
[8.040925e-34]
[0.]


In [7]:
# update by numpy
_, npgrads, td = learn(params.copy(), rewards, coords[0], actions[0], gamma, etas) 

for g in npgrads[4]: print(g)

TypeError: dot_general requires contracting dimensions to have the same shape, got (64,) and (2,).

In [None]:
idx = 7
pidx = 3
print(np.array(jaxgrads[pidx])[idx])
print(np.array(npgrads[pidx])[idx])
print(np.array(jaxgrads[pidx])[idx] /np.array(npgrads[pidx])[idx])


In [None]:
idx = 2
pidx = 2
print(np.array(jaxgrads[pidx])[idx])
print(np.array(npgrads[pidx])[idx])
print(np.array(jaxgrads[pidx])[idx] /np.array(npgrads[pidx])[idx])


In [None]:
pidx = 0
print(np.array(jaxgrads[pidx]) /np.array(npgrads[pidx]))

In [None]:
jnp.linalg.inv(params[1])