Skip to content
This repository has been archived by the owner on Jan 27, 2023. It is now read-only.

Add IMPALA Resnet model #7

Merged
merged 4 commits into from
Mar 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rainy/agents/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, config: Config) -> None:
def members_to_save(self) -> Tuple[str, ...]:
return ("net",)

def set_mode(self, train: bool = True) -> None:
self.net.train(mode=train)

def eval_action(self, state_: Array) -> Action:
state = self.config.eval_env.state_to_array(state_)
if len(state.shape) == len(self.net.state_dim):
Expand Down
7 changes: 5 additions & 2 deletions rainy/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
import torch
from torch import nn
from typing import Callable, Generic, Iterable, List, NamedTuple, Optional, Tuple
from typing import Any, Callable, Generic, Iterable, List, NamedTuple, Optional, Tuple
import warnings
from ..config import Config
from ..lib.rollout import RolloutStorage
Expand All @@ -16,7 +16,7 @@ class EpisodeResult(NamedTuple):
length: np.int32

def __repr__(self) -> str:
return 'EpisodeResult(reward: {}, episode_length: {})'.format(self.reward, self.length)
return 'Result: reward: {}, length: {}'.format(self.reward, self.length)


class Agent(ABC):
Expand Down Expand Up @@ -55,6 +55,9 @@ def eval_action(self, state: Array) -> Action:
def update_steps(self) -> int:
pass

def set_mode(self, train: bool = True) -> None:
pass

def report_loss(self, **kwargs) -> None:
if self.update_steps % self.config.network_log_freq == 0:
kwargs['update-steps'] = self.update_steps
Expand Down
3 changes: 3 additions & 0 deletions rainy/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(self, config: Config) -> None:
dtype=torch.long
)

def set_mode(self, train: bool = True) -> None:
self.net.train(mode=train)

def members_to_save(self) -> Tuple[str, ...]:
return "net", "target_net", "policy", "total_steps"

Expand Down
2 changes: 2 additions & 0 deletions rainy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self) -> None:
}

