In [1]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

env = gym.make('CartPole-v1', render_mode='rgb_array').unwrapped

# matplotlibの設定
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
  from IPython import display

plt.ion()

# gpuが使用される場合の設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
    
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity # サイクルバッファ
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [7]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, lr=0.003):
        super(PolicyNetwork, self).__init__()
        self.layer1 = nn.Linear(state_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, 64)
        self.pi_mean = nn.Linear(64, action_dim)
        self.pi_stddev = nn.Linear(64, action_dim)

        self.optimizer = optim.Adam(self.parameters(), lr=lr)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        mean = F.linear(self.pi_mean(x))
        stddev = F.linear(self.pi_stddev(x))

        stddev = torch.exp(stddev)

        return mean, stddev

In [8]:
class DualQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, lr=0.003):
        super(DualQNetwork, self).__init__()
        # QNetwork 1
        self.layer1 = nn.Linear(state_dim + action_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, 64)
        self.q1 = nn.Linear(64, 1)
        # QNetwork 2
        self.layer4 = nn.Linear(state_dim + action_dim, 64)
        self.layer5 = nn.Linear(64, 64)
        self.layer6 = nn.Linear(64, 64)
        self.q2 = nn.Linear(64, 1)

        self.optimizer = optim.Adam(self.parameters(), lr=lr)

    def forward(self, s, a):
        x = torch.cat((s, a), -1) # combination s and a
  
        # QNetwork 1
        x1 = F.relu(self.layer1(x))
        x1 = F.relu(self.layer2(x1))
        x1 = F.relu(self.layer3(x1))
        q_value1 = F.linear(self.q1(x1))
        # QNetwork 2
        x2 = F.relu(self.layer1(x))
        x2 = F.relu(self.layer2(x2))
        x2 = F.relu(self.layer3(x2))
        q_value2 = F.linear(self.q2(x2))

        return q_value1, q_value2


In [6]:
class SAC():
    def __init__(self, state_space, action_space, buffer_size, gamma, soft_target_tau, hard_target_interval, 
                 target_entropy, policy_lr, q_lr, alpha_lr):
        super(SAC, self).__init__()

        self.state_dim = state_space.shape[0]
        self.action_dim = action_space.shape[0]

        # Envアクション用にスケールする
        self.action_center = (action_space.high + action_space.low) / 2
        self.action_scale = action_space.high - self.action_center

        # Neural Networks
        self.policy_net = PolicyNetwork(self.state_dim, self.action_dim, policy_lr)
        
        self.q_net = DualQNetwork(self.state_dim, self.action_dim, q_lr)
        self.target_q_net = DualQNetwork(self.state_dim, self.action_dim, q_lr)

        for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(param.data)

        self.replay_memory = ReplayMemory(buffer_size)

        
        self.target_entropy = -self.action_dim
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)


        # Hyper Parameters
        self.gamma = gamma
        self.log_alpha = 0
        self.soft_target_tau = soft_target_tau
        self.hard_target_interval = hard_target_interval
        self.target_entropy = target_entropy

        self.train_count = 0

    def sample_action(self, state):
        mean, stddev = self.policy_net(state)

        # Reparameterization
        normal_random = torch.normal(0, 1, size=mean.shape)
        action_org = mean + stddev * normal_random

        # Squashed Gaussian Policy
        action = torch.tanh(action_org)

        return action, mean, stddev, action_org

    def scaled_sample_action(self, state):
        action, _, _, _ = self.sample_action(state)
        env_action = action.numpy()[0] * self.action_scale + self.action_center

        return env_action, action

    # 正規分布でのactionの対数確率密度関数logμ(a|s)
    def compute_logpi(self, mean, stddev, action):
        a1 = -0.5 * np.log(2*np.pi)
        a2 = -torch.log(stddev)
        a3 = -0.5 * (((action - mean) / stddev) ** 2)
        return a1 + a2 + a3

    # tanhで変換されたactionのlogπ(a|s)をaction_orgを使って計算
    def compute_logpi_sgp(self, mean, stddev, action_org):
        logmu = self.compute_logpi(mean, stddev, action_org)
        tmp = 1 - torch.tanh(action_org) ** 2
        tmp = torch.clip(tmp, 1e-10, 1.0)  # log(0)回避
        logpi = logmu - torch.sum(torch.log(tmp), 1, keepdim=True)
        return logpi


    def update(self, batch_size, q_net_sync=False):
        # 経験をバッチでサンプリング
        state_batch, action_batch, n_state_batch, reward_batch, done_batch = self.replay_memory.sample(batch_size)

        state_batch = torch.FloatTensor(state_batch)
        n_state_batch = torch.FloatTensor(n_state_batch)
        action_batch = torch.FloatTensor(action_batch)
        reward_batch = torch.FloatTensor(reward_batch)
        done_batch = torch.BoolTensor(done_batch)

        alpha = torch.exp(self.log_alpha)
        
        # Q(s,a)の推定値を計算し, Q値の損失関数を計算
        with torch.no_grad():
            n_action, n_mean, n_stddev, n_action_org = self.policy_net(n_state_batch)
            
            n_logpi = self.compute_logpi_sgp(n_mean, n_stddev, n_action_org)
            n_q1, n_q2 = self.target_q_net(n_state_batch, n_action)

            q_est = reward_batch + (1 - done_batch) * self.gamma * torch.minimum(n_q1, n_q2) - (alpha * n_logpi)
        q1, q2 = self.q_net(state_batch, action_batch)
        q1_loss = F.mse_loss(q1, q_est)
        q2_loss = F.mse_loss(q2, q_est)
        q_loss = q1_loss + q2_loss

        # q_lossからQNetworkを学習
        self.q_net.optimizer.zero_grad()
        q_loss.backward()
        self.q_net.optimizer.step()

        # 方策の損失関数を計算
        action, mean, stddev, action_org = self.sample_action(state_batch) # 現在の方策π(θ)で選ばれるactionについて評価       
        logpi = self.compute_logpi_sgp(mean, stddev, action_org)
        q1, q2 = self.q_net(state_batch, action)
        q_min = torch.minimum(q1, q2)
        policy_loss =  torch.mean(-q_min + alpha * logpi)

        # policy_lossからPolicyNetworkを学習
        self.policy_net.optimizer.zero_grad()
        policy_loss.backward()
        self.policy_net.optimizer.step()

        # αの自動調整
        alpha_loss = torch.mean(-self.log_alpha * (logpi + self.target_entropy))
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        # ソフトターゲットで更新
        for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.soft_target_tau) + param.data * self.soft_target_tau)

        # q_net_syncフラグが有効ならq_netを同期させる
        if q_net_sync:
            for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()):
                target_param.data.copy_(param.data)

        return policy_loss, q_loss