In [1]:
import gym

import numpy as np

import torch
import torch.nn as nn
import torch.optim as opt

import torch.nn.functional as F

import random

from collections import namedtuple
import collections

import math

In [2]:
step = namedtuple("step", ("state", "action", "next_state", "reward", "done"))

class Memory():
    def __init__(self, size:int):
        self.saved = collections.deque(maxlen = size)
           
    def fill_memory(self, env, init_size:int):
        count = 0
        obs = env.reset()
        while count<init_size:
            act = env.action_space.sample()
            next_obs, rew, done, _ = env.step(act)
            
            if done:
                rew= 0

            self.push(step(obs, act, next_obs, rew, done))
            count += 1
            if done:
                next_obs = env.reset()
            obs = next_obs
    
    def sample(self, size):
        if len(self.saved) < size:
            return None
        else:
            return random.sample(self.saved, size)
            
    def push(self, data):
        self.saved.append(data)
        
    def __len__(self):
        return len(self.saved)

In [3]:
class PGNet(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=128):
        super(PGNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, out_size)
        )
        
    def forward(self, x):
        return self.net(x)

In [4]:
env = gym.make("CartPole-v0")
obs = env.reset()

In [5]:
IN_SIZE = env.observation_space.shape[0]
OUT_SIZE = env.action_space.n
HIDDEN = 64

LR = 0.001
GAMMA = 0.95

EPS = 0.9
EPS_ = 0.01
EPS_DECAY = 100

HIST_INIT = 10000
HIST_SIZE = 30000
BATCH = 256

EPOCH = 2000
TARGET_LOAD = 10
FRAME_SAVE = 20

POS_REW = 100
NEG_REW = -1

step_count = 0

hist = Memory(HIST_SIZE)
hist.fill_memory(env, HIST_INIT)
net = PGNet(IN_SIZE, OUT_SIZE)

optim = opt.Adam(net.parameters(), lr=LR)
criterion = nn.MSELoss()

snapshot = []

In [6]:
for epoch in range(EPOCH):
    obs = env.reset()
    count = 0
    reward = 0
    loss_sum = 0
    frame = []
    
    while True:
        eps = EPS_ + (EPS-EPS_) * math.exp(-1*epoch/EPS_DECAY)
        
        if random.random() < eps:
            act = env.action_space.sample()
        else:
            with torch.no_grad():
                out = net(torch.FloatTensor(obs))
                act = np.random.choice(env.action_space.n, p=F.softmax(out).numpy())
                
        next_obs, rew, done, _ = env.step(act)
        if done:
            rew = 0
        step_count += 1
        count += 1 
        
        hist.push(step(obs, act, next_obs, rew, done))
        obs = next_obs
        
        sample = hist.sample(BATCH)
        if sample:
            sample = step(*zip(*sample))
            
            states = torch.FloatTensor(sample.state)
            actions = torch.LongTensor(sample.action)
            next_states = torch.FloatTensor(sample.next_state)
            rewards  =torch.FloatTensor(sample.reward)
            dones = torch.BoolTensor(sample.done)
            
            optim.zero_grad()
            
            logit = net(states)
            prob = F.softmax(logit, 1)
            prob_log = F.log_softmax(logit, 1).gather(1, actions.unsqueeze(-1))
            
            v_loss = -(rewards * prob_log).mean()
            entropy_loss = (prob * prob_log).sum(dim=1).mean()
            
            loss = v_loss
            
            loss.backward()
            optim.step()
            
            loss_sum+= loss.item()
            
        if done:
            break
    print("epoch %d] count : %d, loss : %.5f"%(epoch, count, loss_sum/count))

  app.launch_new_instance()


epoch 0] count : 13, loss : 0.66438
epoch 1] count : 14, loss : 0.66410
epoch 2] count : 25, loss : 0.66232
epoch 3] count : 11, loss : 0.66091
epoch 4] count : 10, loss : 0.66245
epoch 5] count : 22, loss : 0.66171
epoch 6] count : 14, loss : 0.65986
epoch 7] count : 19, loss : 0.66392
epoch 8] count : 20, loss : 0.66193
epoch 9] count : 13, loss : 0.65768
epoch 10] count : 22, loss : 0.66231
epoch 11] count : 44, loss : 0.66187
epoch 12] count : 17, loss : 0.66602
epoch 13] count : 41, loss : 0.66094
epoch 14] count : 16, loss : 0.66129
epoch 15] count : 10, loss : 0.66707
epoch 16] count : 12, loss : 0.66094
epoch 17] count : 18, loss : 0.66126
epoch 18] count : 24, loss : 0.66292
epoch 19] count : 13, loss : 0.66162
epoch 20] count : 15, loss : 0.66138
epoch 21] count : 24, loss : 0.66324
epoch 22] count : 11, loss : 0.65468
epoch 23] count : 61, loss : 0.66143
epoch 24] count : 59, loss : 0.65931
epoch 25] count : 30, loss : 0.66314
epoch 26] count : 13, loss : 0.66519
epoch 27] c

