In [79]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Categorical

import numpy as np
import random

import collections

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [101]:
env = gym.make("Acrobot-v1")
env._max_episode_steps = 400
obs_n = env.observation_space.shape[0]
act_n = env.action_space.n

In [102]:
class Actor(nn.Module):
    def __init__(self, hidden=(128,128)):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_n, hidden[0]),
            nn.ReLU(),
            nn.Linear(hidden[0], hidden[1]),
            nn.ReLU(),
            nn.Linear(hidden[1], act_n),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.net(x)
    
class Critic(nn.Module):
    def __init__(self, hidden=(128,128)):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_n, hidden[0]),
            nn.ReLU(),
            nn.Linear(hidden[0], hidden[1]),
            nn.ReLU(),
            nn.Linear(hidden[1], act_n)
        )
    
    def forward(self, x):
        return self.net(x)*2

In [103]:
def get_episode(env, net=None):
    obs = env.reset()
    
    cnt= 0
    episode = []
    while True:
        if net is None:
            act = env.action_space.sample()
        else:
            with torch.no_grad():
                logits = net(torch.FloatTensor([obs]).to(device))
                probs = F.softmax(logits, 1)
                act = np.random.choice(act_n, 1, replace=True, p = probs.cpu().numpy()[0])[0]

        next_obs, rew, done, _ = env.step(act)
        cnt += rew
        episode.append((obs, act, rew, next_obs, done))
        obs = next_obs
        if done:
            return episode

In [107]:
ACT_LR = 0.0005
CRT_LR = 0.007

actor = Actor().to(device)
actor_optim = optim.Adam(actor.parameters(), lr= ACT_LR)

q1 = Critic().to(device)
q1_target = Critic().to(device)
q1_target.load_state_dict(q1.state_dict())
q1_target.eval()
q1_optim = optim.Adam(q1.parameters(), lr = CRT_LR)

# q2 = Critic().to(device)
# q2_target = Critic().to(device)
# q2_target.load_state_dict(q2.state_dict())
# q2_target.eval()
# q2_optim = optim.Adam(q2.parameters(), lr = CRT_LR)

BATCH = 5
EPOCH = 1000

ALPH = 5
GAMMA = 0.99
TAU = 0.01
TGT_UPDATE = 4

In [None]:
for epoch in range(EPOCH):
    data= []
    for i in range(BATCH):
        data.extend(get_episode(env, actor))
    size = len(data)
    data = list(zip(*data))
    
    obs = torch.FloatTensor(data[0]).to(device)
    act = torch.FloatTensor(data[1]).unsqueeze(1).long().to(device)
    rew = torch.FloatTensor(data[2]).unsqueeze(1).to(device)
    next_obs = torch.FloatTensor(data[3]).to(device)
    done = torch.FloatTensor(data[4]).unsqueeze(1).to(device)
    
    q_pred = q1(obs)
    q_next = q1(next_obs)
    q_target_next = q1_target(next_obs)
    
    q_target = rew + (1.-done) * GAMMA * q_target_next.max(1)[0].unsqueeze(1)
    q_loss = F.mse_loss(q_pred.gather(1, act), q_target.detach())
    
    logits = actor(obs)
    probs = F.softmax(logits, 1)
    log_probs = F.log_softmax(logits, 1)
    
    v = q_pred.max(1)[0].unsqueeze(1)
    v_next = q_next.max(1)[0].unsqueeze(1)
    adv = rew + (1.-done) * GAMMA * v_next - v
    
    policy_loss = -log_probs.gather(1, act) * adv.detach()
    policy_loss = policy_loss.mean()
    
    entropy_loss = -probs*log_probs
    entropy_loss = -entropy_loss.sum(1).mean()
    
    q1_optim.zero_grad()
    q_loss.backward()
    q1_optim.step()
    
    actor_optim.zero_grad()
    (policy_loss + entropy_loss * ALPH).backward()
    actor_optim.step()
    
    if epoch%TGT_UPDATE == TGT_UPDATE -1:
        q1_target.load_state_dict(q1.state_dict())
    
