In [2]:
import os
import random
import time
import wandb
import numpy as np
import gymnasium as gym
import flax.linen as nn
from flax.training.train_state import TrainState
from flax.training.common_utils import onehot
import jax.numpy as jnp
import jax
import optax

from torch.utils.tensorboard import SummaryWriter

In [3]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id,render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env,f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env
    return thunk

In [4]:
class MLPPolicy(nn.Module):
    action_dims: int

    @nn.compact
    def __call__(self, input: jnp.ndarray):
        x = nn.Dense(16)(input)
        x = nn.relu(x)
        x = nn.Dense(16)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dims)(x)
        probs = nn.softmax(x)
        log_probs = nn.log_softmax(x)

        return probs,log_probs

In [5]:
def train(env_id,gamma,episodes,max_termination,seed,num_envs,learning_rate):
    run_name = f"{env_id}__{seed}_{int(time.time())}"
    wandb.init(
        project="reinforce-classic-control-benchmark",
        config={
            "env_id":env_id,
            "gamma":gamma,
            "epsisodes":episodes,
            "max_termination":max_termination,
            "seed":seed,
            "num_envs":num_envs,
            "lr":learning_rate,
        },
        sync_tensorboard=True,
        monitor_gym=True,
        name=run_name
    )
    writer = SummaryWriter(f"runs/{run_name}")

    random.seed(seed)
    np.random.seed(seed=seed)
    key = jax.random.PRNGKey(seed)

    env = make_env(env_id, seed, 0, True, run_name)()

    obs, _ = env.reset(seed=seed)
    action_dims = env.action_space.n

    policy = MLPPolicy(action_dims=action_dims)

    policy_state = TrainState.create(
        apply_fn=policy.apply,
        params=policy.init(key,obs),
        tx=optax.adam(learning_rate=learning_rate)
    )

    policy.apply = jax.jit(policy.apply)

    @jax.jit
    def update(policy_state,observations,actions,rewards):
        def loss_fn(params):
            _,logprobs = policy.apply(params,observations)
            onehot_actions = onehot(actions,num_classes=logprobs.shape[-1]).reshape(logprobs.shape)
            selected_action_logprobs = jnp.sum(onehot_actions * logprobs, axis=-1)
            loss_value = -jnp.sum(selected_action_logprobs * rewards)
            return loss_value

        loss_value,grads = jax.value_and_grad(loss_fn)(policy_state.params)
        policy_state = policy_state.apply_gradients(grads=grads)
        return loss_value,policy_state

    for episode in range(episodes):
        rewards = []
        # logprobs = []
        actions = []
        observations = []

        done = False
        obs, _ = env.reset(seed=seed)
        for _ in range(max_termination):
            observations.append(obs)
            prob,_ = policy.apply(policy_state.params,jnp.array(obs)[None,...])
            prob = jax.device_get(prob)[0]
            action = np.random.choice(action_dims,p=prob)
            obs, reward, done, _, _ = env.step(int(action))
            actions.append(action)
            rewards.append(reward)
            if done:
                break

        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0,R)

        returns = jnp.array(returns)
        eps = np.finfo(np.float32).eps.item()
        returns = (returns - returns.mean())/(returns.std() + eps)
        actions = jnp.array(actions)
        observations = jnp.array(observations)

        loss, policy_state = update(
            policy_state,
            observations,
            actions,
            returns
        )
        print(f"Episode:{episode}   Loss:{loss}  Reward:{sum(rewards)} Episode Lenght:{len(rewards)}")
        writer.add_scalar("loss",jax.device_get(loss),episode)
        writer.add_scalar("rewards",sum(rewards),episode)

    writer.close()
    env.close()

    return policy_state


In [10]:
envs = {"cartPole":"CartPole-v1","acrobot":"Acrobot-v1","mountainCar":"MountainCar-v0"}
gamma = 0.9
episodes = 4000
max_termination = 1000
seed = 0
num_envs = 1
learning_rate = 0.001

In [11]:
train(envs["mountainCar"],gamma,episodes,max_termination,seed,num_envs,learning_rate)