epoch 218] count : 19, loss : 0.66193
epoch 219] count : 32, loss : 0.65857
epoch 220] count : 32, loss : 0.65974
epoch 221] count : 15, loss : 0.65916
epoch 222] count : 21, loss : 0.66298
epoch 223] count : 68, loss : 0.66193
epoch 224] count : 18, loss : 0.65770
epoch 225] count : 16, loss : 0.66268
epoch 226] count : 22, loss : 0.66013
epoch 227] count : 13, loss : 0.66172
epoch 228] count : 16, loss : 0.66221
epoch 229] count : 22, loss : 0.66194
epoch 230] count : 30, loss : 0.66083
epoch 231] count : 21, loss : 0.65763
epoch 232] count : 22, loss : 0.66354
epoch 233] count : 27, loss : 0.66349
epoch 234] count : 46, loss : 0.66076
epoch 235] count : 15, loss : 0.66096
epoch 236] count : 38, loss : 0.66228
epoch 237] count : 21, loss : 0.66039
epoch 238] count : 18, loss : 0.65915
epoch 239] count : 13, loss : 0.66203
epoch 240] count : 11, loss : 0.65974
epoch 241] count : 27, loss : 0.65984
epoch 242] count : 18, loss : 0.65839
epoch 243] count : 15, loss : 0.66166
epoch 244] c

epoch 433] count : 29, loss : 0.66352
epoch 434] count : 16, loss : 0.65886
epoch 435] count : 55, loss : 0.65817
epoch 436] count : 15, loss : 0.66310
epoch 437] count : 17, loss : 0.65556
epoch 438] count : 26, loss : 0.66075
epoch 439] count : 14, loss : 0.66391
epoch 440] count : 14, loss : 0.65970
epoch 441] count : 23, loss : 0.66006
epoch 442] count : 25, loss : 0.66205
epoch 443] count : 25, loss : 0.66102
epoch 444] count : 22, loss : 0.65880
epoch 445] count : 14, loss : 0.66417
epoch 446] count : 26, loss : 0.66138
epoch 447] count : 11, loss : 0.66114
epoch 448] count : 10, loss : 0.65461
epoch 449] count : 9, loss : 0.66391
epoch 450] count : 13, loss : 0.66187
epoch 451] count : 18, loss : 0.65895
epoch 452] count : 34, loss : 0.65803
epoch 453] count : 32, loss : 0.66370
epoch 454] count : 17, loss : 0.65783
epoch 455] count : 12, loss : 0.65703
epoch 456] count : 15, loss : 0.66193
epoch 457] count : 35, loss : 0.65968
epoch 458] count : 52, loss : 0.66082
epoch 459] co

epoch 648] count : 16, loss : 0.66189
epoch 649] count : 47, loss : 0.65984
epoch 650] count : 31, loss : 0.65992
epoch 651] count : 19, loss : 0.66113
epoch 652] count : 21, loss : 0.65739
epoch 653] count : 12, loss : 0.65863
epoch 654] count : 13, loss : 0.66166
epoch 655] count : 14, loss : 0.65744
epoch 656] count : 18, loss : 0.65414
epoch 657] count : 51, loss : 0.65827
epoch 658] count : 24, loss : 0.66437
epoch 659] count : 19, loss : 0.65835
epoch 660] count : 60, loss : 0.65924
epoch 661] count : 32, loss : 0.65947
epoch 662] count : 23, loss : 0.65758
epoch 663] count : 15, loss : 0.65973
epoch 664] count : 10, loss : 0.66319
epoch 665] count : 12, loss : 0.65569
epoch 666] count : 26, loss : 0.65942
epoch 667] count : 17, loss : 0.65978
epoch 668] count : 22, loss : 0.66334
epoch 669] count : 23, loss : 0.65804
epoch 670] count : 83, loss : 0.65988
epoch 671] count : 34, loss : 0.66139
epoch 672] count : 43, loss : 0.65863
epoch 673] count : 21, loss : 0.65938
epoch 674] c

