# Soft Actor-Critic (SAC)

In [None]:
import torch
import gym
from gym.envs.classic_control import PendulumEnv, CartPoleEnv
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv, Walker2dEnv, HumanoidEnv
from matplotlib import pyplot as plt

from src import SACConfig, RandomAgent, SACAgent, EnvWrapper, ReplayBuffer, SAC

In [None]:
config = SACConfig(
    # Environment Config
    env=PendulumEnv,
    observation_dim=3,
    action_dim=1,
    action_min=-2.0,
    action_max=2.0,
    temperature=1.0,
    max_episode_length=1_000,
    discrete_actions=False,
    
    # Neural Network Config
    adjust_temperature=True,
)

In [None]:
# Initialise Networks, Agents, Dataset and Environment
sac = SAC(config).cuda()
random_agent = RandomAgent(config)
agent = SACAgent(sac.policy_network, config)
dataset = ReplayBuffer(config)
env_wrapper = EnvWrapper(config, random_agent)

In [None]:
# Collect Initial Data
dataset.extend([env_wrapper.step() for _ in range(config.random_steps)])

env_wrapper.agent = agent
dataset.extend([env_wrapper.step() for _ in range(config.initial_policy_steps)])

In [None]:
# Visualise One Episode (before training)
env_wrapper.test(render=True)

In [None]:
# Train Model
test_returns = []
episode_idx = len(env_wrapper.ep_returns)
while env_wrapper.total_steps < config.total_train_steps:
    for _ in range(config.env_steps):
        dataset.extend([env_wrapper.step()])

    for batch_idx in range(config.training_steps):
        states, actions, rewards, next_states, is_done = dataset.sample()
        sac.train(states, actions, rewards, next_states, is_done)
    
    if (env_wrapper.total_steps % 1000) < config.env_steps:
        test_return = env_wrapper.test(render=False)
        test_returns.append(test_return)
        print(
            f"Step: {env_wrapper.total_steps}\t"
            f"Episode: {env_wrapper.num_episodes}\t"
            f"Test Return: {test_return:6.2f}\t"
            f"Temperature: {sac.temperature.log_temperature.exp().item():8.4f}"
        )
        
        torch.save(sac, f"{config.env.__name__.replace('Env', '')}.pt")

In [None]:
# Visualise One Episode (after training)
env_wrapper.test(render=True)