In [2]:
from datetime import datetime
from collections import deque
import os
import random
import gym
import torch
from torch.distributions import Categorical
from torch.nn import Module, Linear
import torch.nn.functional as F


class QNetwork(Module):
    def __init__(self):
        super().__init__()
        self.fc = Linear(4, 48)
        self.fcQ1 = Linear(48, 64)
        self.fcQ2 = Linear(64, 2)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.fcQ1(x)
        x = F.relu(x)
        x = self.fcQ2(x)

        return x


class PolicyNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 48, bias=False)
        self.fcA1 = Linear(48, 64)
        self.fcA2 = Linear(64, 2)
    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.fcA1(x)
        x = self.fcA2(x)
        return torch.nn.functional.softmax(x, dim=-1)
    
    
    

# network and optimizer
Q = QNetwork()
Q_optimizer = torch.optim.Adam(Q.parameters(), lr=0.0001)

pi = PolicyNetwork()
Actor_optimizer = torch.optim.Adam(pi.parameters(), lr=0.0001)

# target network
Q_target = QNetwork()

replay_memory_size = 100
minibatch_size = 20

history = deque(maxlen=replay_memory_size)  # replay buffer
discount = 0.99  # discount factor gamma


def update_Q():
    loss = 0

    for state, action, state_next, reward, done in random.sample(history, min(minibatch_size, len(history))):
        with torch.no_grad():
            if done:
                target = reward
            else:
                target = reward + discount * torch.dot(pi(state_next),Q_target(state_next))

        loss = loss + (target - Q(state)[action])**2

    loss = loss/min(minibatch_size, len(history))
    Q_optimizer.zero_grad()
    loss.backward()
    Q_optimizer.step()

def update_pi():
    loss = 0
    for state, action, state_next, reward, done in random.sample(history, min(minibatch_size, len(history))):
        loss = loss + Q(state)[action] * pi(state)[action].log()
    loss = loss/min(minibatch_size, len(history))
    Actor_optimizer.zero_grad()
    loss.backward()
    Actor_optimizer.step()    
    

# gym environment
env = gym.make("CartPole-v0")
max_time_steps = 1000


# training
for episode in range(1000):
    # sum of accumulated rewards
    rewards = 0

    # get initial observation
    observation = env.reset()
    state = torch.tensor(observation, dtype=torch.float32)

    # loop until an episode ends
    for t in range(1, max_time_steps + 1):
        # display current environment
        #env.render()

        # epsilon greedy policy for current observation
        probs = pi(state)
        action = torch.multinomial(probs, 1).item()
        
        # get next observation and current reward for the chosen action
        observation_next, reward, done, info = env.step(action)
        state_next = torch.tensor(observation_next, dtype=torch.float32)

        # collect reward
        rewards = rewards + reward

        # collect a transition
        history.append([state, action, state_next, reward, done])

        update_Q()
        update_pi()
        
        #Q_target.load_state_dict(Q.state_dict())
        # Soft update
        for target_param, param in zip(Q_target.parameters(), Q.parameters()):
            target_param.data.copy_(param.data * 0.8 + target_param.data * (1.0 - 0.8))


        if done:
            break

        # pass observation to the next step
        observation = observation_next
        state = state_next

    # compute average reward
    print('episode: {}, reward: {:.1f}'.format(episode, rewards))

env.close()


# TEST     
episode = 0
state = env.reset()     
while episode < 5:  # episode loop
    env.render()
    state = torch.tensor(state, dtype=torch.float32)
    probs = pi(state)
    action = torch.multinomial(probs, 1).item()
    next_state, reward, done, info = env.step(action)  # take a random action
    state = next_state

    if done:
        episode = episode + 1
        state = env.reset()
env.close()     




episode: 0, reward: 31.0
episode: 1, reward: 16.0
episode: 2, reward: 15.0
episode: 3, reward: 11.0
episode: 4, reward: 21.0
episode: 5, reward: 19.0
episode: 6, reward: 16.0
episode: 7, reward: 26.0
episode: 8, reward: 13.0
episode: 9, reward: 26.0
episode: 10, reward: 12.0
episode: 11, reward: 16.0
episode: 12, reward: 25.0
episode: 13, reward: 23.0
episode: 14, reward: 13.0
episode: 15, reward: 13.0
episode: 16, reward: 26.0
episode: 17, reward: 15.0
episode: 18, reward: 13.0
episode: 19, reward: 21.0
episode: 20, reward: 13.0
episode: 21, reward: 23.0
episode: 22, reward: 19.0
episode: 23, reward: 10.0
episode: 24, reward: 27.0
episode: 25, reward: 18.0
episode: 26, reward: 10.0
episode: 27, reward: 26.0
episode: 28, reward: 12.0
episode: 29, reward: 10.0
episode: 30, reward: 13.0
episode: 31, reward: 10.0
episode: 32, reward: 13.0
episode: 33, reward: 13.0
episode: 34, reward: 14.0
episode: 35, reward: 10.0
episode: 36, reward: 9.0
episode: 37, reward: 11.0
episode: 38, reward: 9.

