In [1]:
import flax
import gym
import jax
import jax.numpy as jnp
import numpy as np

from argparse import Namespace
from flax import linen as nn
from jax import grad, jit, vmap, random

from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer

In [2]:
cfg_dict = {
    "batch_size": 128,
    "lr": 3e-4,
    "max_timesteps": 1000000,
    "memory_size": 1000000,
    "env": "MountainCar-v0",
    "seed": 0,
}
cfg = Namespace(**cfg_dict)

In [3]:
env = gym.make(cfg.env)

  logger.warn(


In [4]:
cfg.act_dim = (env.action_space.n, )

In [5]:
cfg.obs_dim = env.observation_space.shape

In [6]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [7]:
cfg

Namespace(batch_size=128, lr=0.0003, max_timesteps=1000000, memory_size=1000000, env='MountainCar-v0', seed=0, act_dim=(3,), obs_dim=(2,), h_state_dim=(1,), rew_dim=(1,), rng=RandomState(MT19937) at 0x113D26C40)

In [None]:
cfg.rng = np.random.RandomState(cfg.seed)

In [9]:
buffer = NextStateNumPyBuffer(
    memory_size=cfg.memory_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.rng,
)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  self._checkpoint_idxes = np.ones(shape=memory_size, dtype=np.bool)


In [None]:
def interact(policy, buffecfg):
    max_timesteps = cfg.max_timesteps
    
    obs = env.reset()
    for timestep_i in range(max_timesteps):
        act = policy(obs)
        next_obs, rew, done, info = env.step(act)

In [None]:
v = random.normal(key,(4,))
print("Original v:")
print(v)
print("Gradient of f taken at point v")
print(jax.grad(f)(v)) # should be equal to v !

In [None]:
def predict(W, b, x):
  return jnp.dot(x, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [None]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

In [None]:
# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(W, b, x, y, lr):
  W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)
  return W, b

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse(W_hat, b_hat, x_samples, y_samples))