In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym

import numpy as np

import random
import math

import collections
from collections import namedtuple

In [0]:
step = namedtuple("step", ("state", "action", "next_state", "reward", "done"))

class Replay:
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def push(self, data):
        self.memory.append(data)
        
    def prepare(self, env):
        pass
        
    def sample(self, size):
        if len(self.memory) >= size:
            return random.sample(self.memory, size)

In [0]:
class Actor(nn.Module):
    def __init__(self, state_n, action_n, hidden = 256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), action_n),
            nn.Tanh()
        )
        
    def forward(self,x):
        return self.net(x)
    
class Critic(nn.Module):
    def __init__(self, state_n, action_n, hidden =256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
        )
        self.out = nn.Sequential(
            nn.Linear(hidden+action_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), 1)
        )
        
    def forward(self, state, act):
        temp = self.net(state)
        return self.out(torch.cat([temp, act], dim=1))

In [8]:
EPOCH = 1000
GAME_NAME = "MountainCar-v0"

env = gym.make(GAME_NAME)
obs_n = env.observation_space.shape[0]
act_n = env.action_space.n

LR_ACT = 0.001
LR_CRT = 0.0002
TAU = 0.0005
GAMMA = 0.99

actor = Actor(obs_n, act_n).cuda()
actor_optim = optim.Adam(actor.parameters(), lr = LR_ACT)
actor_tgt = Actor(obs_n, act_n).cuda()
actor_tgt.load_state_dict(actor.state_dict())

critic = Critic(obs_n, act_n).cuda()
critic_optim = optim.Adam(critic.parameters(), lr = LR_CRT)
critic_tgt = Critic(obs_n, act_n).cuda()
critic_tgt.load_state_dict(critic.state_dict())

MAX_MEMORY = 20000
MEM_INIT = 2000
BATCH = 256
storage = Replay(MAX_MEMORY)



In [11]:
for epoch in range(EPOCH):
    obs = env.reset()
    
    count = 0
    while True:
        with torch.no_grad():
            act_v = actor(torch.FloatTensor(obs).cuda()).cpu().numpy()
            noise = np.random.random(act_n)/(epoch+1)
            act_v += noise
            act = act_v.argmax().item()
            
        next_obs, rew, done, _ = env.step(act)
        rew = next_obs[0]
        count += 1
        
        storage.push(step(obs, act_v, next_obs, rew, done))
        obs = next_obs
        
        sample = storage.sample(BATCH)
        if sample:
            sample = step(*zip(*sample))
            
            states = torch.FloatTensor(sample.state).cuda()
            actions = torch.FloatTensor(sample.action).cuda()
            next_states = torch.FloatTensor(sample.next_state).cuda()
            rewards = torch.FloatTensor(sample.reward).unsqueeze(-1).cuda()
            dones = torch.BoolTensor(sample.done).unsqueeze(-1).cuda()
            
            # critic learning
            critic_optim.zero_grad()
            q_pred = critic(states, actions)
            
            next_action_v = actor_tgt(next_states)
            q_next = critic_tgt(next_states, next_action_v)
            q_next[dones] = 0
            q_target = rewards + GAMMA * q_next
            
            critic_loss = F.mse_loss(q_pred, q_target.detach())
            critic_loss.backward()
            critic_optim.step()
            
            # actor learning
            actor_optim.zero_grad()
            actor_loss = -critic(states, actor(states))
            actor_loss = actor_loss.mean()
            actor_loss.backward()
            actor_optim.step()
            
            # tgt soft update
            for tgt, real  in zip(actor_tgt.parameters(), actor.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
                
            for tgt, real  in zip(critic_tgt.parameters(),critic.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
            
        if done:
            break
    print("epoch %d count %d"%(epoch, count))
    
env.close()

epoch 0 count 200
epoch 1 count 200
epoch 2 count 200
epoch 3 count 200
epoch 4 count 200
epoch 5 count 200
epoch 6 count 200
epoch 7 count 200
epoch 8 count 200
epoch 9 count 200
epoch 10 count 200
epoch 11 count 200
epoch 12 count 200
epoch 13 count 200
epoch 14 count 200
epoch 15 count 200
epoch 16 count 200
epoch 17 count 200
epoch 18 count 200
epoch 19 count 200
epoch 20 count 200
epoch 21 count 200
epoch 22 count 200
epoch 23 count 200
epoch 24 count 200
epoch 25 count 200
epoch 26 count 200
epoch 27 count 200
epoch 28 count 200
epoch 29 count 200
epoch 30 count 200
epoch 31 count 200
epoch 32 count 200
epoch 33 count 200
epoch 34 count 200
epoch 35 count 200
epoch 36 count 200
epoch 37 count 200
epoch 38 count 200
epoch 39 count 200
epoch 40 count 200
epoch 41 count 200
epoch 42 count 200
epoch 43 count 200
epoch 44 count 200
epoch 45 count 200
epoch 46 count 200
epoch 47 count 200
epoch 48 count 200
epoch 49 count 200
epoch 50 count 200
epoch 51 count 200
epoch 52 count 200
epo