In [115]:
import warnings
import numpy as np
import tensorflow as tf

from SAC import SAC

%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore', category=DeprecationWarning)

# create SAC agent
config = {
    'state_size': 10,
    'action_size': (2,2), # tuple (t, k) where t*k actions are sampled from t k-dimensional distributions
    'max_episode_steps': 20, # episode length
    'max_steps': 20*200, # number of episodes
    'min_steps': 20*1, # steps before training starts
    'warmup': False, # random actions before training starts
    'buffer_size': 20*160, # should be 80-100% of max_steps
    'minibatch_size': 256, # should be 128-512
    'update_interval': 1,
    'validation_interval': 20*10, # validate performance every 10 episodes
    'preprocess_state': False, # normalize state, implement on your own
    'actor_weights_scaling': 0.1,
    'activation_function': 'tanh',
    'weights_initializer': 'glorot_uniform',
    'pol_nr_layers': 2,
    'pol_hidden_size': 16,
    'val_nr_layers': 2,
    'val_hidden_size': 16,
    'gamma': 0.99,
    'lr': 8e-4, # should be 1e-4-1e-3
    'alpha_init': 0.001, # requires careful tuning: monitor output std.
    'alpha_lr': 0.0, # hard to converge, should be <= lr
    'alpha_decay_rate': 0.0, # depends on max_steps
    'alpha_to_zero_steps': 20*100, # should be ca. 40-60% of max_steps
    'polyak': 0.995,
    'huber_delta': 2.0,
    'gradient_clip_norm': 2.0, # should be 2.0-5.0
    'reg_coef': 0.0,
    'std_initial_value': 0.4, # should be 0.25-0.5
    'seed': 42
}


# define test environment
class Env():
    def __init__(self, agent, target=0.6):
        self.agent = agent
        self.state_size = self.agent.state_size
        self.target = target
        
    def get_state(self):
        return tf.ones((1,self.state_size), dtype=tf.float32)

    def reset(self):
        return self.get_state()

    # reward is MAE between action and target, scaled to be between 0.0 and 1.0
    def step(self, action, target=0.6):
        target_ = tf.ones_like(action) * target
        mae = tf.reduce_mean(tf.abs(action - target_))
        reward = tf.clip_by_value(1. - mae, 0.0, 1.0)
        next_state = self.get_state()
        return next_state, reward


# create environment and agent
agent = SAC(config)
env = Env(agent)
observation = env.reset()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [116]:
# training loop
total_steps = 0
episode_steps = 0
observation = env.reset()

while total_steps < agent.max_steps:
    
    # schedule alpha temperature parameter: decay and/or set to zero
    agent.manage_alpha_value(total_steps)
            
    # action selection, depending on warmup
    if total_steps < agent.warmup_steps:
        mean, std, action, raw_action = agent.random_action()
    else:
        mean, std, action, _, _, raw_action = agent.select_action(observation, training=True)
        
    next_observation, reward = env.step(action)
    agent.buffer.store(observation, action, reward, next_observation)
    observation = next_observation
    total_steps += 1
    episode_steps += 1
    
    # at the end of the episode, reset
    if episode_steps == agent.max_episode_steps:
        observation = env.reset()
        episode_steps = 0
        
    # update agent (policy and value networks)
    if total_steps > agent.min_steps and total_steps % agent.update_interval == 0:
        q1_loss, q2_loss, min_q, policy_loss, min_logpi, max_logpi, mean_logpi, alpha_loss, entropy, grads_norm = agent.train()
        
    # validate agent (optional)
    if agent.validation_interval > 0 and total_steps % agent.validation_interval == 0:
        observation, val_rews = env.reset(), []
        for _ in range(agent.max_episode_steps):
            mean, std, action, _, _, _ = agent.select_action(observation, training=False)
            next_observation, reward = env.step(action)
            observation = next_observation
            val_rews.append(reward)
        
        observation = env.reset()
        print(f'step: {total_steps:>8} ({(total_steps/agent.max_steps):>6.1%}) - validation reward: {np.mean(val_rews):.2f}')

step:      200 (  5.0%) - validation reward: 0.70
step:      400 ( 10.0%) - validation reward: 0.66
step:      600 ( 15.0%) - validation reward: 0.65
step:      800 ( 20.0%) - validation reward: 0.60
step:     1000 ( 25.0%) - validation reward: 0.67
step:     1200 ( 30.0%) - validation reward: 0.68
step:     1400 ( 35.0%) - validation reward: 0.74
step:     1600 ( 40.0%) - validation reward: 0.74
step:     1800 ( 45.0%) - validation reward: 0.98
step:     2000 ( 50.0%) - validation reward: 0.86
step:     2200 ( 55.0%) - validation reward: 0.91
step:     2400 ( 60.0%) - validation reward: 0.90
step:     2600 ( 65.0%) - validation reward: 0.90
step:     2800 ( 70.0%) - validation reward: 0.90
step:     3000 ( 75.0%) - validation reward: 0.90
step:     3200 ( 80.0%) - validation reward: 0.90
step:     3400 ( 85.0%) - validation reward: 0.90
step:     3600 ( 90.0%) - validation reward: 0.92
step:     3800 ( 95.0%) - validation reward: 0.94
step:     4000 (100.0%) - validation reward: 0.97
