In [1]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import pandas as pd
import random
import math

import collections

from matplotlib import pyplot as plt

from torch.distributions import Normal, Categorical

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
GAME = "Pendulum-v0"
env = gym.make(GAME)

obs_n = env.observation_space.shape[0]
act_n = env.action_space.shape[0]

action_scale = env.action_space.high[0]

In [3]:
STD_MIN = -20
STD_MAX = 2

class Actor(nn.Module):
    def __init__(self, hidden = (128, 128)):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_n, hidden[0]),
            nn.ReLU(),
            nn.Linear(hidden[0], hidden[1]),
            nn.ReLU(),
        )
        
        self.mu = nn.Linear(hidden[1], act_n)
        self.log_std = nn.Linear(hidden[1], act_n)
        
        for p in self.parameters():
            p.data.normal_(0, 1e-14)
        
    def forward(self, x):
        x = self.net(x)
        
        mu = self.mu(x)
        logit_factor = torch.tanh(self.log_std(x)) + 1
        std_logit = STD_MIN + 0.5*(STD_MAX - STD_MIN) * logit_factor
        std = torch.exp(std_logit)
        
        dist = Normal(mu, std)
        pi = dist.rsample()
        log_pi = dist.log_prob(pi)
        
        mu = torch.tanh(mu) * action_scale
        pi = torch.tanh(pi) * action_scale
        
        return mu, pi, log_pi
    
class Critic(nn.Module):
    def __init__(self, hidden = (128, 128)):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_n, hidden[0]),
            nn.ReLU()
        )
        self.out = nn.Sequential(
            nn.Linear(hidden[0]+act_n, hidden[1]),
            nn.ReLU(),
            nn.Linear(hidden[1], 1),
        )
        
    def forward(self, obs, act):
        h = self.net(obs)
        x = torch.cat([h, act], 1)
        return self.out(x)

In [4]:
def get_episode(env, net=None, show=False):
    obs = env.reset()
    if show:
        env.render()
    ep = []
    count = 0
    
    while True:
        if net is None:
            act = env.action_space.sample()
        else:
            with torch.no_grad():
                _, pi, log_pi = net(torch.FloatTensor([obs]).to(device))
                act = pi.cpu().numpy()[0]
            
        next_obs, rew, done, _ = env.step(act)
        if show:
            env.render()
        count += rew
        if done:
            print(count, end=' ')
            
        step = (obs, act, next_obs, rew, done)
        ep.append(step)
        yield step
                
        obs = next_obs
        if done:
            break
        
class Memory():
    def __init__(self, size):
        self.data = collections.deque(maxlen=size)
        
    def append(self, data):
        self.data.append(data)
    
    def extend(self, data):
        self.data.extend(data)
        
    def __len__(self):
        return len(self.data)
    
    def sample(self, size):
        assert size <= len(self)
        return random.sample(self.data, size)

In [5]:
ACT_LR = 3e-5
ACT_CLIP = 5e-1
actor = Actor().to(device)
actor_optim = optim.Adam(actor.parameters(), lr = ACT_LR)

CRT_LR = 1e-4
CRT_CLIP = 1e0
q1 = Critic().to(device)
q1_target = Critic().to(device)
q1_target.load_state_dict(q1.state_dict())
q1_target.eval()

q2 = Critic().to(device)
q2_target = Critic().to(device)
q2_target.load_state_dict(q2.state_dict())
q2_target.eval()
q_optim = optim.Adam(list(q1.parameters()) + list(q2.parameters()), lr = CRT_LR)

EPOCH = 4000
BATCH = 64
REPEAT = 5
GAMMA = 0.99
ALPH = 0.2
TAU = 1e-2

ST_MAX = 20000
ST_INIT = 2e3
storage = Memory(ST_MAX)
while len(storage) < ST_INIT:
    ep = get_episode(env)
    storage.extend([step for step in ep])
env.close()

-1065.4127951244188 -959.334909060996 -1401.4506011561741 -1392.4472609479715 -1305.5754711955979 -896.6362547756713 -886.8930023966681 -1380.3654445152101 -793.4858866253998 -1462.1185644035668 

In [None]:
for epoch in range(EPOCH):
    ep = get_episode(env, actor)
    loss = [0,0]
    print(epoch, end=' ')
    for j, step in enumerate(ep):
        storage.append(step)
        for i in range(REPEAT):
            sample = list(zip(*storage.sample(BATCH)))

            obs = torch.FloatTensor(sample[0]).to(device)
            probs = torch.FloatTensor(sample[1]).to(device)
            next_obs = torch.FloatTensor(sample[2]).to(device)
            rew = torch.FloatTensor(sample[3]).to(device).unsqueeze(1)
            done = torch.FloatTensor(sample[4]).to(device).unsqueeze(1)

            _, cur_pi, cur_log_pi = actor(obs)
            _, next_pi, next_log_pi = actor(next_obs)

            q1_pred = q1(obs, probs)
            q2_pred = q2(obs, probs)

            min_q_next = torch.min(q1_target(next_obs, next_pi), q2_target(next_obs, next_pi))
            min_q = torch.min(q1(obs, cur_pi), q2(obs, cur_pi))

            v_next = (min_q_next - ALPH*next_log_pi)
            q_target = rew + GAMMA * (1-done) * v_next
            
            #

            q1_loss = F.mse_loss(q1_pred, q_target.detach())
            q2_loss = F.mse_loss(q2_pred, q_target.detach())
            q_loss = q1_loss + q2_loss
            actor_loss = (ALPH*cur_log_pi - min_q).mean()
            
            q_optim.zero_grad()
            q_loss.backward()
            q_optim.step()
            
            actor_optim.zero_grad()
            actor_loss.backward()
            actor_optim.step()

            for l, t in zip(q1.parameters(), q1_target.parameters()):
                t.data.copy_(l.data*TAU + t.data*(1-TAU))
            for l, t in zip(q2.parameters(), q2_target.parameters()):
                t.data.copy_(l.data*TAU + t.data*(1-TAU))
            
            loss[0] += actor_loss.item()
            loss[1] += q_loss.item()
    print(loss[0]/(REPEAT*(j+1)), loss[1]/(REPEAT*(j+1)))

0 -1164.2690592025485 22.620750230550765 45.48706240653992
1 -1326.3695022035797 67.0290850982666 117.02176022148133
2 -1168.1848270693913 104.96734637451172 211.806271900177
3 -1281.9446734525613 137.07224115753175 250.96506354904176
4 -1411.1252882447227 168.92873377990722 292.6690738353729
5 -1618.370359763798 200.71528378295898 406.2027742137909
6 -1370.664154445216 227.86358085632324 505.6253550758362
7 -1518.383385316265 250.8569178161621 580.5058912925721
8 -1569.8677938695412 273.9906424407959 727.8111974906922
9 -1050.3983667955392 297.8849736633301 855.0175288162231
10 -1192.4537986367898 318.0206702880859 964.5460550689697
11 -1425.1604330488606 333.2628228149414 1022.6053234443665
12 -1548.2389779067494 343.8177347412109 1219.1489986305237
13 -1435.1310521544067 353.96732656860354 1251.4891275787354
14 -1370.9793324352747 364.83096267700193 1230.6084582500457
15 -1650.1519033706224 376.47714126586914 1260.0177343006135
16 -1508.9141825623415 386.38756546020505 1401.74952117

In [None]:
actor(obs)