Skip to content

Commit

Permalink
Reorganize code
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Apr 2, 2018
1 parent 2bb5160 commit 17ea833
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def update_current_obs(obs):
update_current_obs(obs)
rollouts.insert(current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)

next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
Variable(rollouts.states[-1], volatile=True),
Variable(rollouts.masks[-1], volatile=True))[0].data
next_value = actor_critic.get_value(Variable(rollouts.observations[-1], volatile=True),
Variable(rollouts.states[-1], volatile=True),
Variable(rollouts.masks[-1], volatile=True)).data

rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

Expand Down
52 changes: 34 additions & 18 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,45 @@ def weights_init(m):
m.bias.data.fill_(0)


class FFPolicy(nn.Module):
class Policy(nn.Module):
def __init__(self):
super(FFPolicy, self).__init__()

super(Policy, self).__init__()
"""
All classes that inheret from Policy are expected to have
a feature exctractor for actor and critic (see examples below)
and modules called linear_critic and dist. Where linear_critic
takes critic features and maps them to value and dist
represents a distribution of actions.
"""

def forward(self, inputs, states, masks):
raise NotImplementedError

def act(self, inputs, states, masks, deterministic=False):
value, x, states = self(inputs, states, masks)
action = self.dist.sample(x, deterministic=deterministic)
action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
hidden_critic, hidden_actor, states = self(inputs, states, masks)

action = self.dist.sample(hidden_actor, deterministic=deterministic)

action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(hidden_actor, action)
value = self.critic_linear(hidden_critic)

return value, action, action_log_probs, states

def get_value(self, inputs, states, masks):
hidden_critic, _, states = self(inputs, states, masks)
value = self.critic_linear(hidden_critic)
return value

def evaluate_actions(self, inputs, states, masks, actions):
value, x, states = self(inputs, states, masks)
action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
hidden_critic, hidden_actor, states = self(inputs, states, masks)

action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(hidden_actor, actions)
value = self.critic_linear(hidden_critic)

return value, action_log_probs, dist_entropy, states


class CNNPolicy(FFPolicy):
class CNNPolicy(Policy):
def __init__(self, num_inputs, action_space, use_gru):
super(CNNPolicy, self).__init__()
self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
Expand Down Expand Up @@ -101,7 +120,7 @@ def forward(self, inputs, states, masks):
hx = states = self.gru(x[i], states * masks[i])
outputs.append(hx)
x = torch.cat(outputs, 0)
return self.critic_linear(x), x, states
return x, x, states


def weights_init_mlp(m):
Expand All @@ -113,7 +132,7 @@ def weights_init_mlp(m):
m.bias.data.fill_(0)


class MLPPolicy(FFPolicy):
class MLPPolicy(Policy):
def __init__(self, num_inputs, action_space):
super(MLPPolicy, self).__init__()

Expand All @@ -124,8 +143,8 @@ def __init__(self, num_inputs, action_space):

self.v_fc1 = nn.Linear(num_inputs, 64)
self.v_fc2 = nn.Linear(64, 64)
self.v_fc3 = nn.Linear(64, 1)

self.critic_linear = nn.Linear(64, 1)
self.dist = get_distribution(64, action_space)

self.train()
Expand Down Expand Up @@ -154,15 +173,12 @@ def forward(self, inputs, states, masks):
x = F.tanh(x)

x = self.v_fc2(x)
x = F.tanh(x)

x = self.v_fc3(x)
value = x
hidden_critic = F.tanh(x)

x = self.a_fc1(inputs)
x = F.tanh(x)

x = self.a_fc2(x)
x = F.tanh(x)
hidden_actor = F.tanh(x)

return value, x, states
return hidden_critic, hidden_actor, states

0 comments on commit 17ea833

Please sign in to comment.