In [7]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make('CartPole-v1')

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x12595244880>

In [8]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))

class BatchMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

class Agent():
    def __init__(self, n_observations, n_actions):
        self.batch_size = 128
        self.gamma = 0.99
        self.eps_start = 0.9
        self.eps_end = 0.05
        self.eps_decay = 1000
        self.tau = 0.005
        self.lr = 1e-4

        self.policy_net = DQN(n_observations, n_actions)
        self.target_net = DQN(n_observations, n_actions)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.steps_done = 0
        self.memory = BatchMemory(10000)
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.lr, amsgrad=True)

    def get_action(self, state):
        sample = random.random()
        epsilon = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        if sample > epsilon:
            with torch.no_grad():
                return self.policy_net(state).max(1).indices.view(1, 1)
        else:
            return torch.tensor([[env.action_space.sample()]], dtype=torch.long)
    
    def optimize_model(self):
        if len(self.memory) < self.batch_size:
            return
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        next_state_values = torch.zeros(self.batch_size)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
        q_pred = (next_state_values * self.gamma) + reward_batch
        
        criterion = nn.SmoothL1Loss()
        loss = criterion(q_values, q_pred.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

In [None]:
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)

agent = Agent(n_observations, n_actions)
done = False

episode_durations = []

for i in range(500):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    for t in count():
        action = agent.get_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward])
        done = terminated or truncated
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)
        
        agent.memory.push(state, action, reward, next_state)
        state = next_state
        agent.optimize_model()
        
        target_net_state_dict = agent.target_net.state_dict()
        policy_net_state_dict = agent.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = \
            policy_net_state_dict[key]*agent.tau + target_net_state_dict[key]*(1 - agent.tau)
        agent.target_net.load_state_dict(target_net_state_dict)
        
        if done:
            episode_durations.append(t+1)
            print(episode_durations[-1])
            break
            
    if episode_durations[-1] == 500:
        break

21
11
10
15
21
12
28
41
16
18
45
17
18
10
23
14
15
11
18
14
26
14
11
12
13
17
15
12
21
10
11
11
22
11
13
19
9
14
12
10
11
11
12
10
17
9
11
12
17
13
11
10
11
10
12
9
12
9
13
16
10
10
10
9
15
10
11
10
10
15
15
9
10
10
10
11
15
15
13
9
9
13
9
11
9
9
11
14
13
13
9
10
10
11
16
9
9
14
9
10
11
11
9
11
11
9
8
11
9
9
14
10
9
10
9
9
15
13
14
12
10
10
10
11
9
10
10
16
13
13
13
13
13
10
14
16
14
14
13
13
14
15
16
19
14
13
17
19
20
17
68
91
85
75
60
56
55
73
40
99
82
60
67
89
83
65
103
81
106
60
78
83
80
91
72
58
76
91
77
93
128
97
87
134
71
78
84
93
85
82
119
121
86
86
148
78
118
110
78
134
153
98
100
87
82
86
90
90
81
83
96
104
94
81
97
109
246
96
95
97
110
99
94
215
106
130
104
105
89
80
95
124
78
89
83
121
83
96
74
78
68
91
83
75
129
87
78
90
104
113
95
69
87
81
94
91
104
106
130
79
94
100
134
97
133
113
143
93
179
136
133
117
112
116
124
167
190
180
117
118
138
128
118
120
161
185
104
119
125
128
164
136
134
128
142
113
136
152
134
145
153
149
148
155
146
137
142
197
163
159
144
166
88
106
120