-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
125 lines (79 loc) · 2.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 22 2021
@author: George Yiasemis
"""
import torch
import torch.nn as nn
import torch.distributions as distributions
def train(agent, policy, optimiser, gamma, num_episode_steps):
policy.train()
action_log_probs = []
rewards = []
episode_reward = 0
agent.reset()
state = agent.state
trace = [state]
end_episode = False
for i in range(num_episode_steps):
state = torch.tensor(state).float()
# Make a prediction using the Policy network
pred = policy(state)
# Calculate \pi_{\theta}( a | s_t) for all a \in {0, 1, ..., .environment.num_actions-1}
action_probs = nn.Softmax(-1)(pred)
# Categorical Distribution
p = distributions.Categorical(probs=action_probs)
# Choose an action a_t \in {0, 1, ..., .environment.num_actions-1}
# with probability action_probs
action = p.sample()
# Calculate \pi_{\theta}( a_t | s_t)
action_log_prob = p.log_prob(action)
# Take a step using the policy; observe reward and next state
(_, _, reward, state), end_episode = agent.step(action.item())
action_log_probs.append(action_log_prob)
rewards.append(reward)
trace.append(state)
episode_reward += reward
if end_episode:
break
# Tensor of shape (min(num_episode_steps,steps_to_goal), 1)
action_log_probs = torch.stack(action_log_probs)
R = calculate_returns(rewards, gamma)
loss = update_policy(R, action_log_probs, optimiser)
return loss, episode_reward, trace
def calculate_returns(rewards, discount_factor, normalise = True):
returns = []
R = 0
for r in reversed(rewards):
R = r + R * discount_factor
returns.insert(0, R)
returns = torch.tensor(returns)
if normalise:
returns = (returns - returns.mean()) / returns.std()
return returns
def update_policy(returns, action_log_probs, optimiser):
returns = returns.detach()
loss = - (returns * action_log_probs).sum()
optimiser.zero_grad()
loss.backward()
optimiser.step()
return loss.item()
def evaluate(agent, policy, num_episode_steps):
policy.eval()
reached_goal = False
episode_reward = 0
agent.reset()
state = agent.state
trace = [state]
for i in range(num_episode_steps):
state = torch.tensor(state).float()
with torch.no_grad():
pred = policy(state)
action_prob = nn.Softmax(-1)(pred)
action = action_prob.argmax(-1)
(_, _, reward, state), reached_goal = agent.step(action.item())
trace.append(state)
episode_reward += reward
if reached_goal:
break
return episode_reward, trace