# Soft Actor-Critic (SAC)

In [None]:
import torch
from gym.envs.classic_control import PendulumEnv
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv, Walker2dEnv

from sac import SAC, SACAgent, SACConfig
from sac.rl import EnvWrapper, ReplayBuffer, RandomContinuousAgent

In [None]:
config = SACConfig(env=Walker2dEnv)

In [None]:
sac = SAC(config).cuda()
sac_agent = SACAgent(sac.policy_network, discrete_actions=False)
random_agent = RandomContinuousAgent(config.action_dim, config.action_min, config.action_max)
replay_buffer = ReplayBuffer(config.buffer_size, config.batch_size)
env_wrapper = EnvWrapper(config.env, random_agent, config.max_episode_length)

In [None]:
# Collect Initial Data
replay_buffer.extend([env_wrapper.step() for _ in range(config.random_steps)])

env_wrapper.update_agent(sac_agent)
replay_buffer.extend([env_wrapper.step() for _ in range(config.initial_policy_steps)])

In [None]:
# Visualise One Episode (before training)
env_wrapper.test(render=True)

In [None]:
# Train Model
test_returns = []
env_wrapper.reset_statistics()
while env_wrapper.total_steps < config.total_train_steps:
    for _ in range(config.env_steps):
        replay_buffer.extend([env_wrapper.step()])

    for batch_idx in range(config.training_steps):
        states, actions, rewards, next_states, is_done = replay_buffer.sample()
        sac.step(states, actions, rewards, next_states, is_done)
    
    if (env_wrapper.total_steps % 1000) < config.env_steps:
        test_return = env_wrapper.test(render=False)
        test_returns.append(test_return)
        print( 
            f"Step: {env_wrapper.total_steps}\t"
            f"Episode: {env_wrapper.total_episodes}\t"
            f"Test Return: {test_return:6.2f}\t"
            f"Temperature: {sac.temperature.log_temperature.exp().item():8.4f}"
        )
        
        torch.save(sac, f"{config.env.__name__.replace('Env', '')}.pt")

In [None]:
# Visualise One Episode (after training)
env_wrapper.test(render=True)