In [10]:
import tensorflow as tf
import numpy as np

from tensorflow.keras import optimizers, losses
from tensorflow.keras import Model

import random
import gym

class A2C(Model):
    def __init__(self):
        super(A2C, self).__init__()
        self.layer1 = tf.keras.layers.Dense(128, activation='relu')
        self.layer2 = tf.keras.layers.Dense(128, activation='relu')
        self.layer_a1 = tf.keras.layers.Dense(64, activation='relu')
        self.layer_c1 = tf.keras.layers.Dense(64, activation='relu')
        self.logits = tf.keras.layers.Dense(2, activation='softmax')
        self.value = tf.keras.layers.Dense(1)

    def call(self, state):
        layer1 = self.layer1(state)
        layer2 = self.layer2(layer1)
        layer_a1 = self.layer_a1(layer2)
        logits = self.logits(layer_a1)
        layer_c1 = self.layer_c1(layer2)
        value = self.value(layer_c1)
        return logits, value

class Agent:
    def __init__(self):
        self.lr = 0.001
        self.gamma = 0.99

        self.a2c = A2C()
        self.opt = optimizers.Adam(lr=self.lr, )
        
        self.rollout = 128
        self.batch_size = 128
        self.state_size = 4
        self.action_size = 2

    def get_action(self, state):

        state = tf.convert_to_tensor([state], dtype=tf.float32)
        policy, _ = self.a2c(state)
        policy = np.array(policy)[0]
        action = np.random.choice(self.action_size, p=policy)
        return action

    def update(self, state, next_state, reward, done, action):
        sample_range = np.arange(self.rollout)
        np.random.shuffle(sample_range)
        sample_idx = sample_range[:self.batch_size]

        state = [state[i] for i in sample_idx]
        next_state = [next_state[i] for i in sample_idx]
        reward = [reward[i] for i in sample_idx]
        done = [done[i] for i in sample_idx]
        action = [action[i] for i in sample_idx]

        a2c_variable = self.a2c.trainable_variables
        with tf.GradientTape() as tape:
            tape.watch(a2c_variable)
            _, current_value = self.a2c(tf.convert_to_tensor(state, dtype=tf.float32))
            _, next_value = self.a2c(tf.convert_to_tensor(next_state, dtype=tf.float32))
            current_value, next_value = tf.squeeze(current_value), tf.squeeze(next_value)
            target = tf.stop_gradient(self.gamma * (1-tf.convert_to_tensor(done, dtype=tf.float32)) * next_value + tf.convert_to_tensor(reward, dtype=tf.float32))
            value_loss = tf.reduce_mean(tf.square(target - current_value) * 0.5)

            policy, _  = self.a2c(tf.convert_to_tensor(state, dtype=tf.float32))
            entropy = tf.reduce_mean(- policy * tf.math.log(policy+1e-8)) * 0.1
            action = tf.convert_to_tensor(action, dtype=tf.int32)
            onehot_action = tf.one_hot(action, self.action_size)
            action_policy = tf.reduce_sum(onehot_action * policy, axis=1)
            adv = tf.stop_gradient(target - current_value)
            pi_loss = -tf.reduce_mean(tf.math.log(action_policy+1e-8) * adv) - entropy

            total_loss = pi_loss + value_loss

        grads = tape.gradient(total_loss, a2c_variable)
        self.opt.apply_gradients(zip(grads, a2c_variable))

    def run(self):

        env = gym.make('CartPole-v1')
        state = env.reset()
        episode = 0
        score = 0

        while True:
            
            state_list, next_state_list = [], []
            reward_list, done_list, action_list = [], [], []

            for _ in range(self.rollout):
                
                action = self.get_action(state)
                next_state, reward, done, _ = env.step(action)

                score += reward

                if done:
                    if score == 500:
                        reward = 1
                    else:
                        reward = -1
                else:
                    reward = 0

                state_list.append(state)
                next_state_list.append(next_state)
                reward_list.append(reward)
                done_list.append(done)
                action_list.append(action)

                state = next_state

                if done:
                    print(episode, score)
                    state = env.reset()
                    episode += 1
                    score = 0
            self.update(
                state=state_list, next_state=next_state_list,
                reward=reward_list, done=done_list, action=action_list)


if __name__ == '__main__':
    agent = Agent()
    agent.run()

0 15.0
1 19.0
2 24.0
3 25.0
4 13.0
5 13.0
6 63.0
7 39.0
8 55.0
9 93.0
10 42.0
11 27.0
12 51.0
13 28.0
14 36.0
15 59.0
16 28.0
17 35.0
18 20.0
19 19.0
20 27.0
21 67.0
22 14.0
23 39.0
24 40.0
25 17.0
26 37.0
27 76.0
28 32.0
29 34.0
30 74.0
31 28.0
32 28.0
33 123.0
34 19.0
35 12.0
36 21.0
37 19.0
38 11.0
39 8.0
40 9.0
41 13.0
42 12.0
43 32.0
44 33.0
45 22.0
46 20.0
47 80.0
48 34.0
49 14.0
50 85.0
51 70.0
52 21.0
53 43.0
54 62.0
55 38.0
56 122.0
57 29.0
58 13.0
59 61.0
60 114.0
61 16.0
62 40.0
63 35.0
64 148.0
65 115.0
66 79.0
67 19.0
68 33.0
69 121.0
70 161.0
71 30.0
72 13.0
73 10.0
74 17.0
75 145.0
76 38.0
77 12.0
78 18.0
79 186.0
80 95.0
81 69.0
82 21.0
83 13.0
84 10.0
85 41.0
86 50.0
87 18.0
88 160.0
89 53.0
90 96.0
91 27.0
92 15.0
93 259.0
94 13.0
95 21.0
96 46.0
97 22.0
98 206.0
99 16.0


KeyboardInterrupt: 