In [2]:
import gym

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import random

import functools

GAME = 'CartPole-v1'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
env = gym.make(GAME)
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

In [13]:
class QNet(nn.Module):
    def __init__(self, in_, out_, hidden=64):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_)
        )
    
    def forward(self, x):
        return self.net(x)

In [74]:
import collections

class Memory:
    def __init__(self, size):
        self.data = collections.deque(maxlen=size)
        
    def __len__(self):
        return len(self.data)
    
    def append(self, item):
        self.data.append(item)
        
    def extend(self, items):
        self.data.extend(items)
    
    def sample(self, size):
        assert len(self) >= size
        indices = np.random.choice(len(self), size, False)
        sample = [self.data[i] for i in indices]
        return self.parse(sample)
    
    def parse(self, sample):
        parsed = list(zip(*sample))
        
        obs = torch.FloatTensor(parsed[0]).to(device)
        act = torch.LongTensor(parsed[1]).unsqueeze(1).to(device)
        next_obs = torch.FloatTensor(parsed[2]).to(device)
        rew = torch.FloatTensor(parsed[3]).unsqueeze(1).to(device)
        done = torch.FloatTensor(parsed[4]).unsqueeze(1).to(device)
        
        return obs, act, next_obs, rew, done

In [92]:
EPOCH = 1000
LR = 0.001
GAMMA = 0.99
MEM_SIZE = 2000
MEM_INIT = 500
SAMPLE = 128

TARGET_UPDATE = 5

critic = QNet(obs_n, act_n).to(device)
critic_opt = optim.Adam(critic.parameters(), lr=LR)

critic_target = QNet(obs_n, act_n).to(device)
critic_target.load_state_dict(critic.state_dict())
critic_target.eval()

memory = Memory(MEM_SIZE)

obs = env.reset()
while len(memory) < MEM_INIT:
    act = env.action_space.sample()
    next_obs, rew, done, _ = env.step(act)
    memory.append((obs, act, next_obs, rew, done))
    obs = next_obs
    if done:
        obs = env.reset()

In [None]:
for epoch in range(EPOCH):
    obs = env.reset()
    loss_sum = 0
    rew_sum = 0
    while True:
        if random.random() < np.exp(-epoch/100):
            act = env.action_space.sample()
        else:
            with torch.no_grad():
                act_v = critic(torch.FloatTensor([obs]).to(device))
                act = act_v.max(1)[1].cpu().numpy()[0]
        
        next_obs, rew, done, _ = env.step(act)
        rew_sum += rew
        memory.append((obs, act, next_obs, rew, done))
        
        #training - simple q-learning with target network
        obs_sample, act_sample, next_obs_sample, rew_sample, done_sample = memory.sample(SAMPLE)
        
        q_pred = critic(obs_sample).gather(1, act_sample)
        q_next = critic_target(next_obs_sample).detach()
        q_target = rew + GAMMA * (1-done_sample) * q_next.max(1)[0].unsqueeze(1)
        
        critic_loss = F.mse_loss(q_pred, q_target)
        critic_opt.zero_grad()
        critic_loss.backward()
        critic_opt.step()
        
        loss_sum += critic_loss.data.cpu().numpy()
        
        if done:
            break
        obs = next_obs
        
    print(epoch, rew_sum, loss_sum)
    if epoch % TARGET_UPDATE == TARGET_UPDATE - 1:
        critic_target.load_state_dict(critic.state_dict())
        critic_target.eval()

0 13.0 1.4757527858018875
1 26.0 1.9395144507288933
2 14.0 0.782274916768074
3 15.0 0.623052978888154
4 12.0 0.4666457287967205
5 64.0 21.438378289341927
6 72.0 4.222232828848064
7 40.0 1.866068258881569
8 20.0 0.9269589819014072
9 18.0 0.9066288731992245
10 16.0 11.841343194246292
11 15.0 3.6297172158956528
12 15.0 2.7825524136424065
13 32.0 5.512712623924017
14 12.0 2.272049516439438
15 12.0 11.805711209774017
16 18.0 7.627584740519524
17 13.0 4.966764569282532
18 12.0 4.177481785416603
19 11.0 3.955346181988716
20 13.0 13.935448408126831
21 22.0 14.210505671799183
22 18.0 10.486097544431686
23 23.0 12.073438882827759
24 11.0 5.604661718010902
25 17.0 19.983860552310944
26 19.0 14.694125294685364
27 14.0 10.55323476344347
28 27.0 20.077901154756546
29 15.0 11.170775145292282
30 10.0 13.016668677330017
31 32.0 28.683239832520485
32 11.0 8.193803071975708
33 51.0 39.243862479925156
34 37.0 32.91406315565109
35 13.0 17.397015392780304
36 33.0 34.131644904613495
37 37.0 36.55689910054207

294 113.0 3.9935500444844365
295 118.0 5.48254253808409
296 118.0 4.032810552045703
297 119.0 4.2583521362394094
298 118.0 4.742721417918801
299 123.0 4.471115238964558
300 115.0 7.028815028257668
301 119.0 7.1517453556880355
302 114.0 6.211063971742988
303 115.0 5.219894526526332
304 111.0 6.018487532623112
305 107.0 12.49426136445254
306 102.0 25.052767112851143
307 117.0 29.74973373953253
308 125.0 29.078563819639385
309 114.0 27.28889308217913
310 127.0 29.996587141416967
311 127.0 26.603164572268724
312 121.0 23.91496716812253
313 122.0 28.461707582697272
314 126.0 25.474444112740457
315 130.0 43.78756702039391
316 128.0 33.68724964838475
317 138.0 33.17687949817628
318 132.0 25.03647953271866
319 122.0 29.456359111703932
320 125.0 17.7144087087363
321 119.0 23.794248891994357
322 126.0 16.191788487136364
323 119.0 15.086150428280234
324 131.0 11.198785591870546
325 131.0 7.9864520179107785
326 129.0 2.3355018035508692
327 125.0 2.258252264931798
328 127.0 2.4243955397978425
329 1

581 93.0 56.289156056940556
582 126.0 121.73022018745542
583 142.0 230.69027038663626
584 143.0 163.2161608338356
585 141.0 252.89636337384582
586 153.0 272.8876776956022
587 154.0 219.9892657659948
588 150.0 217.46263831853867
589 161.0 288.83058734610677
590 172.0 301.5306608825922
591 160.0 265.8710467219353
592 150.0 185.0608152449131
593 169.0 247.1714870519936
594 166.0 228.7565893754363
595 161.0 121.15430251136422
596 178.0 91.26844260469079
597 166.0 86.326626021415
598 174.0 91.38692975789309
599 189.0 86.67254461720586
600 189.0 80.07818285003304
601 186.0 86.88378889486194
602 197.0 93.48751696571708
603 202.0 94.8502232581377
604 221.0 95.01931618154049
605 221.0 95.41097378358245
606 216.0 84.26033167541027
607 232.0 104.17463549412787
608 280.0 119.80336894467473
609 334.0 125.00980868376791
610 259.0 92.0406849887222
611 261.0 102.03978919796646
612 304.0 104.82481945306063
613 323.0 131.17729790508747
614 337.0 124.14556998200715
615 284.0 134.9717866331339
616 341.0 1

In [94]:
obs

array([ 0.03141891,  0.19712675,  0.00175316, -0.33283627])