In [None]:
import torch
# from __future__ import absolute_import
# from __future__ import print_function

from itertools import product

import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

import gymnasium as gym
import torch.optim as optim
from matplotlib import pylab as plt

import relaxit

from relaxit.rl_benchmarks.algorithms.reinforce import REINFORCE
from relaxit.rl_benchmarks.algorithms.a2c import A2C
from relaxit.rl_benchmarks.algorithms.relax import RELAX

In [None]:
# Plot formatting
plt.rcParams['font.family'] = 'DejaVu Serif'
plt.rcParams['lines.linewidth'] = 2
plt.rcParams['lines.markersize'] = 12
plt.rcParams['xtick.labelsize'] = 24
plt.rcParams['ytick.labelsize'] = 24
plt.rcParams['legend.fontsize'] = 24
plt.rcParams['axes.titlesize'] = 36
plt.rcParams['axes.labelsize'] = 24

In [None]:
def get_z_tilde_z_samples_params(logits):  # logits = log P(b | theta)
    u = torch.rand_like(logits, device=logits.device)
    v = torch.rand_like(logits, device=logits.device)
    z = logits - torch.log(-torch.log(u))
    samples = torch.argmax(z)

    tilde_z = -torch.log(-torch.log(v)/torch.exp(logits) -
                         torch.log(v)[samples])
    tilde_z[samples] = -torch.log(-torch.log(v))[samples]

    return z, tilde_z, samples

In [None]:
import gymnasium as gym
import torch.optim as optim
from matplotlib import pylab as plt



def apply_benchmark(algorithm : str = 'REINFORCE', 
                    env_name = 'CartPole-v1', 
                    max_steps : int = 256,
                    max_episode: int = 100):
    LR = 0.002  # Learning rate
    SEED = 42  # Random seed for reproducibility
    #MAX_EPISODES = 350  # Max number of episodes
    LOG_INTERVAL = 10
    HIDDEN_SIZE: int = 64

    # Init actor-critic agent
    if algorithm == 'REINFORCE':
        agent = REINFORCE(gym.make(env_name), hidden_size=HIDDEN_SIZE, gamma=.99, max_steps = max_steps, random_seed=SEED)
    elif algorithm == 'RELAX':
        agent = RELAX(gym.make(env_name), hidden_size = HIDDEN_SIZE, gamma=0.99, max_steps = max_steps, random_seed=SEED)
    elif algorithm == 'A2C':
        agent = A2C(gym.make(env_name), hidden_size = HIDDEN_SIZE, gamma=0.99, max_steps = max_steps, random_seed=SEED)

    # Init optimizers
    actor_optim = optim.Adam(agent.actor.parameters(), lr=LR)
    if algorithm != 'REINFORCE':
        critic_optim = optim.Adam(agent.critic.parameters(), lr=LR)

    #
    # Train
    #

    r = []  # Array containing total rewards
    avg_r = 0  # Value storing average reward over last 100 episodes

    running_reward = 0

    ep_rewards = []
    running_rewards = []

    for i in range(max_episode):
        if algorithm == 'REINFORCE':
            total_reward = agent.train_one_episode(optimizer=actor_optim)
        else:
            total_reward = agent.train_one_episode(actor_optimizer=actor_optim, critic_optimizer=critic_optim)

        ep_reward = total_reward
        if running_reward == 0:
            running_reward = ep_reward
        running_reward = 0.01 * ep_reward + (1 - 0.01) * running_reward
        ep_rewards.append(ep_reward)
        running_rewards.append(running_reward)
        
        if i % LOG_INTERVAL == 0:
            print(
                "Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}".format(
                    i, ep_reward, running_reward
                )
            )
        if running_reward > agent.env.spec.reward_threshold:
            print(
                "Solved! Running reward is now {} and ".format(running_reward)
            )
            break
    return running_rewards

In [None]:
algorithm2rewards = {}

In [None]:
algorithm2rewards = {}
for algorithm in ['RELAX', 'REINFORCE', 'A2C']:
    algorithm2rewards[algorithm] = apply_benchmark(algorithm=algorithm, env_name = 'CartPole-v1', max_episode = 200, max_steps = 70)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize = (10, 5))
for algorithm in ['REINFORCE', 'A2C', 'RELAX']:
    if algorithm in algorithm2rewards:
        ax.plot(algorithm2rewards[algorithm], label = algorithm)
ax.legend()
ax.set(
    title = 'CartPole-v1',
    ylabel = 'reward',
    xlabel = 'num_episides'
)
ax.grid()
plt.show()

In [None]:
algorithm2rewards = {}
for algorithm in ['REINFORCE', 'A2C']: #'RELAX', 
    algorithm2rewards[algorithm] = apply_benchmark(algorithm=algorithm,
                                                   env_name='Acrobot-v1',
                                                   max_episode=200,
                                                   max_steps = 650)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize = (10, 5))
for algorithm in ['REINFORCE', 'A2C', 'RELAX']:
    if algorithm in algorithm2rewards:
        ax.plot(algorithm2rewards[algorithm], label = algorithm)
ax.legend()
ax.set(
    title = 'Acrobot-v1',
    ylabel = 'reward',
    xlabel = 'num_episides'
)
ax.grid()
plt.show()

In [None]:
algorithm2rewards = {}
for algorithm in ['RELAX', 'REINFORCE', 'A2C']: 
    algorithm2rewards[algorithm] = apply_benchmark(algorithm=algorithm,
                                                   env_name='Taxi-v3',
                                                   max_episode=200,
                                                   max_steps = 30)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize = (10, 5))
for algorithm in ['REINFORCE', 'A2C', 'RELAX']:
    if algorithm in algorithm2rewards:
        ax.plot(algorithm2rewards[algorithm], label = algorithm)
ax.legend()
ax.set(
    title = 'Taxi-v3',
    ylabel = 'reward',
    xlabel = 'num_episides'
)
ax.grid()
plt.show()