In [1]:
import gym
import torch
from collections import namedtuple

class Agent:
    def __init__(self, env, actor, noise, rend_wait = -1, rend_interval = -1 \
                 , frame = None, max_step = None, device = None):
        self.env = env
        if max_step is not None:
            self.env._max_episode_steps = max_step
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")            
        self.actor = actor
        self.noise = noise
        self.device = device
        
        self.wait = rend_wait
        self.interval = rend_interval
        self.frame = frame
        
    def reset(self):
        self.env.close()
        self.env.reset()
        
    def render(self, epoch):
        if self.wait >= 0 and epoch < self.wait:
            return
        if self.interval >= 0 and epoch % self.interval == 0:
            rend = self.env.render("rgb_array")
            if self.frame is not None:
                self.frame.append(self.env.render("rgb_array"))
        
    def episode(self, epoch):
        self.obs = self.env.reset()
        self.render(epoch)
            
        while True:
            with torch.no_grad():
                act_v = self.actor(torch.FloatTensor(self.obs).to(self.device)).cpu().numpy()
                act_v += self.noise.get_noise()
                if self.env.action_space.shape:
                    act_v = act_v.clip(self.env.action_space.low, self.env.action_space.high)
                act = self.actor.get_action(act_v)

            next_obs, rew, done, etc = self.env.step(act)
            self.render(epoch)

            obs = self.obs
            self.obs = next_obs

            yield obs, act_v, act, next_obs, rew, done, etc
            if done:
                break

In [2]:
import numpy as np
import math

class NoiseMaker():
    def __init__(self, action_size, n_type = None, param = None, decay = False):
        self.action_size = action_size
        self.state = np.zeros(action_size, dtype=np.float32)
        self.count = 0
        self.decay = decay
        if n_type is None:
            n_type = "normal"
        self.type = n_type
        
        if param is None:
            self.param = {
                "start": 0.9,
                "end":0.02,
                "decay": 2000
            }
            if n_type =="ou":
                self.param["ou_mu"] = 0.0
                self.param["ou_th"] = 0.15
                self.param["ou_sig"] = 0.2
        else:
            self.param = param
            
    def get_noise(self):
        eps = self.param["end"] + (self.param["start"] - self.param["end"]) \
                * math.exp(-1*self.count/ self.param["decay"])
        
        noise = np.random.normal(size=self.action_size)
        if self.type == "ou":
            self.state += self.param["ou_th"] * (self.param["ou_mu"] - self.state) \
                        + self.param["ou_sig"] * noise
            noise = self.state
        if not self.decay:
            eps = 1
        self.count += 1
            
        return noise * eps

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym

import numpy as np

import random
import math

import collections
from collections import namedtuple

class Actor(nn.Module):
    def __init__(self, state_n, action_n, hidden = 512):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), action_n),
            nn.Tanh()
        )
        
    def get_action(self, value):
        return value
        
    def forward(self,x):
        return self.net(x)
    
class Critic(nn.Module):
    def __init__(self, state_n, action_n, hidden = 512):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
        )
        self.out = nn.Sequential(
            nn.Linear(hidden + action_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), 1)
        )
        
    def forward(self, state, act):
        temp = self.net(state)
        return self.out(torch.cat([temp, act], dim=1))

In [4]:
step_data = namedtuple("step", ("state", "action_value", "action", "next_state", "reward", "done"))

class Replay:
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def push(self, data):
        self.memory.append(data)
        
    def prepare(self, env):
        pass
        
    def sample(self, size):
        if len(self.memory) >= size:
            return random.sample(self.memory, size)

In [5]:
EPOCH = 5000
GAME_NAME = "Pendulum-v0"

env = gym.make(GAME_NAME)
obs_n = env.observation_space.shape[0]
act_n = env.action_space.shape[0]

LR_ACT = 0.001
LR_CRT = 0.005
TAU = 0.05
GAMMA = 0.99

actor = Actor(obs_n, act_n).cuda()
actor_optim = optim.Adam(actor.parameters(), lr = LR_ACT)
actor_tgt = Actor(obs_n, act_n).cuda()
actor_tgt.load_state_dict(actor.state_dict())

critic = Critic(obs_n, act_n).cuda()
critic_optim = optim.Adam(critic.parameters(), lr = LR_CRT)
critic_tgt = Critic(obs_n, act_n).cuda()
critic_tgt.load_state_dict(critic.state_dict())

MAX_MEMORY = 10000
BATCH = 512
storage = Replay(MAX_MEMORY)
noise = NoiseMaker(act_n, "ou", decay = True)

agent = Agent(env, actor, noise, 0, 1)

In [6]:
for epoch in range(EPOCH):
    rew_total = 0
    for i, step in enumerate(agent.episode(epoch)):
        obs, act_v, act, next_obs, rew, done, etc = step
        rew_total += rew
        
        storage.push(step_data(obs, act_v, act, next_obs, rew, done))        
        sample = storage.sample(BATCH)
        if sample:
            sample = step_data(*zip(*sample))
            
            states = torch.FloatTensor(sample.state).cuda()
            actions = torch.FloatTensor(sample.action_value).cuda()
            next_states = torch.FloatTensor(sample.next_state).cuda()
            rewards = torch.FloatTensor(sample.reward).unsqueeze(-1).cuda()
            dones = torch.BoolTensor(sample.done).unsqueeze(-1).cuda()
            
            # critic learning
            critic_optim.zero_grad()
            q_pred = critic(states, actions)
            
            next_action_v = actor_tgt(next_states)
            q_next = critic_tgt(next_states, next_action_v)
            q_next[dones] = 0
            q_target = rewards + GAMMA * q_next
            
            critic_loss = F.mse_loss(q_pred, q_target.detach())
            critic_loss.backward()
            critic_optim.step()
            
            # actor learning
            actor_optim.zero_grad()
            actor_loss = -critic(states, actor(states))
            actor_loss = actor_loss.mean()
            actor_loss.backward()
            actor_optim.step()
            
            # tgt soft update
            for tgt, real  in zip(actor_tgt.parameters(), actor.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
                
            for tgt, real  in zip(critic_tgt.parameters(),critic.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
            
        if done:
            break
    print("epoch %d count %d" % (epoch, i + 1))
    
agent.reset()

epoch 0 count 200
epoch 1 count 200
epoch 2 count 200
epoch 3 count 200
epoch 4 count 200
epoch 5 count 200
epoch 6 count 200
epoch 7 count 200
epoch 8 count 200
epoch 9 count 200
epoch 10 count 200
epoch 11 count 200
epoch 12 count 200
epoch 13 count 200
epoch 14 count 200
epoch 15 count 200
epoch 16 count 200
epoch 17 count 200
epoch 18 count 200
epoch 19 count 200
epoch 20 count 200
epoch 21 count 200
epoch 22 count 200
epoch 23 count 200
epoch 24 count 200
epoch 25 count 200
epoch 26 count 200
epoch 27 count 200
epoch 28 count 200
epoch 29 count 200
epoch 30 count 200
epoch 31 count 200
epoch 32 count 200
epoch 33 count 200
epoch 34 count 200
epoch 35 count 200
epoch 36 count 200
epoch 37 count 200
epoch 38 count 200
epoch 39 count 200
epoch 40 count 200
epoch 41 count 200
epoch 42 count 200
epoch 43 count 200
epoch 44 count 200
epoch 45 count 200
epoch 46 count 200
epoch 47 count 200
epoch 48 count 200
epoch 49 count 200
epoch 50 count 200
epoch 51 count 200
epoch 52 count 200
epo

KeyboardInterrupt: 