In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make('CartPole-v1')

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x247b05cee30>

In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))

class BatchMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

class Agent:
    def __init__(self, n_observations, n_actions):
        self.batch_size = 128
        self.gamma = 0.99
        self.eps_start = 0.9
        self.eps_end = 0.05
        self.eps_decay = 1000
        self.tau = 0.005
        self.lr = 1e-4

        self.policy_net = DQN(n_observations, n_actions)
        self.target_net = DQN(n_observations, n_actions)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.steps_done = 0
        self.memory = BatchMemory(10000)
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.lr, amsgrad=True)

    def get_action(self, state):
        sample = random.random()
        epsilon = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        if sample > epsilon:
            with torch.no_grad():
                return self.policy_net(state).max(1).indices.view(1, 1)
        else:
            return torch.tensor([[env.action_space.sample()]], dtype=torch.long)
    
    def batch_update(self):
        if len(self.memory) < self.batch_size:
            return
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        next_state_values = torch.zeros(self.batch_size)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
        q_pred = (next_state_values * self.gamma) + reward_batch
        
        criterion = nn.SmoothL1Loss()
        loss = criterion(q_values, q_pred.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

In [3]:
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)

agent = Agent(n_observations, n_actions)
done = False

episode_durations = []

for i in range(500):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    for t in count():
        action = agent.get_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward])
        done = terminated or truncated
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)
        
        agent.memory.push(state, action, reward, next_state)
        state = next_state
        agent.batch_update()
        
        target_net_state_dict = agent.target_net.state_dict()
        policy_net_state_dict = agent.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = \
            policy_net_state_dict[key]*agent.tau + target_net_state_dict[key]*(1 - agent.tau)
        agent.target_net.load_state_dict(target_net_state_dict)
        
        if done:
            episode_durations.append(t+1)
            print(episode_durations[-1])
            break
            
    if (episode_durations[-1] == 500) | (len(episode_durations) == 500):
        break

10
24
9
15
13
18
14
11
11
10
46
12
14
24
10
21
16
11
9
38
15
17
14
18
21
10
14
20
12
14
10
15
9
10
12
20
15
12
12
10
26
9
12
10
10
21
10
13
10
10
11
10
10
10
12
11
11
12
9
10
12
9
11
11
9
10
12
12
9
11
9
12
9
12
11
14
20
13
9
12
9
9
13
12
12
13
12
18
11
10
21
11
10
12
15
11
9
9
12
10
10
9
11
10
12
14
9
15
14
10
12
12
12
11
13
10
11
13
13
13
14
14
14
13
11
18
24
16
14
11
13
11
15
16
14
14
13
15
13
16
18
14
15
14
19
12
14
14
17
25
17
16
24
27
22
20
29
63
130
90
203
143
149
176
160
305
239
226
168
240
249
187
222
123
149
167
135
137
154
144
150
133
200
162
132
151
127
297
146
156
124
216
174
135
151
182
151
128
155
130
131
154
140
121
170
130
140
126
119
123
136
120
112
139
116
113
140
127
119
148
126
125
114
122
128
120
132
136
126
144
133
136
141
135
181
138
119
150
177
136
132
139
133
152
205
154
146
118
157
124
154
133
172
138
137
118
277
125
195
163
155
137
125
129
162
160
161
137
113
121
169
132
128
132
121
165
124
123
129
162
140
127
144
112
136
171
158
130
115
175
124
137
150
115


In [4]:
agent.policy_net(state)

tensor([[45.6271, 45.6334]], grad_fn=<AddmmBackward0>)

In [5]:
state

tensor([[ 0.2734, -0.0045,  0.0459,  0.2576]])

In [6]:
next_state

tensor([[ 0.2734, -0.0045,  0.0459,  0.2576]])

In [7]:
action.item()

0

In [8]:
reward

tensor([1.])