# Environments
self.eval_parallel = False
self.eval_times = 1
self.__env = lambda: ClassicalControl()
self.__eval_env: Optional[EnvExt] = None
self.__paralle_env = lambda env_gen, num_w: DummyParallelEnv(env_gen, num_w)
Expand Down
2 changes: 1 addition & 1 deletion rainy/net/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .actor_critic import ActorCriticNet
from .block import Activator, ConvBody, DqnConv, FcBody, LinearHead, NetworkBlock
from .block import Activator, ConvBody, DqnConv, FcBody, ResNetBody, LinearHead, NetworkBlock
from .init import InitFn, Initializer
from .policy import Policy, PolicyHead
from .value import ValueNet, ValuePredictor
18 changes: 17 additions & 1 deletion rainy/net/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numpy import ndarray
from torch import nn, Tensor
from typing import Callable, Tuple, Union
from .block import DqnConv, FcBody, LinearHead, NetworkBlock
from .block import DqnConv, FcBody, ResNetBody, LinearHead, NetworkBlock
from .init import Initializer, orthogonal
from .policy import CategoricalHead, Policy, PolicyHead
from ..utils import Device
Expand Down Expand Up @@ -87,6 +87,22 @@ def ac_conv(
return ActorCriticNet(body, ac_head, cr_head, device=device, policy_head=policy_head)


def impala_conv(
state_dim: Tuple[int, int, int],
action_dim: int,
device: Device,
policy: Callable[[int], PolicyHead] = CategoricalHead,
) -> ActorCriticNet:
"""Convolutuion network used in IMPALA
"""
body = ResNetBody(state_dim, channels=[16, 32, 32], use_batch_norm=False)
policy_head = policy(action_dim, device)
policy_dim = policy_head.calc_input_dim(action_dim)
ac_head = LinearHead(body.output_dim, policy_dim, Initializer(weight_init=orthogonal(0.01)))
cr_head = LinearHead(body.output_dim, 1)
return ActorCriticNet(body, ac_head, cr_head, device=device, policy_head=policy_head)


def fc(state_dim: Tuple[int, ...],
action_dim: int,
device: Device = Device(),
Expand Down
93 changes: 92 additions & 1 deletion rainy/net/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class NetworkBlock(nn.Module, ABC):
"""Defines a NN block
"""Defines a NN block which returns 1-dimension Tensor
"""
@property
@abstractmethod
Expand All @@ -25,6 +25,11 @@ def output_dim(self) -> int:
pass


class DummyBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return x


class LinearHead(NetworkBlock):
"""One FC layer
"""
Expand Down Expand Up @@ -110,6 +115,92 @@ def __init__(
super().__init__(F.relu, init, dim, hidden, fc, conv1, conv2, conv3)


class ResBlock(nn.Module):
def __init__(
self,
channel: int,
stride: int = 1,
use_batch_norm: bool = True,
) -> None:
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
self._conv3x3(channel, channel, stride),
self._batch_norm(use_batch_norm),
nn.ReLU(inplace=True),
self._conv3x3(channel, channel, stride),
self._batch_norm(use_batch_norm),
)

@staticmethod
def _batch_norm(use_batch_norm: bool) -> nn.Module:
if use_batch_norm:
return nn.BatchNorm2d(out_channel)
else:
return DummyBlock()

@staticmethod
def _conv3x3(in_channel: int, out_channel: int, stride: int = 1) -> nn.Conv2d:
return nn.Conv2d(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)

def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.net(x)
return out + residual


class ResNetBody(NetworkBlock):
"""Convolutuion Network used in IMPALA
"""
def __init__(
self,
input_dim: Tuple[int, int, int],
channels: List[int],
use_batch_norm: bool = True,
fc_out: int = 256,
init: Initializer = Initializer(nonlinearity = 'relu'),
) -> None:
def make_layer(in_channel: int, out_channel: int) -> nn.Sequential:
return nn.Sequential(
ResBlock._conv3x3(in_channel, out_channel),
ResBlock._batch_norm(use_batch_norm),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ResBlock(out_channel, use_batch_norm=use_batch_norm),
ResBlock(out_channel, use_batch_norm=use_batch_norm)
)

super().__init__()
self._input_dim = input_dim
_channels = zip([input_dim[0]] + channels, channels)
self.res_blocks = init.make_list(*[make_layer(*c) for c in _channels])
self.relu = nn.ReLU(inplace=True)
width, height = calc_cnn_hidden([(3, 2, 1)] * len(channels), *input_dim[1:])
fc_in = iter_prod([channels[-1], width, height])
self.fc = nn.Linear(fc_in, fc_out)

def forward(self, x: Tensor) -> Tensor:
for block in self.res_blocks:
x = block(x)
x = self.relu(x)
x = self.fc(x.view(x.size(0), -1))
return self.relu(x)

@property
def input_dim(self) -> Tuple[int, ...]:
return self._input_dim

@property
def output_dim(self) -> int:
return self.fc.out_features


class FcBody(NetworkBlock):
def __init__(
self,
Expand Down
46 changes: 35 additions & 11 deletions rainy/run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
from .agents import Agent, EpisodeResult
import numpy as np
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from .agents import Agent, EpisodeResult
from .prelude import Array


SAVE_FILE_DEFAULT = 'rainy-agent.save'
ACTION_FILE_DEFAULT = 'actions.json'


def _eval_common(
ag: Agent,
save_file: Optional[Path],
render: bool = False
) -> List[EpisodeResult]:
n = ag.config.eval_times
ag.set_mode(train=False)
if save_file is not None:
res = [ag.eval_and_save(save_file, render=render) for _ in range(n)]
elif not ag.config.eval_parallel:
res = [ag.eval_episode(render=render) for _ in range(n)]
else:
res = ag.eval_parallel(n)
ag.set_mode(train=True)
return res


def _reward_and_length(results: List[EpisodeResult]) -> Tuple[Array[float], Array[float]]:
rewards = np.array(list(map(lambda t: t.reward, results)))
length = np.array(list(map(lambda t: t.length, results)))
return rewards, length


def train_agent(
ag: Agent,
save_file_name: str = SAVE_FILE_DEFAULT,
Expand All @@ -19,8 +43,7 @@ def train_agent(
action_file = Path(action_file_name)

def log_episode(episodes: int, res: List[EpisodeResult]) -> None:
rewards = np.array(list(map(lambda t: t.reward, res)))
length = np.array(list(map(lambda t: t.length, res)))
rewards, length = _reward_and_length(res)
ag.logger.exp('train', {
'episodes': episodes,
'update-steps': ag.update_steps,
Expand All @@ -39,14 +62,15 @@ def log_eval(episodes: int):
episodes,
action_file.suffix
))
res = ag.eval_and_save(fname.as_posix())
res = _eval_common(ag, fname)
else:
res = ag.eval_episode()
res = _eval_common(ag, None)
rewards, length = _reward_and_length(res)
ag.logger.exp('eval', {
'episodes': episodes,
'update-steps': ag.update_steps,
'reward': res.reward,
'length': res.length,
'reward-mean': float(np.mean(rewards)),
'length-mean': float(np.mean(length)),
})

def interval(turn: int, width: int, freq: Optional[int]) -> bool:
Expand Down Expand Up @@ -82,10 +106,10 @@ def eval_agent(
) -> None:
path = Path(log_dir)
ag.load(path.joinpath(load_file_name).as_posix())
if action_file:
res = ag.eval_and_save(path.joinpath(action_file).as_posix(), render=render)
if action_file is not None and len(action_file) > 0:
res = _eval_common(ag, path.joinpath(action_file).as_posix(), render=render)
else:
res = ag.eval_episode(render=render)
res = _eval_common(ag, None, render=render)
print('{}'.format(res))
if render:
input('--Press Enter to exit--')
Expand Down