In [98]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import jax
from jax import numpy as jnp
from jax.experimental import host_callback

from tqdm.auto import tqdm
from functools import partial
import optax
import matplotlib.pyplot as plt

from flax import linen as nn
from typing import Sequence

from functools import partial
from typing import Callable, Sequence
from flax import struct
from flax import linen as nn
import jax
import jax.numpy as jnp
import optax

import qlearning
import replay
from frozen_lake import EnvState, FrozenLake, ObsType, ActType, RNGKey



In [99]:
jax.devices()

[cuda(id=0)]

In [100]:
from typing import Any
import optax


class Network(nn.Module):
    @nn.compact
    def __call__(self, x):
        return self.param('a', nn.initializers.ones, (1,))*x


x = jnp.ones((16,))
network = Network()
params = network.init(jax.random.PRNGKey(0), x)
y = network.apply(params, x)
print(params)
print(y)


{'params': {'a': Array([1.], dtype=float32)}}
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [101]:
tx = optax.adam(1e-3)
opt_state = tx.init(params)
grads = jax.grad(lambda params: network.apply(params, x).sum())(params)
updates, opt_state = tx.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
y = network.apply(params, x)
print(params)
print(y)

{'params': {'a': Array([0.999], dtype=float32)}}
[0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999 0.999
 0.999 0.999 0.999 0.999]


In [108]:
class Policy(struct.PyTreeNode):
    params: Any
    apply_fn: Callable = struct.field(pytree_node=False)
    eps: float = 1.0

    def fow(self, x):
        return self.apply_fn(self.params, x)
    

params = network.init(jax.random.PRNGKey(0), x)
policy = Policy(params, network.apply, eps=1.0)
print(policy)

tx = optax.adam(1e-3)
opt_state = tx.init(policy)
grads = jax.grad(lambda policy: Policy.fow(policy, x).sum())(policy)

updates, opt_state = tx.update(grads, opt_state, policy)
policy = optax.apply_updates(policy, updates)
print(policy)





Policy(params={'params': {'a': Array([1.], dtype=float32)}}, apply_fn=<bound method Module.apply of Network()>, eps=1.0)
Policy(params={'params': {'a': Array([0.999], dtype=float32)}}, apply_fn=<bound method Module.apply of Network()>, eps=Array(1., dtype=float32))


: 

In [5]:
from utils import ConvNet

n_actions = 5
n_layers = 3

MultiConvNet = nn.vmap(
    ConvNet,
    in_axes=None,
    axis_size=n_actions,
    variable_axes={"params": 0},
    split_rngs={"params": True},
)


x = jnp.ones((8, 8, 3))
y = jnp.stack([x] * 256)

net = MultiConvNet([16, 16, 16], n_actions)
params = net.init(jax.random.PRNGKey(0), x)
jitted = jax.jit(jax.vmap(net.apply, in_axes=(None, 0)))
print(jitted(params, y).shape)
%timeit jitted(params, y)

(256, 5, 5)
188 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
class MangoLayer(struct.PyTreeNode):
    qnet_apply_fn: list[Callable] = struct.field(pytree_node=False)

    @jax.jit
    def get_action(
        self, params: list[optax.Params], rng_key: RNGKey, obs: ObsType, epsilon: float
    ) -> ActType:
        rng_eps, rng_action = jax.random.split(rng_key)
        qval = self.qnet_apply_fn(params, obs)
        return jax.lax.select(
            jax.random.uniform(rng_eps) < epsilon,
            jax.random.randint(rng_action, shape=(), minval=0, maxval=qval.size),
            qval.argmax(),
        )

    def rollout(
        self,
        params: optax.Params,
        rng_key: RNGKey,
        env: FrozenLake,
        n_steps: int,
        epsilon: Sequence[float],
    ):
        def scan_compatible_step(carry, rng_key: RNGKey):
            env_state, obs, mango_state = carry
            rng_action, rng_step, rng_reset = jax.random.split(rng_key, 3)
            mango_state, action = self.get_action(params, mango_state, rng_action, obs, epsilon)
            next_env_state, next_obs, reward, done, info = env.step(rng_step, env_state, action)
            transition = Transition(env_state, obs, action, next_obs, reward, done, info)

            # reset the environment if done
            carry = jax.lax.cond(
                done,
                lambda: env.reset(rng_reset),
                lambda: (next_env_state, next_obs),
            )
            return carry, transition

        rng_reset, *rng_steps = jax.random.split(rng_key, n_steps + 1)
        final_state, transitions = jax.lax.scan(
            scan_compatible_step, env.reset(rng_reset), jnp.array(rng_steps)
        )
        return transitions
