In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym

In [28]:
# hyperparameters
BATCH_SIZE = 32
LR = 0.1
EPSILON = 0.9
GAMMA = 0.9
TARGET_REPLACE_ITER = 100
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v0')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]

In [29]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(N_STATES, 10)
        self.fc1.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(10, N_ACTIONS)
        self.out.weight.data.normal_(0, 0.1)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        actions_value = self.out(x)
        return actions_value

In [65]:
class DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()
        self.learn_step_counter = 0
        self.memory_counter = 0
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES*2+2))
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
        
    def choose_action(self, x):
        print(x)
        
        x = Variable(torch.unsqueeze(torch.FloatTensor(x),0))# add one dim at beginning
        print(x)
        if np.random.uniform() < EPSILON:
            action_value = self.eval_net.forward(x)
            action = torch.max(action_value, 1)[1].data.numpy()[0]
        else:
            action = np.random.randint(0, N_ACTIONS)
        return action
    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s,[a,r],s_))
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index,:] = transition
        self.memory_counter += 1
    def learn(self):
        #target patameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        b_memory = self.memory[sample_index,:]
        b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
        b_a = Variable(torch.FloatTensor(b_memory[:, :N_STATES:N_STATES+1].astype(int)))
        b_r = Variable(torch.FloatTensor(b_memory[:, :N_STATES+1:N_STATES+2]))
        b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:])) #test later
        
        #q_eval
        
        q_eval = self.eval_net(b_s).max(1)[0] #get max value each line in batch
        q_next = self.target_net(b_s_).detach()
        q_target = b_r + GAMMA * q_next.max(1)[0] #test q_next.max(1)
        loss = self.loss_func(q_eval, q_target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
                                         
        
        

In [66]:
dqn = DQN()

In [67]:

print('\nCollecting experience...')
for i_episode in range(400):
    s = env.reset()
    ep_r = 0
    while True:
        env.render()
        a = dqn.choose_action(s)
        print(a)

        # take action
        s_, r, done, info = env.step(a)
        #print(s_)

        # modify the reward
        x, x_dot, theta, theta_dot = s_
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        r = r1 + r2
        print(r)

        dqn.store_transition(s, a, r, s_)

        ep_r += r
        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
            if done:
                print('Ep: ', i_episode,
                      '| Ep_r: ', round(ep_r, 2))

        if done:
            break
        s = s_


Collecting experience...
[ 0.04846513 -0.0016873   0.01597049  0.01593173]
tensor([[ 0.0485, -0.0017,  0.0160,  0.0159]])
1
0.6020454390002788
[ 0.04843138  0.19320202  0.01628912 -0.27166984]
tensor([[ 0.0484,  0.1932,  0.0163, -0.2717]])
0
0.6263779809712814
[ 0.05229542 -0.00214853  0.01085572  0.02610591]
tensor([[ 0.0523, -0.0021,  0.0109,  0.0261]])
0
0.6239029549805455
[ 0.05225245 -0.19742447  0.01137784  0.32219407]
tensor([[ 0.0523, -0.1974,  0.0114,  0.3222]])
0
0.5947808916033133
[ 0.04830396 -0.39270658  0.01782172  0.61844329]
tensor([[ 0.0483, -0.3927,  0.0178,  0.6184]])
0
0.5389964626620769
[ 0.04044983 -0.58807288  0.03019059  0.91668551]
tensor([[ 0.0404, -0.5881,  0.0302,  0.9167]])
0
0.45636005227634546
[ 0.02868837 -0.78358976  0.0485243   1.21870187]
tensor([[ 0.0287, -0.7836,  0.0485,  1.2187]])
0
0.34651251101719216
[ 0.01301658 -0.97930255  0.07289834  1.52618605]
tensor([[ 0.0130, -0.9793,  0.0729,  1.5262]])
0
0.2034587713190873
[-0.00656947 -1.17522471  0.

[-0.1134214  -1.38374718  0.17893277  2.21214425]
tensor([[-0.1134, -1.3837,  0.1789,  2.2121]])
0
-0.42437537984144846
[-0.03204379  0.03144458  0.04760025  0.02403185]
tensor([[-0.0320,  0.0314,  0.0476,  0.0240]])
0
0.4573411330467746
[-0.03141489 -0.16432653  0.04808089  0.3313451 ]
tensor([[-0.0314, -0.1643,  0.0481,  0.3313]])
0
0.4243306192485038
[-0.03470143 -0.36009874  0.05470779  0.63879436]
tensor([[-0.0347, -0.3601,  0.0547,  0.6388]])
0
0.3603294281198187
[-0.0419034  -0.5559391   0.06748368  0.94819148]
tensor([[-0.0419, -0.5559,  0.0675,  0.9482]])
0
0.2651509856235763
[-0.05302218 -0.75190158  0.08644751  1.26129129]
tensor([[-0.0530, -0.7519,  0.0864,  1.2613]])
0
0.13844069330467884
[-0.06806021 -0.94801617  0.11167333  1.57974858]
tensor([[-0.0681, -0.9480,  0.1117,  1.5797]])
0
-0.020314318938236875
[-0.08702054 -1.14427672  0.1432683   1.90506941]
tensor([[-0.0870, -1.1443,  0.1433,  1.9051]])
1
-0.21177068684638334
[-0.10990607 -0.95096426  0.18136969  1.66004848

0
-0.21503768501892945
[-0.10296688 -1.38570992  0.18265949  2.22752109]
tensor([[-0.1030, -1.3857,  0.1827,  2.2275]])
0
-0.4392978626075442
[0.02666979 0.04640593 0.0449115  0.00088531]
tensor([[0.0267, 0.0464, 0.0449, 0.0009]])
0
0.47397969473344537
[ 0.02759791 -0.14933038  0.04492921  0.30739327]
tensor([[ 0.0276, -0.1493,  0.0449,  0.3074]])
0
0.44587021995231735
[ 0.02461131 -0.34506279  0.05107708  0.61390033]
tensor([[ 0.0246, -0.3451,  0.0511,  0.6139]])
0
0.39012257984407195
[ 0.01771005 -0.54085987  0.06335508  0.92222306]
tensor([[ 0.0177, -0.5409,  0.0634,  0.9222]])
0
0.3065639300091787
[ 0.00689285 -0.736778    0.08179954  1.2341246 ]
tensor([[ 0.0069, -0.7368,  0.0818,  1.2341]])
0
0.18831793860950113
[-0.00784271 -0.93285065  0.10648203  1.55127242]
tensor([[-0.0078, -0.9329,  0.1065,  1.5513]])
0
0.03240857905570188
[-0.02649972 -1.12907649  0.13750748  1.87519068]
tensor([[-0.0265, -1.1291,  0.1375,  1.8752]])
0
-0.15606791087819472
[-0.04908125 -1.32540557  0.17501

[-0.04576252 -0.63667721  0.102435    0.96649776]
tensor([[-0.0458, -0.6367,  0.1024,  0.9665]])
0
0.0942418069006476
[-0.05849606 -0.83301457  0.12176496  1.28952328]
tensor([[-0.0585, -0.8330,  0.1218,  1.2895]])
0
-0.03584038412873419
[-0.07515635 -1.02945648  0.14755542  1.61771491]
tensor([[-0.0752, -1.0295,  0.1476,  1.6177]])
0
-0.19889958258742108
[-0.09574548 -1.22597782  0.17990972  1.9525177 ]
tensor([[-0.0957, -1.2260,  0.1799,  1.9525]])
0
-0.3955677706428775
[-0.04242048 -0.03943844  0.00456463  0.00612313]
tensor([[-0.0424, -0.0394,  0.0046,  0.0061]])
0
0.659616929348544
[-0.04320925 -0.23462556  0.00468709  0.30024275]
tensor([[-0.0432, -0.2346,  0.0047,  0.3002]])
0
0.6289906460125284
[-0.04790176 -0.429814    0.01069195  0.5944002 ]
tensor([[-0.0479, -0.4298,  0.0107,  0.5944]])
0
0.5686478249828115
[-0.05649804 -0.62508396  0.02257995  0.89043177]
tensor([[-0.0565, -0.6251,  0.0226,  0.8904]])
0
0.4784088217801581
[-0.06899972 -0.82050488  0.04038859  1.19012636]
te

KeyboardInterrupt: 