epoch 863] count : 11, loss : 0.65466
epoch 864] count : 18, loss : 0.66086
epoch 865] count : 24, loss : 0.66059
epoch 866] count : 15, loss : 0.66107
epoch 867] count : 20, loss : 0.66034
epoch 868] count : 16, loss : 0.65837
epoch 869] count : 22, loss : 0.66085
epoch 870] count : 13, loss : 0.66047
epoch 871] count : 35, loss : 0.65728
epoch 872] count : 95, loss : 0.65780
epoch 873] count : 14, loss : 0.65827
epoch 874] count : 31, loss : 0.65926
epoch 875] count : 14, loss : 0.66175
epoch 876] count : 12, loss : 0.66284
epoch 877] count : 15, loss : 0.65890
epoch 878] count : 21, loss : 0.66287
epoch 879] count : 13, loss : 0.65866
epoch 880] count : 18, loss : 0.65924
epoch 881] count : 17, loss : 0.65850
epoch 882] count : 21, loss : 0.66321
epoch 883] count : 29, loss : 0.66238
epoch 884] count : 15, loss : 0.65925
epoch 885] count : 15, loss : 0.65858
epoch 886] count : 47, loss : 0.65807
epoch 887] count : 12, loss : 0.65801
epoch 888] count : 16, loss : 0.65923
epoch 889] c

epoch 1076] count : 14, loss : 0.66334
epoch 1077] count : 27, loss : 0.65989
epoch 1078] count : 33, loss : 0.65873
epoch 1079] count : 23, loss : 0.66194
epoch 1080] count : 10, loss : 0.65963
epoch 1081] count : 27, loss : 0.66185
epoch 1082] count : 35, loss : 0.66027
epoch 1083] count : 18, loss : 0.65679
epoch 1084] count : 19, loss : 0.65639
epoch 1085] count : 19, loss : 0.65809
epoch 1086] count : 30, loss : 0.65868
epoch 1087] count : 29, loss : 0.65847
epoch 1088] count : 16, loss : 0.65581
epoch 1089] count : 17, loss : 0.65677
epoch 1090] count : 24, loss : 0.65710
epoch 1091] count : 57, loss : 0.66116
epoch 1092] count : 11, loss : 0.65824
epoch 1093] count : 27, loss : 0.65722
epoch 1094] count : 19, loss : 0.66144
epoch 1095] count : 16, loss : 0.66052
epoch 1096] count : 20, loss : 0.65733
epoch 1097] count : 18, loss : 0.66212
epoch 1098] count : 14, loss : 0.65878
epoch 1099] count : 40, loss : 0.65860
epoch 1100] count : 40, loss : 0.65796
epoch 1101] count : 19, l

epoch 1286] count : 13, loss : 0.65975
epoch 1287] count : 15, loss : 0.65813
epoch 1288] count : 21, loss : 0.65756
epoch 1289] count : 34, loss : 0.65839
epoch 1290] count : 32, loss : 0.65685
epoch 1291] count : 53, loss : 0.65844
epoch 1292] count : 17, loss : 0.66084
epoch 1293] count : 55, loss : 0.65682
epoch 1294] count : 35, loss : 0.66070
epoch 1295] count : 42, loss : 0.65923
epoch 1296] count : 25, loss : 0.66002
epoch 1297] count : 27, loss : 0.65880
epoch 1298] count : 30, loss : 0.66085
epoch 1299] count : 26, loss : 0.65950
epoch 1300] count : 30, loss : 0.65979
epoch 1301] count : 24, loss : 0.65908
epoch 1302] count : 16, loss : 0.66217
epoch 1303] count : 14, loss : 0.65989
epoch 1304] count : 21, loss : 0.65888
epoch 1305] count : 90, loss : 0.65804
epoch 1306] count : 14, loss : 0.65732
epoch 1307] count : 21, loss : 0.66163
epoch 1308] count : 30, loss : 0.65508
epoch 1309] count : 28, loss : 0.65978
epoch 1310] count : 12, loss : 0.66043
epoch 1311] count : 15, l

