In [1]:
import gym
from utils import *

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math

class Critic(nn.Module):
    def __init__(self, in_, out_, hidden_ = 256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden_),
            nn.ReLU(),
            nn.Linear(hidden_, out_)
        )
        
    def get_action(self, act_v):
        global counter
        counter += 1
        eps = EPS_END+(EPS_START- EPS_END) * math.exp(-1*counter/EPS_DECAY)
        if random.random()<eps:
            return env.action_space.sample()
        return act_v.argmax()
    
    def forward(self, x):
        return self.net(x)

In [3]:
LR = 0.0005
GAMMA = 0.95
TARGET_UPDATE = 5

EPS_START = 0.9
EPS_END = 0.02
EPS_DECAY = 2000
counter = 0

env = gym.make("CartPole-v1")
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

net = Critic(obs_n, act_n).cuda()
net_tgt = Critic(obs_n, act_n).cuda()
net_tgt.load_state_dict(net.state_dict())
opt = optim.Adam(net.parameters(), LR)
loss = nn.MSELoss()

agent = Agent(env, net)

ST_SIZE = 10000
ST_INIT = 1000
ST_SAMPLE = 256
storage = Replay(ST_SIZE)

In [4]:
EPOCH = 1000

for epoch in range(EPOCH):
    for i, step in enumerate(agent.episode(epoch)):
        storage.push(step)
        
        if len(storage) < ST_INIT:
            continue
        
        obs, act_v, act, next_obs, rew, done, _ = list(zip(*storage.sample(ST_SAMPLE)))
        
        obs_ = torch.FloatTensor(obs).cuda()
        act_v_ = torch.FloatTensor(act_v).cuda()
        act_ = torch.LongTensor(act).unsqueeze(1).cuda()
        next_obs_ = torch.FloatTensor(next_obs).cuda()
        rew_ = torch.FloatTensor(rew).unsqueeze(1).cuda()
        done_ = torch.BoolTensor(done).cuda()
        
        q_next = net_tgt(next_obs_).max(1)[0].unsqueeze(1)
        q_next[done_] = -1000
        q = rew_ + GAMMA * q_next
        q = q.detach()
        q_pred = net(obs_).gather(1, act_)
        
        opt.zero_grad()
        loss_ = loss(q_pred, q)
        loss_.backward()
        opt.step()
    
    if epoch%TARGET_UPDATE == 0:
        net_tgt.load_state_dict(net.state_dict())
    print(epoch, i + 1)
        
agent.reset()

0 45
1 20
2 15
3 29
4 12
5 15
6 16
7 13
8 17
9 15
10 12
11 11
12 39
13 15
14 10
15 32
16 8
17 11
18 13
19 25
20 11
21 17
22 30
23 12
24 21
25 10
26 9
27 14
28 9
29 12
30 9
31 20
32 11
33 13
34 14
35 13
36 31
37 9
38 11
39 12
40 12
41 14
42 31
43 16
44 19
45 23
46 17
47 11
48 12
49 14
50 9
51 16
52 13
53 12
54 17
55 12
56 18
57 12
58 25
59 15
60 15
61 11
62 14
63 11
64 12
65 11
66 12
67 11
68 11
69 15
70 17
71 13
72 14
73 13
74 11
75 10
76 8
77 14
78 23
79 12
80 13
81 11
82 12
83 61
84 36
85 70
86 89
87 35
88 23
89 39
90 51
91 59
92 74
93 50
94 25
95 20
96 27
97 17
98 19
99 20
100 21
101 19
102 31
103 17
104 19
105 22
106 18
107 22
108 14
109 19
110 15
111 17
112 19
113 19
114 22
115 20
116 17
117 15
118 19
119 12
120 17
121 13
122 13
123 13
124 16
125 13
126 18
127 20
128 26
129 24
130 25
131 23
132 20
133 17
134 27
135 26
136 18
137 24
138 28
139 28
140 24
141 32
142 52
143 49
144 40
145 38
146 33
147 32
148 39
149 117
150 35
151 64
152 103
153 103
154 51
155 241
156 178
157 45
158 20

KeyboardInterrupt: 