In [1]:
import gym
import random
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
#Hyperparameters
lr_mu        = 0.0005
lr_q         = 0.001
gamma        = 0.99
batch_size   = 32
buffer_limit = 50000
tau          = 0.005 # for target network soft update

In [3]:
class ReplayBuffer():
    def __init__(self):
        self.buffer=collections.deque(maxlen=buffer_limit)
    def put(self,transition):
        self.buffer.append(transition)
    def sample(self,n):
        mini_batch=random.sample(self.buffer,n)
        s_lst,a_lst,r_lst,s_prime_lst,done_mask_lst=[],[],[],[],[]
        for transition in mini_batch:
            s,a,r,s_prime,done_mask=transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])
        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst),\
                        torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float),\
                        torch.tensor(done_mask_lst)
    def size(self):
        return len(self.buffer)

In [4]:
class Munet(nn.Module):
    def __init__(self):
        super(Munet,self).__init__()
        self.fc1=nn.Linear(3,128)
        self.fc2=nn.Linear(128,64)
        self.fc_mu=nn.Linear(64,1)
    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        mu=torch.tanh(self.fc_mu(x))*2
        return mu
    

In [5]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet,self).__init__()
        self.fc_s=nn.Linear(3,64)
        self.fc_a=nn.Linear(1,64)
        self.fc_1=nn.Linear(128,32)
        self.fc_3=nn.Linear(32,1)
    def forward(self,x,a):
        h1=F.relu(self.fc_s(x))
        h2=F.relu(self.fc_a(a))
        cat=torch.cat([h1,h2],dim=1)
        q=F.relu(self.fc_1(cat))
        q=self.fc_3(q)
        return q

In [6]:
class OrnsteinUhlenbeckNoise:
    def __init__(self, mu):
        self.theta, self.dt, self.sigma = 0.1, 0.01, 0.1
        self.mu = mu
        self.x_prev = np.zeros_like(self.mu)

    def __call__(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \
                self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
        self.x_prev = x
        return x

In [10]:
def train(mu,mu_target,q,q_target,memory,q_optimizer,mu_optimizer):
    s,a,r,s_prime,done_mask=memory.sample(batch_size)
    target=r+gamma*q_target(s_prime,mu_target(s_prime))
    q_loss=F.smooth_l1_loss(q(s,a),target.detach())
    q_optimizer.zero_grad()
    q_loss.backward()
    q_optimizer.step()
    
    mu_loss=-q(s,mu(s)).mean()
    mu_optimizer.zero_grad()
    mu_loss.backward()
    mu_optimizer.step()

In [13]:
def soft_update(net,net_target):
    for param_target,param in zip(net_target.parameters(),net.parameters()):
        param_target.data.copy_(param_target.data*(1.0-tau)+param.data*tau)

In [14]:
env=gym.make('Pendulum-v0')
memory=ReplayBuffer()

q,q_target=Qnet(),Qnet()
q_target.load_state_dict(q.state_dict())
mu, mu_target = Munet(), Munet()
mu_target.load_state_dict(mu.state_dict())

score=0.0
print_interval=20

mu_optimizer=optim.Adam(mu.parameters(),lr=lr_mu)
q_optimizer=optim.Adam(q.parameters(),lr=lr_q)
ou_noise=OrnsteinUhlenbeckNoise(mu=np.zeros(1))

for n_epi in range(2000):
    s=env.reset()
    for t in range(300):
        a=mu(torch.from_numpy(s).float())
        a=a.item()+ou_noise()[0]
        s_prime,r,done,info=env.step([a])
        memory.put((s,a,r/100.0,s_prime,done))
        score+=r
        s=s_prime
        
        if done:
            break
    if memory.size()>2000:
        for i in range(10):
            train(mu,mu_target,q,q_target,memory,q_optimizer,mu_optimizer)
            soft_update(mu,mu_target)
            soft_update(q,q_target)
    if n_epi%print_interval==0 and n_epi!=0:
        print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
        score = 0.0
env.close()

# of episode :20, avg score : -1616.7
# of episode :40, avg score : -1707.1
# of episode :60, avg score : -1690.0
# of episode :80, avg score : -1508.9
# of episode :100, avg score : -1486.9
# of episode :120, avg score : -1466.9
# of episode :140, avg score : -1512.7
# of episode :160, avg score : -1475.9
# of episode :180, avg score : -1448.1
# of episode :200, avg score : -1424.1
# of episode :220, avg score : -1354.5
# of episode :240, avg score : -1394.9
# of episode :260, avg score : -1322.1
# of episode :280, avg score : -1277.5
# of episode :300, avg score : -1255.1
# of episode :320, avg score : -1571.4
# of episode :340, avg score : -1571.1
# of episode :360, avg score : -1580.8
# of episode :380, avg score : -1495.7
# of episode :400, avg score : -1449.2
# of episode :420, avg score : -1472.1
# of episode :440, avg score : -1475.0
# of episode :460, avg score : -1522.0
# of episode :480, avg score : -1405.5
# of episode :500, avg score : -1494.0
# of episode :520, avg score 