VBox(children=(Label(value='0.734 MB of 0.734 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,▄▅▆▃▅▆▆▄▅▄▃▃▄▃▅▅▃▆▆▅▄▆▄▅▅▃▄▄▇▇▆▆▅▃▅▁█▃▃▅
rewards,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
global_step,3999.0
loss,1.16523
rewards,-1000.0




Moviepy - Building video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-0.mp4.
Moviepy - Writing video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-0.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-0.mp4
Episode:0   Loss:-1.880661964416504  Reward:-1000.0 Episode Lenght:1000
Moviepy - Building video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-1.mp4.
Moviepy - Writing video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-1.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-1.mp4
Episode:1   Loss:3.0718746185302734  Reward:-1000.0 Episode Lenght:1000
Episode:2   Loss:-3.5280046463012695  Reward:-1000.0 Episode Lenght:1000
Episode:3   Loss:11.557256698608398  Reward:-1000.0 Episode Lenght:1000
Episode:4   Loss:-1.5823020935058594  Reward:-1000.0 Episode Lenght:1000
Episode:5   Loss:0.5212497711181641  Reward:-1000.0 Episode Lenght:1000
Episode:6   Loss:4.905426025390625  Reward:-1000.0 Episode Lenght:1000
Episode:7   Loss:0.02987957000732422  Reward:-1000.0 Episode Lenght:1000
Moviepy - Building video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-8.mp4.
Moviepy - Writing video /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-8.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-8.mp4
Episode:8   Loss:-2.1715354919433594  Reward:-1000.0 Episode Lenght:1000
Episode:9   Loss:-0.3814048767089844  Reward:-1000.0 Episode Lenght:1000
Episode:10   Loss:-0.028176307678222656  Reward:-1000.0 Episode Lenght:1000
Episode:11   Loss:7.27703857421875  Reward:-726.0 Episode Lenght:726
Episode:12   Loss:-7.04747200012207  Reward:-1000.0 Episode Lenght:1000
Episode:13   Loss:3.608450412750244  Reward:-1000.0 Episode Lenght:1000
Episode:14   Loss:-3.529386520385742  Reward:-1000.0 Episode Lenght:1000
Episode:15   Loss:-9.172393798828125  Reward:-1000.0 Episode Lenght:1000
Episode:16   Loss:-6.013286590576172  Reward:-1000.0 Episode Lenght:1000
Episode:17   Loss:-2.438650131225586  Reward:-1000.0 Episode Lenght:1000
Episode:18   Loss:-20.434568405151367  Reward:-1000.0 Episode Lenght:1000
Episode:19   Loss:-4.137092590332031  Reward:-1000.0 Episode Lenght:1000
Episode:

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-27.mp4
Episode:27   Loss:-10.739022254943848  Reward:-1000.0 Episode Lenght:1000
Episode:28   Loss:5.1284074783325195  Reward:-1000.0 Episode Lenght:1000
Episode:29   Loss:-6.559283256530762  Reward:-1000.0 Episode Lenght:1000
Episode:30   Loss:-3.032811164855957  Reward:-1000.0 Episode Lenght:1000
Episode:31   Loss:-4.378208160400391  Reward:-1000.0 Episode Lenght:1000
Episode:32   Loss:-15.733776092529297  Reward:-1000.0 Episode Lenght:1000
Episode:33   Loss:1.3398151397705078  Reward:-1000.0 Episode Lenght:1000
Episode:34   Loss:-6.935384750366211  Reward:-1000.0 Episode Lenght:1000
Episode:35   Loss:-5.177606582641602  Reward:-1000.0 Episode Lenght:1000
Episode:36   Loss:1.6409168243408203  Reward:-1000.0 Episode Lenght:1000
Episode:37   Loss:0.9207248687744141  Reward:-1000.0 Episode Lenght:1000
Episode:38   Loss:-28.10767936706543  Reward:-1000.0 Episode Lenght:1000
Epi

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-64.mp4
Episode:64   Loss:-2.4087953567504883  Reward:-1000.0 Episode Lenght:1000
Episode:65   Loss:20.77193832397461  Reward:-1000.0 Episode Lenght:1000
Episode:66   Loss:-32.69367599487305  Reward:-1000.0 Episode Lenght:1000
Episode:67   Loss:-0.2896108627319336  Reward:-1000.0 Episode Lenght:1000
Episode:68   Loss:2.1710433959960938  Reward:-1000.0 Episode Lenght:1000
Episode:69   Loss:6.030326843261719  Reward:-1000.0 Episode Lenght:1000
Episode:70   Loss:-0.4811859130859375  Reward:-1000.0 Episode Lenght:1000
Episode:71   Loss:-27.125520706176758  Reward:-1000.0 Episode Lenght:1000
Episode:72   Loss:32.11925506591797  Reward:-1000.0 Episode Lenght:1000
Episode:73   Loss:-22.38017463684082  Reward:-1000.0 Episode Lenght:1000
Episode:74   Loss:-4.518184661865234  Reward:-1000.0 Episode Lenght:1000
Episode:75   Loss:1.8288021087646484  Reward:-1000.0 Episode Lenght:1000
Epis

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-125.mp4
Episode:125   Loss:0.3138999938964844  Reward:-1000.0 Episode Lenght:1000
Episode:126   Loss:22.42351531982422  Reward:-1000.0 Episode Lenght:1000
Episode:127   Loss:-10.89342212677002  Reward:-1000.0 Episode Lenght:1000
Episode:128   Loss:10.105932235717773  Reward:-1000.0 Episode Lenght:1000
Episode:129   Loss:-3.061141014099121  Reward:-1000.0 Episode Lenght:1000
Episode:130   Loss:2.1764144897460938  Reward:-1000.0 Episode Lenght:1000
Episode:131   Loss:7.896398544311523  Reward:-1000.0 Episode Lenght:1000
Episode:132   Loss:0.574101448059082  Reward:-1000.0 Episode Lenght:1000
Episode:133   Loss:15.291563034057617  Reward:-1000.0 Episode Lenght:1000
Episode:134   Loss:17.06829261779785  Reward:-1000.0 Episode Lenght:1000
Episode:135   Loss:-3.675874710083008  Reward:-1000.0 Episode Lenght:1000
Episode:136   Loss:5.621259689331055  Reward:-1000.0 Episode Lenght:10

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-216.mp4
Episode:216   Loss:-5.870028495788574  Reward:-1000.0 Episode Lenght:1000
Episode:217   Loss:-26.386138916015625  Reward:-1000.0 Episode Lenght:1000
Episode:218   Loss:-9.17416763305664  Reward:-1000.0 Episode Lenght:1000
Episode:219   Loss:-5.776474952697754  Reward:-1000.0 Episode Lenght:1000
Episode:220   Loss:-14.727201461791992  Reward:-1000.0 Episode Lenght:1000
Episode:221   Loss:18.022029876708984  Reward:-1000.0 Episode Lenght:1000
Episode:222   Loss:4.4015350341796875  Reward:-1000.0 Episode Lenght:1000
Episode:223   Loss:24.546781539916992  Reward:-1000.0 Episode Lenght:1000
Episode:224   Loss:-17.32560157775879  Reward:-1000.0 Episode Lenght:1000
Episode:225   Loss:-7.887300968170166  Reward:-1000.0 Episode Lenght:1000
Episode:226   Loss:15.291756629943848  Reward:-1000.0 Episode Lenght:1000
Episode:227   Loss:9.858161926269531  Reward:-1000.0 Episode Leng

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-343.mp4
Episode:343   Loss:12.422237396240234  Reward:-1000.0 Episode Lenght:1000
Episode:344   Loss:6.038848876953125  Reward:-1000.0 Episode Lenght:1000
Episode:345   Loss:15.461750030517578  Reward:-1000.0 Episode Lenght:1000
Episode:346   Loss:-22.468969345092773  Reward:-1000.0 Episode Lenght:1000
Episode:347   Loss:-4.447035789489746  Reward:-1000.0 Episode Lenght:1000
Episode:348   Loss:1.8423271179199219  Reward:-1000.0 Episode Lenght:1000
Episode:349   Loss:-2.5695419311523438  Reward:-1000.0 Episode Lenght:1000
Episode:350   Loss:12.182157516479492  Reward:-1000.0 Episode Lenght:1000
Episode:351   Loss:8.423749923706055  Reward:-1000.0 Episode Lenght:1000
Episode:352   Loss:27.582931518554688  Reward:-1000.0 Episode Lenght:1000
Episode:353   Loss:-7.088788986206055  Reward:-1000.0 Episode Lenght:1000
Episode:354   Loss:3.006956100463867  Reward:-1000.0 Episode Lengh

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-512.mp4
Episode:512   Loss:1.8703584671020508  Reward:-1000.0 Episode Lenght:1000
Episode:513   Loss:9.70389461517334  Reward:-1000.0 Episode Lenght:1000
Episode:514   Loss:-15.540826797485352  Reward:-1000.0 Episode Lenght:1000
Episode:515   Loss:-9.823139190673828  Reward:-1000.0 Episode Lenght:1000
Episode:516   Loss:-2.9429121017456055  Reward:-1000.0 Episode Lenght:1000
Episode:517   Loss:2.7230749130249023  Reward:-1000.0 Episode Lenght:1000
Episode:518   Loss:7.8421783447265625  Reward:-1000.0 Episode Lenght:1000
Episode:519   Loss:-20.06130599975586  Reward:-1000.0 Episode Lenght:1000
Episode:520   Loss:6.745099067687988  Reward:-1000.0 Episode Lenght:1000
Episode:521   Loss:-1.1125450134277344  Reward:-1000.0 Episode Lenght:1000
Episode:522   Loss:18.854780197143555  Reward:-1000.0 Episode Lenght:1000
Episode:523   Loss:7.490301132202148  Reward:-1000.0 Episode Lengh

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-729.mp4
Episode:729   Loss:13.145273208618164  Reward:-1000.0 Episode Lenght:1000
Episode:730   Loss:4.297815322875977  Reward:-1000.0 Episode Lenght:1000
Episode:731   Loss:-17.79135513305664  Reward:-1000.0 Episode Lenght:1000
Episode:732   Loss:-29.010480880737305  Reward:-1000.0 Episode Lenght:1000
Episode:733   Loss:-6.029407501220703  Reward:-1000.0 Episode Lenght:1000
Episode:734   Loss:-2.839508056640625  Reward:-1000.0 Episode Lenght:1000
Episode:735   Loss:-7.344195365905762  Reward:-1000.0 Episode Lenght:1000
Episode:736   Loss:-17.151811599731445  Reward:-1000.0 Episode Lenght:1000
Episode:737   Loss:1.536433219909668  Reward:-1000.0 Episode Lenght:1000
Episode:738   Loss:-13.526663780212402  Reward:-1000.0 Episode Lenght:1000
Episode:739   Loss:2.9349870681762695  Reward:-1000.0 Episode Lenght:1000
Episode:740   Loss:7.367565155029297  Reward:-1000.0 Episode Leng

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-1000.mp4
Episode:1000   Loss:-0.9080810546875  Reward:-1000.0 Episode Lenght:1000
Episode:1001   Loss:10.299215316772461  Reward:-1000.0 Episode Lenght:1000
Episode:1002   Loss:10.204967498779297  Reward:-1000.0 Episode Lenght:1000
Episode:1003   Loss:-3.2799606323242188  Reward:-1000.0 Episode Lenght:1000
Episode:1004   Loss:-22.775943756103516  Reward:-1000.0 Episode Lenght:1000
Episode:1005   Loss:7.537966728210449  Reward:-1000.0 Episode Lenght:1000
Episode:1006   Loss:-0.7899265289306641  Reward:-1000.0 Episode Lenght:1000
Episode:1007   Loss:-4.808506011962891  Reward:-1000.0 Episode Lenght:1000
Episode:1008   Loss:-0.9596967697143555  Reward:-1000.0 Episode Lenght:1000
Episode:1009   Loss:-20.238126754760742  Reward:-1000.0 Episode Lenght:1000
Episode:1010   Loss:-10.253435134887695  Reward:-1000.0 Episode Lenght:1000
Episode:1011   Loss:4.573627471923828  Reward:-1000

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-2000.mp4
Episode:2000   Loss:12.280218124389648  Reward:-1000.0 Episode Lenght:1000
Episode:2001   Loss:-5.962685585021973  Reward:-1000.0 Episode Lenght:1000
Episode:2002   Loss:24.006797790527344  Reward:-1000.0 Episode Lenght:1000
Episode:2003   Loss:19.620420455932617  Reward:-1000.0 Episode Lenght:1000
Episode:2004   Loss:18.79877471923828  Reward:-1000.0 Episode Lenght:1000
Episode:2005   Loss:-1.6074481010437012  Reward:-1000.0 Episode Lenght:1000
Episode:2006   Loss:-24.88214874267578  Reward:-1000.0 Episode Lenght:1000
Episode:2007   Loss:47.646141052246094  Reward:-1000.0 Episode Lenght:1000
Episode:2008   Loss:-4.525152206420898  Reward:-1000.0 Episode Lenght:1000
Episode:2009   Loss:2.6621837615966797  Reward:-1000.0 Episode Lenght:1000
Episode:2010   Loss:26.06179428100586  Reward:-1000.0 Episode Lenght:1000
Episode:2011   Loss:9.930980682373047  Reward:-1000.0 E

                                                               

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-3000.mp4
Episode:3000   Loss:-14.086091041564941  Reward:-1000.0 Episode Lenght:1000
Episode:3001   Loss:-2.6843137741088867  Reward:-1000.0 Episode Lenght:1000
Episode:3002   Loss:-23.403575897216797  Reward:-1000.0 Episode Lenght:1000
Episode:3003   Loss:20.798912048339844  Reward:-1000.0 Episode Lenght:1000
Episode:3004   Loss:-22.43109703063965  Reward:-1000.0 Episode Lenght:1000
Episode:3005   Loss:19.44402503967285  Reward:-1000.0 Episode Lenght:1000
Episode:3006   Loss:4.826547622680664  Reward:-1000.0 Episode Lenght:1000
Episode:3007   Loss:-5.947260856628418  Reward:-1000.0 Episode Lenght:1000
Episode:3008   Loss:-12.108234405517578  Reward:-1000.0 Episode Lenght:1000
Episode:3009   Loss:-24.606449127197266  Reward:-1000.0 Episode Lenght:1000
Episode:3010   Loss:9.00093936920166  Reward:-1000.0 Episode Lenght:1000
Episode:3011   Loss:-19.485910415649414  Reward:-1000

                                                  

Moviepy - Done !
Moviepy - video ready /notebooks/rl-algos/videos/MountainCar-v0__0_1699956346/rl-video-episode-4000.mp4




TrainState(step=Array(4000, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of MLPPolicy(
    # attributes
    action_dims = 3
)>, params=FrozenDict({
    params: {
        Dense_0: {
            bias: Array([-0.04243626, -0.04255573, -0.04262176,  0.05070423, -0.04024125,
                   -0.04434826, -0.05914595, -0.02557674, -0.04410316, -0.1105606 ,
                   -0.02047009,  0.01735428, -0.03230472, -0.07481626, -0.03225683,
                   -0.02252995], dtype=float32),
            kernel: Array([[-0.44260874, -1.5601882 , -1.1346977 ,  0.9341963 ,  0.02998589,
                     0.81851006,  0.04668314,  0.14245598, -1.1609331 , -0.2739778 ,
                    -0.68549204,  0.27432388,  0.9784788 , -0.766398  , -0.73717576,
                    -0.6822712 ],
                   [-0.24176899, -0.4008417 ,  0.19514215, -0.20621327,  0.94435525,
                     1.0735552 , -1.5468733 ,  0.42790997, -0.11145167,  0.8676161 ,
                    -0.1