epoch 1496] count : 12, loss : 0.65466
epoch 1497] count : 9, loss : 0.65717
epoch 1498] count : 15, loss : 0.65105
epoch 1499] count : 25, loss : 0.65619
epoch 1500] count : 12, loss : 0.65579
epoch 1501] count : 14, loss : 0.65843
epoch 1502] count : 59, loss : 0.65798
epoch 1503] count : 32, loss : 0.65611
epoch 1504] count : 52, loss : 0.65713
epoch 1505] count : 9, loss : 0.66268
epoch 1506] count : 46, loss : 0.65853
epoch 1507] count : 15, loss : 0.66196
epoch 1508] count : 16, loss : 0.65229
epoch 1509] count : 26, loss : 0.65956
epoch 1510] count : 22, loss : 0.65921
epoch 1511] count : 28, loss : 0.66052
epoch 1512] count : 14, loss : 0.65902
epoch 1513] count : 29, loss : 0.65510
epoch 1514] count : 14, loss : 0.66179
epoch 1515] count : 21, loss : 0.65722
epoch 1516] count : 16, loss : 0.65522
epoch 1517] count : 18, loss : 0.65962
epoch 1518] count : 17, loss : 0.66186
epoch 1519] count : 12, loss : 0.66000
epoch 1520] count : 50, loss : 0.65490
epoch 1521] count : 24, los

epoch 1706] count : 29, loss : 0.65575
epoch 1707] count : 24, loss : 0.65598
epoch 1708] count : 15, loss : 0.65571
epoch 1709] count : 14, loss : 0.65346
epoch 1710] count : 17, loss : 0.65994
epoch 1711] count : 13, loss : 0.65967
epoch 1712] count : 12, loss : 0.66015
epoch 1713] count : 22, loss : 0.65896
epoch 1714] count : 18, loss : 0.65839
epoch 1715] count : 46, loss : 0.66191
epoch 1716] count : 29, loss : 0.66042
epoch 1717] count : 18, loss : 0.65912
epoch 1718] count : 18, loss : 0.65497
epoch 1719] count : 27, loss : 0.65751
epoch 1720] count : 14, loss : 0.65871
epoch 1721] count : 32, loss : 0.65855
epoch 1722] count : 27, loss : 0.65827
epoch 1723] count : 32, loss : 0.65834
epoch 1724] count : 16, loss : 0.65951
epoch 1725] count : 21, loss : 0.66141
epoch 1726] count : 28, loss : 0.65654
epoch 1727] count : 30, loss : 0.66032
epoch 1728] count : 34, loss : 0.65633
epoch 1729] count : 16, loss : 0.65048
epoch 1730] count : 12, loss : 0.65707
epoch 1731] count : 30, l

epoch 1916] count : 28, loss : 0.65739
epoch 1917] count : 22, loss : 0.65441
epoch 1918] count : 30, loss : 0.65841
epoch 1919] count : 23, loss : 0.65602
epoch 1920] count : 15, loss : 0.65883
epoch 1921] count : 17, loss : 0.65578
epoch 1922] count : 15, loss : 0.65601
epoch 1923] count : 18, loss : 0.65842
epoch 1924] count : 13, loss : 0.65792
epoch 1925] count : 14, loss : 0.65502
epoch 1926] count : 39, loss : 0.65486
epoch 1927] count : 74, loss : 0.65710
epoch 1928] count : 11, loss : 0.66099
epoch 1929] count : 28, loss : 0.65403
epoch 1930] count : 12, loss : 0.65693
epoch 1931] count : 22, loss : 0.65766
epoch 1932] count : 11, loss : 0.65509
epoch 1933] count : 11, loss : 0.65666
epoch 1934] count : 18, loss : 0.65695
epoch 1935] count : 25, loss : 0.65682
epoch 1936] count : 26, loss : 0.65758
epoch 1937] count : 25, loss : 0.65924
epoch 1938] count : 12, loss : 0.65415
epoch 1939] count : 19, loss : 0.65849
epoch 1940] count : 18, loss : 0.65514
epoch 1941] count : 21, l

In [7]:
entropy_loss

tensor(-0.6831, grad_fn=<MeanBackward0>)

In [8]:
v_loss

tensor(0.6591, grad_fn=<NegBackward>)

In [9]:
np.random.choice(env.action_space.n, p=F.softmax(out).numpy())

  """Entry point for launching an IPython kernel.


0

In [10]:
import frame_player

In [11]:
idx = 5
print(snapshot[idx][0], snapshot[idx][1])
display_frames_as_gif(snapshot[idx][2])

IndexError: list index out of range