In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

from tensorflow import keras
from tensorflow.keras import layers

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [3]:
n_observations = 4
n_actions = 2
n_hidden = 128


def create_q_model():
    observations = layers.Input(shape=(n_observations,))
    hidden = layers.Dense(n_hidden, activation='relu')(observations)
    action = layers.Dense(n_actions, activation='softmax')(hidden)
    
    return keras.Model(inputs=observations, outputs=action)

policy_network = create_q_model()
target_network = create_q_model()