Skip to content

Commit

Permalink
Refactor the models
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Apr 30, 2018
1 parent 3c0be88 commit df5c12c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
13 changes: 4 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize
from envs import make_env
from model import CNNPolicy, MLPPolicy
from model import Policy
from storage import RolloutStorage
from visualize import visdom_plot

Expand Down Expand Up @@ -68,13 +68,8 @@ def main():
obs_shape = envs.observation_space.shape
obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

if len(envs.observation_space.shape) == 3:
actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
else:
assert not args.recurrent_policy, \
"Recurrent policy is not implemented for the MLP controller"
actor_critic = MLPPolicy(obs_shape[0], envs.action_space)

actor_critic = Policy(obs_shape, envs.action_space, args.recurrent_policy)

if envs.action_space.__class__.__name__ == "Discrete":
action_shape = 1
else:
Expand All @@ -97,7 +92,7 @@ def main():
agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
args.entropy_coef, acktr=True)

rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size)
rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.base.state_size)
current_obs = torch.zeros(args.num_processes, *obs_shape)

def update_current_obs(obs):
Expand Down
46 changes: 26 additions & 20 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import torch.nn.functional as F
from distributions import get_distribution

def zero_bias_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
m.bias.data.fill_(0)


def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0)


class Flatten(nn.Module):
Expand All @@ -18,25 +21,26 @@ def forward(self, x):


class Policy(nn.Module):
def __init__(self):
def __init__(self, obs_shape, action_space, recurrent_policy):
super(Policy, self).__init__()
"""
All classes that inheret from Policy are expected to have
a feature exctractor for actor and critic (see examples below)
and modules called linear_critic and dist. Where linear_critic
takes critic features and maps them to value and dist
represents a distribution of actions.
"""
if len(obs_shape) == 3:
self.base = CNNBase(obs_shape[0], action_space, recurrent_policy)
elif len(obs_shape) == 1:
assert not recurrent_policy, \
"Recurrent policy is not implemented for the MLP controller"
self.base = MLPBase(obs_shape[0], action_space)
else:
raise NotImplementedError

def forward(self, inputs, states, masks):
raise NotImplementedError
return self.base(inputs, states, masks)

def act(self, inputs, states, masks, deterministic=False):
value, hidden_actor, states = self(inputs, states, masks)

action = self.dist.sample(hidden_actor, deterministic=deterministic)
action = self.base.dist.sample(hidden_actor, deterministic=deterministic)

action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(hidden_actor, action)
action_log_probs, dist_entropy = self.base.dist.logprobs_and_entropy(hidden_actor, action)

return value, action, action_log_probs, states

Expand All @@ -47,14 +51,14 @@ def get_value(self, inputs, states, masks):
def evaluate_actions(self, inputs, states, masks, actions):
value, hidden_actor, states = self(inputs, states, masks)

action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(hidden_actor, actions)
action_log_probs, dist_entropy = self.base.dist.logprobs_and_entropy(hidden_actor, actions)

return value, action_log_probs, dist_entropy, states


class CNNPolicy(Policy):
class CNNBase(nn.Module):
def __init__(self, num_inputs, action_space, use_gru):
super(CNNPolicy, self).__init__()
super(CNNBase, self).__init__()

self.main = nn.Sequential(
nn.Conv2d(num_inputs, 32, 8, stride=4),
Expand Down Expand Up @@ -95,6 +99,8 @@ def mult_gain(m):
m.weight.data.mul_(relu_gain)

self.main.apply(mult_gain)

self.apply(zero_bias_init)

if hasattr(self, 'gru'):
nn.init.orthogonal_(self.gru.weight_ih.data)
Expand Down Expand Up @@ -128,13 +134,11 @@ def weights_init_mlp(m):
if classname.find('Linear') != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)


class MLPPolicy(Policy):
class MLPBase(nn.Module):
def __init__(self, num_inputs, action_space):
super(MLPPolicy, self).__init__()
super(MLPBase, self).__init__()

self.action_space = action_space

Expand Down Expand Up @@ -164,6 +168,8 @@ def state_size(self):

def reset_parameters(self):
self.apply(weights_init_mlp)
self.apply(zero_bias_init)

if self.dist.__class__.__name__ == "DiagGaussian":
self.dist.fc_mean.weight.data.mul_(0.01)

Expand Down

0 comments on commit df5c12c

Please sign in to comment.