episode: 313, reward: 10.0
episode: 314, reward: 10.0
episode: 315, reward: 8.0
episode: 316, reward: 8.0
episode: 317, reward: 10.0
episode: 318, reward: 11.0
episode: 319, reward: 9.0
episode: 320, reward: 10.0
episode: 321, reward: 10.0
episode: 322, reward: 10.0
episode: 323, reward: 10.0
episode: 324, reward: 9.0
episode: 325, reward: 10.0
episode: 326, reward: 10.0
episode: 327, reward: 9.0
episode: 328, reward: 9.0
episode: 329, reward: 10.0
episode: 330, reward: 10.0
episode: 331, reward: 9.0
episode: 332, reward: 9.0
episode: 333, reward: 9.0
episode: 334, reward: 10.0
episode: 335, reward: 9.0
episode: 336, reward: 10.0
episode: 337, reward: 9.0
episode: 338, reward: 10.0
episode: 339, reward: 8.0
episode: 340, reward: 13.0
episode: 341, reward: 9.0
episode: 342, reward: 10.0
episode: 343, reward: 8.0
episode: 344, reward: 9.0
episode: 345, reward: 9.0
episode: 346, reward: 10.0
episode: 347, reward: 8.0
episode: 348, reward: 10.0
episode: 349, reward: 9.0
episode: 350, rewar

episode: 623, reward: 9.0
episode: 624, reward: 9.0
episode: 625, reward: 8.0
episode: 626, reward: 9.0
episode: 627, reward: 10.0
episode: 628, reward: 8.0
episode: 629, reward: 10.0
episode: 630, reward: 10.0
episode: 631, reward: 9.0
episode: 632, reward: 9.0
episode: 633, reward: 10.0
episode: 634, reward: 9.0
episode: 635, reward: 10.0
episode: 636, reward: 9.0
episode: 637, reward: 9.0
episode: 638, reward: 10.0
episode: 639, reward: 10.0
episode: 640, reward: 10.0
episode: 641, reward: 9.0
episode: 642, reward: 11.0
episode: 643, reward: 9.0
episode: 644, reward: 9.0
episode: 645, reward: 10.0
episode: 646, reward: 9.0
episode: 647, reward: 11.0
episode: 648, reward: 10.0
episode: 649, reward: 11.0
episode: 650, reward: 11.0
episode: 651, reward: 8.0
episode: 652, reward: 10.0
episode: 653, reward: 9.0
episode: 654, reward: 10.0
episode: 655, reward: 9.0
episode: 656, reward: 10.0
episode: 657, reward: 10.0
episode: 658, reward: 10.0
episode: 659, reward: 9.0
episode: 660, rewar

episode: 932, reward: 10.0
episode: 933, reward: 10.0
episode: 934, reward: 8.0
episode: 935, reward: 10.0
episode: 936, reward: 10.0
episode: 937, reward: 8.0
episode: 938, reward: 9.0
episode: 939, reward: 9.0
episode: 940, reward: 10.0
episode: 941, reward: 9.0
episode: 942, reward: 10.0
episode: 943, reward: 13.0
episode: 944, reward: 9.0
episode: 945, reward: 9.0
episode: 946, reward: 10.0
episode: 947, reward: 10.0
episode: 948, reward: 8.0
episode: 949, reward: 10.0
episode: 950, reward: 10.0
episode: 951, reward: 10.0
episode: 952, reward: 9.0
episode: 953, reward: 10.0
episode: 954, reward: 10.0
episode: 955, reward: 9.0
episode: 956, reward: 10.0
episode: 957, reward: 10.0
episode: 958, reward: 9.0
episode: 959, reward: 10.0
episode: 960, reward: 9.0
episode: 961, reward: 10.0
episode: 962, reward: 10.0
episode: 963, reward: 10.0
episode: 964, reward: 10.0
episode: 965, reward: 10.0
episode: 966, reward: 9.0
episode: 967, reward: 11.0
episode: 968, reward: 10.0
episode: 969, 