In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, jit, vmap, random, nn, lax
from jax import value_and_grad
import numpy as np
from jax import config
config.update('jax_platform_name', 'cpu')
from jax.lib import xla_bridge
device = xla_bridge.get_backend().platform
print(device)

Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

cpu




In [2]:
def uniform_2D_pc_weights(npc, nact,seed=0,sigma=0.1, alpha=1,envsize=1):
    x = np.linspace(-envsize,envsize,int(npc**0.5))
    xx,yy = np.meshgrid(x,x)
    pc_cent = np.concatenate([xx.reshape(-1)[:,None],yy.reshape(-1)[:,None]],axis=1)
    pc_sigma = np.tile(np.eye(2),(npc,1,1))*sigma
    # pc_sigma = np.tile(np.ones([2,2]),(npc,1,1))*sigma
    pc_constant = np.ones(npc) * alpha
    actor_key, critic_key = random.split(random.PRNGKey(seed), num=2)
    return [jnp.array(pc_cent), jnp.array(pc_sigma), jnp.array(pc_constant), 
            1e-5 * random.normal(actor_key, (npc,nact)), 1e-5 * random.normal(critic_key, (npc,1))]


def predict_placecell(params, x):
    pc_centers, pc_sigmas, pc_constant, actor_weights, critic_weights = params
    diff = x - pc_centers  # Shape: (npc, dim)
    inv_sigma = jnp.linalg.inv(pc_sigmas)  # Shape: (npc, dim, dim)
    exponent = jnp.einsum('ni,nij,nj->n', diff, inv_sigma, diff)
    pcacts = jnp.exp(-0.5 * exponent) * pc_constant**2
    return pcacts

def compute_reward_prediction_error(rewards, values, gamma=0.95):
    new_values = jnp.concatenate([values[1:], jnp.array([[0]])])
    td = rewards + gamma * new_values - values
    return td


def td_loss(params, coords, actions, rewards, gamma, betas):
    aprobs = []
    values = []
    for coord in coords:
        pcact = predict_placecell(params, coord)
        aprob = predict_action(params, pcact)
        value = predict_value(params, pcact)
        aprobs.append(aprob)
        values.append(value)
    aprobs = jnp.array(aprobs)
    values = jnp.array(values)

    log_likelihood = jnp.log(aprobs) * actions  # log probability of action as policy
    tde = jnp.array(compute_reward_prediction_error(rewards[:,None], values, gamma))
    actor_loss = jnp.sum(log_likelihood * lax.stop_gradient(tde))  # log policy * discounted reward
    critic_loss = -jnp.sum(tde ** 2) # grad decent
    tot_loss = actor_loss + betas[0] * critic_loss
    return tot_loss

def predict_value(params, pcact):
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    value = jnp.matmul(pcact, critic_weights)
    return value


def predict_action(params, pcact, beta=1):
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    actout = jnp.matmul(pcact, actor_weights)
    aprob = nn.softmax(beta * actout)
    return aprob

@jit
def update_td_params(params, coords, actions, rewards, etas, gamma, betas):
    loss, grads = value_and_grad(td_loss)(params, coords,actions, rewards, gamma, betas)
    pc_centers, pc_sigmas, pc_constant, actor_weights,critic_weights = params
    dpcc, dpcs, dpca, dact, dcri = grads

    # + for gradient ascent
    pc_eta, sigma_eta,constant_eta, actor_eta, critic_eta = etas
    newpc_centers = pc_centers + pc_eta * dpcc
    newpc_sigma = pc_sigmas + sigma_eta * dpcs
    newpc_const = pc_constant + constant_eta * dpca
    newactor_weights = actor_weights + actor_eta * dact
    newcritic_weights = critic_weights + critic_eta * dcri  # gradient descent
    return [newpc_centers, newpc_sigma,newpc_const, newactor_weights,newcritic_weights], grads, loss


In [3]:
def learn(params, reward, state, onehotg, gamma, etas,beta=1):
    pc_centers, pc_sigmas, pc_constant, actor_weights, critic_weights = params
    pcact = predict_placecell(params, state)
    # newpcact = predict_placecell(params, newstate)
    value =  predict_value(params, pcact)
    # newvalue = predict_value(params, newpcact)
    td = compute_reward_prediction_error(reward[:,None], np.array([value]), gamma)[:,0]

    # td = (reward + gamma * newvalue - value)[0]
    aprob = predict_action(params, pcact)

    # Critic grads
    dcri = pcact[:, None] * td
    
    # Actor grads
    decay = beta * (onehotg[:, None] - aprob[:, None])
    dact = np.dot(pcact[:, None], decay.T) * td
    
    # Grads for field parameters
    post_td = (actor_weights @ decay + critic_weights) * td

    df = state - pc_centers
    inv_sigma = np.linalg.inv(pc_sigmas)

    dpcc = post_td * pcact[:,None] * np.einsum('nji,nj->ni', inv_sigma, df)
    outer = np.einsum('nj,nk->njk',df,df)
    dpcs = 0.5 * (post_td * pcact[:,None])[:,:,None] * np.einsum('njl,njk,nik->nji',inv_sigma, outer, inv_sigma)
    dpca = (post_td * pcact[:,None] * (2/pc_constant[:,None]))[:,0]

    grads = [dpcc, dpcs, dpca, dact, dcri]  # dpcc needs to be transposed back
    
    for p in range(len(params)):
        params[p] += etas[p] * grads[p]

    return params, grads, td

In [4]:
# gradients computed by jax

npc = 9
nact = 4
seed = 2020
sigma = 0.1
alpha = 0.5
envsize=1.0
gamma= 0.6

