In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym

import numpy as np

import random
import math

import collections
from collections import namedtuple

In [2]:
step = namedtuple("step", ("state", "action", "next_state", "reward", "done"))

class Replay:
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def push(self, data):
        self.memory.append(data)
        
    def prepare(self, env):
        pass
        
    def sample(self, size):
        if len(self.memory) >= size:
            return random.sample(self.memory, size)

In [3]:
class Actor(nn.Module):
    def __init__(self, state_n, action_n, hidden = 256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), action_n),
            nn.Tanh()
        )
        
    def forward(self,x):
        return self.net(x)
    
class Critic(nn.Module):
    def __init__(self, state_n, action_n, hidden =256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
        )
        self.out = nn.Sequential(
            nn.Linear(hidden+action_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), 1)
        )
        
    def forward(self, state, act):
        temp = self.net(state)
        return self.out(torch.cat([temp, act], dim=1))

In [4]:
EPOCH = 1000
GAME_NAME = "BipedalWalker-v2"

env = gym.make(GAME_NAME)
env._max_episode_steps = 400
obs_n = env.observation_space.shape[0]
act_n = env.action_space.shape[0]

LR_ACT = 0.0001
LR_CRT = 0.0005
TAU = 0.0005
GAMMA = 0.999

EPS_START = 0.5
EPS_END = 0.1
EPS_DECAY = 20000

actor = Actor(obs_n, act_n).cuda()
actor_optim = optim.Adam(actor.parameters(), lr = LR_ACT)
actor_tgt = Actor(obs_n, act_n).cuda()
actor_tgt.load_state_dict(actor.state_dict())

critic = Critic(obs_n, act_n).cuda()
critic_optim = optim.Adam(critic.parameters(), lr = LR_CRT)
critic_tgt = Critic(obs_n, act_n).cuda()
critic_tgt.load_state_dict(critic.state_dict())

MAX_MEMORY = 20000
MEM_INIT = 2000
BATCH = 256
storage = Replay(MAX_MEMORY)
step_count = 0

In [None]:
for epoch in range(EPOCH):
    obs = env.reset()
    env.render()
    
    count = 0
    total_rew = 0 
    while True:
        eps = EPS_END + (EPS_START- EPS_END) * math.exp(-1* step_count / EPS_DECAY)
        with torch.no_grad():
            act_v = actor(torch.FloatTensor(obs).cuda()).cpu().numpy()
            noise = (np.random.random(act_n) - 0.5) * eps
            act_v += noise
            act_v = act_v.clip(-1,1)

        next_obs, rew, done, _ = env.step(act_v)
        
        #print(noise, act_v, rew)
        env.render()
        count += 1
        step_count += 1
        total_rew += rew 
        
        storage.push(step(obs, act_v, next_obs, rew, done))
        obs = next_obs
        
        sample = storage.sample(BATCH)
        if sample:
            sample = step(*zip(*sample))
            
            states = torch.FloatTensor(sample.state).cuda()
            actions = torch.FloatTensor(sample.action).cuda()
            next_states = torch.FloatTensor(sample.next_state).cuda()
            rewards = torch.FloatTensor(sample.reward).unsqueeze(-1).cuda()
            dones = torch.BoolTensor(sample.done).unsqueeze(-1).cuda()
            
            # critic learning
            critic_optim.zero_grad()
            q_pred = critic(states, actions)
            
            next_action_v = actor_tgt(next_states)
            q_next = critic_tgt(next_states, next_action_v)
            q_next[dones] = 0
            q_target = rewards + GAMMA * q_next
            
            critic_loss = F.mse_loss(q_pred, q_target.detach())
            critic_loss.backward()
            critic_optim.step()
            
            # actor learning
            actor_optim.zero_grad()
            actor_loss = -critic(states, actor(states))
            actor_loss = actor_loss.mean()
            actor_loss.backward()
            actor_optim.step()
            
            # tgt soft update
            for tgt, real  in zip(actor_tgt.parameters(), actor.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
                
            for tgt, real  in zip(critic_tgt.parameters(),critic.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
            
        if done:
            break
    print("epoch %d count %d reward %.5f"%(epoch, count, total_rew))
    
env.close()

epoch 0 count 400 reward -46.69813
epoch 1 count 400 reward -46.55329
epoch 2 count 400 reward -46.65319
epoch 3 count 400 reward -46.13652
epoch 4 count 400 reward -46.55728
epoch 5 count 400 reward -46.57750
epoch 6 count 400 reward -46.60644
epoch 7 count 400 reward -46.68356
epoch 8 count 400 reward -46.64587
epoch 9 count 400 reward -46.60472
epoch 10 count 400 reward -46.16692
epoch 11 count 400 reward -46.54889
epoch 12 count 400 reward -46.41477
epoch 13 count 400 reward -46.57433
epoch 14 count 400 reward -46.63324
epoch 15 count 400 reward -46.56693