#     for active, target in zip(q1_target.parameters(), q1_target.parameters()):
#         target.data.copy_(active.data*TAU + target.data*(1-TAU))
    
    print(epoch, size//BATCH, policy_loss.item(), entropy_loss.item(), q_loss.item())

0 400 -1.100490689277649 -1.0968422889709473 0.7688949108123779
1 400 -1.091426134109497 -1.0981680154800415 0.1775425672531128
2 400 -1.0889679193496704 -1.098305106163025 0.2273746281862259
3 400 -1.0905869007110596 -1.097922682762146 0.17423531413078308
4 400 -1.0930625200271606 -1.0979046821594238 0.8243275880813599
5 400 -1.0902132987976074 -1.0981847047805786 0.6213397979736328
6 400 -1.0878074169158936 -1.0982903242111206 0.295564740896225
7 400 -1.0820720195770264 -1.0982519388198853 0.1617027372121811
8 400 -1.0777937173843384 -1.0983917713165283 0.7313821315765381
9 400 -1.0693484544754028 -1.098337173461914 0.19819524884223938
10 400 -1.0613654851913452 -1.098413348197937 0.36180537939071655
11 400 -1.056788444519043 -1.0983587503433228 0.695479154586792
12 400 -1.0612369775772095 -1.0983645915985107 0.767270565032959
13 400 -1.0602086782455444 -1.098266839981079 0.5979579091072083
14 400 -1.0546168088912964 -1.0981419086456299 0.2773421108722687
15 400 -1.0477561950683594 -

125 386 -0.8305429220199585 -1.0857181549072266 11.353801727294922
126 343 -0.8235642313957214 -1.0877625942230225 7.15909481048584
127 359 -0.8218488097190857 -1.0892295837402344 7.170810222625732
128 344 -0.8499815464019775 -1.090505599975586 16.226436614990234
129 383 -0.8696690201759338 -1.0919225215911865 8.136871337890625
130 380 -0.8517560958862305 -1.092980146408081 5.638513088226318
131 400 -0.8535199165344238 -1.094016671180725 6.331318378448486
132 400 -0.8738980889320374 -1.0947109460830688 16.60857391357422
133 400 -0.8830201029777527 -1.0952328443527222 9.927619934082031
134 388 -0.8806273341178894 -1.0954173803329468 8.232995986938477
135 400 -0.8716461062431335 -1.0954384803771973 5.130454063415527
136 400 -0.8836795091629028 -1.095398187637329 17.800275802612305
137 400 -0.8917862176895142 -1.0952351093292236 12.938148498535156
138 400 -0.8951452970504761 -1.0950706005096436 7.149137496948242
139 400 -0.8770073652267456 -1.094544529914856 5.531949996948242
140 400 -0.8

249 400 -0.888936460018158 -1.098052740097046 28.354875564575195
250 400 -0.8980958461761475 -1.0981050729751587 24.836923599243164
251 400 -0.9091688394546509 -1.0981346368789673 12.400674819946289
252 400 -0.8964598774909973 -1.098147988319397 20.189781188964844
253 389 -0.8948830962181091 -1.0981321334838867 19.228595733642578
254 400 -0.9004342555999756 -1.098124623298645 8.996601104736328
255 400 -0.9104244112968445 -1.098056674003601 10.287306785583496
256 396 -0.9163981080055237 -1.0979880094528198 12.1209077835083
257 400 -0.9254488945007324 -1.0979773998260498 7.384682655334473
258 400 -0.9238749146461487 -1.0978314876556396 7.831243515014648
259 400 -0.9141194224357605 -1.0977894067764282 7.223423480987549
260 400 -0.911492109298706 -1.0978533029556274 5.984072208404541
261 400 -0.8949152231216431 -1.0978227853775024 4.743949890136719
262 400 -0.8921169638633728 -1.0976601839065552 5.683803081512451
263 400 -0.8951807618141174 -1.0976754426956177 2.395581007003784
264 400 -0.

In [78]:
actor.eval()
cnt = 0
obs = env.reset()
env.render()
while True:
    with torch.no_grad():
        logits = actor(torch.FloatTensor([obs]).to(device))
        probs = F.softmax(logits, 1)
        act = np.random.choice(act_n, 1, replace=True, p = probs.cpu().numpy()[0])[0]

    next_obs, rew, done, _ = env.step(act)
    env.render()
    cnt += rew
    obs = next_obs
    if done:
        break
print(cnt)
actor.train()
env.close()

500.0