# model 
params = uniform_2D_pc_weights(npc, nact, seed, sigma=sigma, alpha=alpha, envsize=envsize)
#params[3] /=params[3]
#params[4] /=params[4]
#params[4] *= 1
etas = [1, 1, 1, 1, 1]
betas = [0.5]

# dataset
coords = np.array([[0.0, 0.0]])
rewards = np.array([1.0])
actions = np.array([[0,1,0,0]])


In [5]:
print(params[1][0])
print(jnp.linalg.det(params[1]))
print(jnp.linalg.inv(params[1])[0])

[[0.1 0. ]
 [0.  0.1]]
[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]
[[10.  0.]
 [ 0. 10.]]


In [6]:
# update by jax
_, jaxgrads, loss = update_td_params(params.copy(), coords, actions, rewards, etas, gamma, betas)

for g in jaxgrads[1]: print(g)

[[-1.0465312e-08 -1.0465312e-08]
 [-1.0465312e-08 -1.0465312e-08]]
[[0.000000e+00 0.000000e+00]
 [0.000000e+00 4.299794e-07]]
[[ 1.6479046e-08 -1.6479046e-08]
 [-1.6479046e-08  1.6479046e-08]]
[[9.997724e-07 0.000000e+00]
 [0.000000e+00 0.000000e+00]]
[[0. 0.]
 [0. 0.]]
[[8.2153485e-07 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00]]
[[-3.7963463e-09  3.7963463e-09]
 [ 3.7963463e-09 -3.7963463e-09]]
[[ 0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -2.1201467e-06]]
[[-1.6723596e-09 -1.6723596e-09]
 [-1.6723596e-09 -1.6723596e-09]]


In [7]:
# update by numpy
_, npgrads, td = learn(params.copy(), rewards, coords[0], actions[0], gamma, etas) 

for g in npgrads[1]: print(g)

[[-1.0465313e-08 -1.0465313e-08]
 [-1.0465313e-08 -1.0465313e-08]]
[[0.000000e+00 0.000000e+00]
 [0.000000e+00 4.299794e-07]]
[[ 1.6479046e-08 -1.6479046e-08]
 [-1.6479046e-08  1.6479046e-08]]
[[9.997724e-07 0.000000e+00]
 [0.000000e+00 0.000000e+00]]
[[-0. -0.]
 [-0. -0.]]
[[8.2153485e-07 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00]]
[[-3.7963463e-09  3.7963463e-09]
 [ 3.7963463e-09 -3.7963463e-09]]
[[-0.0000000e+00 -0.0000000e+00]
 [-0.0000000e+00 -2.1201467e-06]]
[[-1.6723598e-09 -1.6723598e-09]
 [-1.6723598e-09 -1.6723598e-09]]


In [8]:
idx = 7
pidx = 3
print(np.array(jaxgrads[pidx])[idx])
print(np.array(npgrads[pidx])[idx])
print(np.array(jaxgrads[pidx])[idx] /np.array(npgrads[pidx])[idx])


[-0.00042112  0.00126337 -0.00042112 -0.00042112]
[-0.00042112  0.00126337 -0.00042112 -0.00042112]
[0.99999994 1.0000001  1.0000001  1.0000001 ]


In [9]:
idx = 2
pidx = 2
print(np.array(jaxgrads[pidx])[idx])
print(np.array(npgrads[pidx])[idx])
print(np.array(jaxgrads[pidx])[idx] /np.array(npgrads[pidx])[idx])


1.3183236e-09
1.3183237e-09
0.99999994


In [10]:
pidx = 1
print(np.array(jaxgrads[pidx]) /np.array(npgrads[pidx]))

[[[0.99999994 0.99999994]
  [0.99999994 0.99999994]]

 [[       nan        nan]
  [       nan 1.        ]]

 [[1.         1.        ]
  [1.         1.        ]]

 [[1.                nan]
  [       nan        nan]]

 [[       nan        nan]
  [       nan        nan]]

 [[1.                nan]
  [       nan        nan]]

 [[1.         1.        ]
  [1.         1.        ]]

 [[       nan        nan]
  [       nan 1.        ]]

 [[0.9999999  0.9999999 ]
  [0.9999999  0.9999999 ]]]


  print(np.array(jaxgrads[pidx]) /np.array(npgrads[pidx]))


In [11]:
jnp.linalg.inv(params[1])

Array([[[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]],

       [[10.,  0.],
        [ 0., 10.]]], dtype=float32)

In [15]:
def rbf(s, c, W):
    """
    s : 2d vector
    c : center of the basis function (2d vector)
    W : covariance matrix (2x2 matrix with diagonals = 0.1^2)
    """
    diff = s - c
    print(diff.shape)
    exponent =  np.dot(np.dot(diff.T, np.linalg.inv(W)), diff)
    return np.exp(-0.5 * exponent)

print(rbf(coords[0], params[0], params[1]))

(9, 2)


ValueError: shapes (2,9) and (9,2,2) not aligned: 9 (dim 1) != 2 (dim 1)

In [None]:
def radial_basis_function_no_inverse(s, c, W):
    """
    s : 2D vector
    c : center of the basis function (2D vector)
    W : covariance matrix (2x2 matrix with diagonals = 0.1^2)
    """
    diff = s - c
    result = np.dot(diff.T, np.linalg.solve(W, diff))
    return np.exp(-0.5 * result)

In [18]:
sig = np.array([[0.1,0.0],[0.0,0.1]])
B = np.array([0.5,0.2])

print(np.linalg.solve(sig, B))
print(np.linalg.inv(sig) @ B)

[5. 2.]
[5. 2.]
