In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
import torch.nn.functional as F

In [4]:
import random

In [5]:
from collections import deque

In [6]:
from torch.distributions import Categorical

In [7]:
import numpy as np

In [8]:
import gym
import os
import copy
import matplotlib.pyplot as plt
import seaborn as sns

In [34]:
class ActorSoftmax(nn.Module):
    def __init__(self,input_dim,output_dim,hidden_dim=256):
        super(ActorSoftmax,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,output_dim)
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        probs = F.softmax(self.fc3(x),dim=-1)
        return probs

In [35]:
class Critic(nn.Module):
    def __init__(self,input_dim,output_dim,hidden_dim=256):
        super(ActorSoftmax,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,output_dim)
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        values = self.fc3(x)
        return values

In [36]:
class ReplayBufferQue:
    def __init__(self,capacity: int) -> None:
        self.capacity = capacity
        self.buffer = deque(maxlen=self.capacity)
    def push(self,transitions):
        self.buffer.append(transitions)
    def sample(self,batch_size: int,sequential: bool=False):
        if sequential:
            index = random.randint(0,len(self.buffer)-batch_size-1)
            res = [self.buffer[i] for i in range(index,index+batch_size)]
            return zip(*res)
        else:
            batch = random.sample(self.buffer,batch_size)
            return zip(*batch)
    def clear(self):
        self.buffer.clear()
    def __len__(self):
        return len(self.buffer)

In [37]:
class PGReplay(ReplayBufferQue):
    def __init__(self):
        self.buffer = deque()
    def sample(self):
        batch = list(self.buffer)
        return zip(*batch)

In [38]:
class Agent:
    def __init__(self,cfg) -> None:
        self.gamma = cfg.gamma
        self.device = torch.device
        self.actor = ActorSoftmax(cfg.n_states,cfg.n_actions,hidden_dim=cfg.actor_hidden_dim).to(self.device)
        self.critic = Critic(cfg.n_states,1,hidden_dim=cfg.critic_hidden_dim).to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=cfg.actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=cfg.critic_lr)
        self.memory = PGReplay()
        self.k_epochs = cfg.k_epochs
        self.eps_clip = cfg.eps_clip
        self.entropy_coef = cfg.entropy_coef
        self.sample_count = 0
        self.update_freq = cfg.update_freq

    def sample_action(self,state):
        self.sample_count += 1
        state = torch.tensor(state,device=self.device,dtype=torch.float32).unsqueeze(0)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        self.log_probs = dist.log_prob(action).detach()
        return action.detach().cpu().numpy().item()

    @torch.no_grad()
    def predict_action(self, state):
        state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.detach().cpu().numpy().item()

    def update(self):
        if self.sample_count % self.update_freq != 0:
            return
        old_states,old_actions,old_log_probs,old_rewards,old_dones = self.memory.sample()
        old_states = torch.tensor(np.array(old_states),device=self.device,dtype=torch.float32)
        old_actions = torch.tensor(np.array(old_actions),device=self.device,dtype=torch.float32)
        old_log_probs = torch.tensor(np.array(old_log_probs),device=self.device,dtype=torch.float32)
        returns = []
        discounted_sum = 0
        for reward,done in zip(reversed(old_rewards),reversed(old_dones)):
            if done:
                discounted_sum = 0
            discounted_sum = reward + (self.gamma * discounted_sum)
            returns.insert(0,discounted_sum)
        returns = torch.tensor(returns, device=self.device, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        for _ in range(self.k_epochs):
            values = self.critic(old_states)
            advantage = returns - values.detach()
            probs = self.actor(old_states)
            dist = Categorical(probs)
            new_probs = dist.log_prob(old_actions)
            ratio = torch.exp(new_probs-old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio,1-self.eps_clip,1+self.eps_clip) * advantage
            actor_loss = -torch.min(surr1,surr2).mean() + self.entropy_coef * dist.entropy().mean()
            critic_loss = (returns - values).pow(2).mean()
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()
        self.memory.clear()

In [39]:
def train(cfg,env,agent):
    print('begin to train!')
    rewards = []
    steps = []
    best_ep_reward = 0
    output_agent = None
    for i_ep in range(cfg.train_eps):
        ep_reward = 0
        ep_step = 0
        state,_ = env.reset()
        for _ in range(cfg.max_steps):
            ep_steps += 1
            action = agent.sample_action(state)
            next_state,reward,terminated,truncated,_ = env.step(action)
            env.render()
            agent.memory.push((state,action,agent.log_probs,reward,terminated))
            state = next_state
            agent.update()
            ep_reward += reward
            if terminated:
                break
        ### eval
        if (i_ep + 1) % cfg.eval_per_episode == 0:
            sum_eval_reward = 0
            state,_ = env.reset()
            for _ in range(cfg.max_steps):
                action = action.predict_action(state)
                next_state,reward,terminated,truncated,_ = env.step(action)
                state = next_state
                eval_per_episode += reward
                if terminated:
                    break
                sum_eval_reward += eval_per_episode
            mean_eval_reward = sum_eval_reward / cfg.eval_eps
            if mean_eval_reward >= best_ep_reward:
                best_ep_reward = mean_eval_reward
        env.close()
        return output_agent,{'rewards':rewards}