-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
85 lines (70 loc) · 3.33 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import os
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
from buffer import ReplayBuffer
from model import QNetwork
from utils import soft_update
class Agent():
def __init__(self, state_size, action_size, seed=0, lr=1e-3, update_every=4, batch_size=4, buffer_size=64, gamma = 0.0994,tau = 1e-3, model_path="model.pth"):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("=== AGENT ===")
print(f"Created agent on device: {self.device}")
self.model_path = model_path
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
self.update_every = update_every
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
# network variables
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(self.device)
self.qnetwork_target = QNetwork(state_size, action_size, seed).to(self.device)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)
self.load()
# Control variables
self.memory = ReplayBuffer(action_size, buffer_size, self.batch_size, seed, self.device)
self.t_step = 0
def act(self, state, eps=0.):
state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
self.qnetwork_local.eval()
with torch.no_grad():
action_values = self.qnetwork_local(state)
self.qnetwork_local.train()
# Epsilon-greedy action selection
if random.random() > eps:
return np.argmax(action_values.cpu().data.numpy())
else:
return random.choice(np.arange(self.action_size))
def step(self, state, action, reward, next_state, done):
self.memory.add(state, action, reward, next_state, done)
self.t_step = (self.t_step + 1) % self.update_every
if self.t_step == 0:
# If enough samples are available in memory, get random subset and learn
if len(self.memory) > self.batch_size:
experiences = self.memory.sample()
self.learn(experiences, self.gamma)
def learn(self, experiences, gamma):
states, actions, rewards, next_states, dones = experiences
Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
Q_expected = self.qnetwork_local(states).gather(1, actions)
# Compute loss and backpropagate
loss = F.mse_loss(Q_expected, Q_targets)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network
soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
def save(self):
torch.save(self.qnetwork_local.state_dict(),self.model_path)
torch.save(self.qnetwork_target.state_dict(),self.model_path.replace('.pth','_target.pth'))
print("Saved agent model.")
def load(self):
if( os.path.isfile(self.model_path)):
self.qnetwork_local.load_state_dict(torch.load(self.model_path))
self.qnetwork_target.load_state_dict(torch.load(self.model_path.replace('.pth','_target.pth')))
print(f"Loaded agent model: {self.model_path}")