Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Actually add missing probability term
  • Loading branch information
cswinter committed Aug 11, 2019
1 parent 069dbec commit c4f82bb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
8 changes: 6 additions & 2 deletions main.py
Expand Up @@ -66,18 +66,20 @@ def train(hps: HyperParams) -> None:
entropies = []
all_obs = []
all_actions = []
all_probs = []
all_rewards = []
all_dones = []

# Rollout
for step in range(hps.seq_rosteps):
obs_tensor = torch.tensor(obs).to(device)
actions, entropy = policy.evaluate(obs_tensor)
actions, probs, entropy = policy.evaluate(obs_tensor)

entropies.append(entropy)

all_obs.extend(obs)
all_actions.extend(actions)
all_probs.extend(probs)

obs, rews, dones, infos = env.step(actions)

Expand Down Expand Up @@ -107,6 +109,7 @@ def train(hps: HyperParams) -> None:
all_obs = np.array(all_obs)[perm]
all_returns = all_returns[perm]
all_actions = np.array(all_actions)[perm]
all_probs = np.array(all_probs)[perm]

# Policy Update
episode_loss = 0
Expand All @@ -116,10 +119,11 @@ def train(hps: HyperParams) -> None:

o = torch.tensor(all_obs[start:end]).to(device)
actions = torch.tensor(all_actions[start:end]).to(device)
probs = torch.tensor(all_probs[start:end]).to(device)
returns = torch.tensor(all_returns[start:end]).to(device)

optimizer.zero_grad()
episode_loss += policy.backprop(o, actions, returns)
episode_loss += policy.backprop(o, actions, probs, returns)
optimizer.step()

epoch += 1
Expand Down
14 changes: 8 additions & 6 deletions policy.py
Expand Up @@ -16,16 +16,18 @@ def __init__(self, layers, nhidden):
def evaluate(self, observation):
probs = self.forward(observation)
actions = []
ps = []
probs.detach_()
for i in range(probs.size()[0]):
actions.append(np.random.choice(8, 1, p=probs[i].cpu().numpy())[0])
return actions, self.entropy(probs)
probs_np = probs[i].cpu().numpy()
action = np.random.choice(8, 1, p=probs_np)[0]
actions.append(action)
ps.append(probs_np[action])
return actions, ps, self.entropy(probs)

def backprop(self, obs, actions, returns):
def backprop(self, obs, actions, probs, returns):
logits = self.logits(obs)
# TODO: should this use probability value at rollout time before policy updates?
p = torch.clamp_min(F.softmax(logits.data, dim=1).gather(1, actions.view(-1, 1)), 1).view(-1)
loss = torch.sum(returns * F.cross_entropy(logits, actions) / p)
loss = torch.sum(returns * F.cross_entropy(logits, actions) / torch.clamp_min(probs, 0.01))
loss.backward()
return loss.data.tolist()

Expand Down

0 comments on commit c4f82bb

Please sign in to comment.