Skip to content

Commit

Permalink
Add updater_creator
Browse files Browse the repository at this point in the history
  • Loading branch information
mitmul committed Jul 24, 2017
1 parent 561854a commit a05a670
Show file tree
Hide file tree
Showing 22 changed files with 129 additions and 340 deletions.
20 changes: 18 additions & 2 deletions chainercmd/bin/chainercmd.py
Expand Up @@ -29,12 +29,28 @@ def main():
parser_test.set_defaults(handler=test.test)

# init command
parser_init = subparsers.add_parser('init', help='Inference mode')
parser_init = subparsers.add_parser(
'init', help='Generate templates of dataset.py, loss.py, and model.py')
parser_init.add_argument(
'--create_subdirs', action='store_true', default=False,
help='If you want to create subdirs ("model", "loss", "dataset"), '
'give this flag.')
parser_init.set_defaults(handler=init.init)
parser_init.set_defaults(handler=init.init_basic)

# init_full command
parser_init = subparsers.add_parser(
'init_full', help='Generate templates of dataset.py, loss.py, '
'model.py, and updater_creator.py')
parser_init.add_argument(
'--create_subdirs', action='store_true', default=False,
help='If you want to create subdirs ("model", "loss", "dataset", '
'"updater_creator"), give this flag.')
parser_init.set_defaults(handler=init.init_full)

# init_config command
parser_init = subparsers.add_parser(
'init_config', help='Generate a config template')
parser_init.set_defaults(handler=init.init_config)

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions chainercmd/config/__init__.py
@@ -1,3 +1,4 @@
from chainercmd.config.model import get_model_from_config # NOQA
from chainercmd.config.dataset import get_dataset_from_config # NOQA
from chainercmd.config.optimizer import get_optimizer_from_config # NOQA
from chainercmd.config.updater_creator import get_updater_creator_from_config # NOQA
21 changes: 0 additions & 21 deletions chainercmd/config/loss.py

This file was deleted.

35 changes: 20 additions & 15 deletions chainercmd/config/model.py
@@ -1,11 +1,12 @@
import os
import shutil
from importlib import import_module
from importlib.machinery import SourceFileLoader

import chainer

from chainercmd.config.base import ConfigBase

import chainer


class Model(ConfigBase):

Expand All @@ -25,19 +26,19 @@ class Loss(ConfigBase):

def __init__(self, **kwargs):
required_keys = [
'file',
'name',
]
optional_keys = [
'args'
'file',
'args',
]
super().__init__(
required_keys, optional_keys, kwargs, self.__class__.__name__)


def get_model(
model_file, model_name, model_args, loss_file, loss_name, loss_args,
result_dir):
result_dir, model_file, model_name, model_args, loss_file, loss_name,
loss_args):
loader = SourceFileLoader(model_name, model_file)
mod = loader.load_module()
model = getattr(mod, model_name)
Expand All @@ -54,25 +55,29 @@ def get_model(
else:
model = model()

if chainer.config.train:
loader = SourceFileLoader(loss_name, loss_file)
mod = loader.load_module()
if chainer.config.train and loss_name is not None:
if loss_file is not None:
loader = SourceFileLoader(loss_name, loss_file)
mod = loader.load_module()
else:
mod = import_module('chainer.links')
loss = getattr(mod, loss_name)
if loss_args is not None:
model = loss(model, **loss_args)
else:
model = loss(model)

# Copy loss file
dst = '{}/{}'.format(result_dir, os.path.basename(loss_file))
if not os.path.exists(dst):
shutil.copy(loss_file, dst)
if loss_file is not None:
# Copy loss file
dst = '{}/{}'.format(result_dir, os.path.basename(loss_file))
if not os.path.exists(dst):
shutil.copy(loss_file, dst)
return model


def get_model_from_config(config):
model = Model(**config['model'])
loss = Loss(**config['loss'])
return get_model(
model.file, model.name, model.args, loss.file, loss.name, loss.args,
config['result_dir'])
config['result_dir'], model.file, model.name, model.args,
loss.file, loss.name, loss.args)
36 changes: 36 additions & 0 deletions chainercmd/config/updater_creator.py
@@ -0,0 +1,36 @@
from functools import partial
from importlib.machinery import SourceFileLoader

from chainercmd.config.base import ConfigBase


class UpdaterCreator(ConfigBase):

def __init__(self, **kwargs):
required_keys = [
'file',
'name',
]
optional_keys = [
'args',
]
super().__init__(
required_keys, optional_keys, kwargs, self.__class__.__name__)


def get_updater_creator(file, name, args):
loader = SourceFileLoader(name, file)
mod = loader.load_module()
updater_creator = getattr(mod, name)
if args is not None:
return partial(updater_creator, **args)
else:
return updater_creator


def get_updater_creator_from_config(config):
updater_creator_config = UpdaterCreator(**config['updater_creator'])
updater_creator = get_updater_creator(
updater_creator_config.file, updater_creator_config.name,
updater_creator_config.args)
return updater_creator
18 changes: 17 additions & 1 deletion chainercmd/init.py
Expand Up @@ -4,7 +4,7 @@
from chainercmd import template


def init(args):
def init_basic(args):
model_template = os.path.abspath(template.model.__file__)
dataset_template = os.path.abspath(template.dataset.__file__)
loss_template = os.path.abspath(template.loss.__file__)
Expand All @@ -23,6 +23,22 @@ def init(args):
shutil.copy(model_template, './')
shutil.copy(dataset_template, './')
shutil.copy(loss_template, './')
init_config(args)


def init_config(args):
model_template = os.path.abspath(template.model.__file__)
dname = os.path.dirname(model_template)
shutil.copy('{}/config.yml'.format(dname), './')


def init_full(args):
init_basic(args)
updater_creator_template = os.path.abspath(
template.updater_creator.__file__)
if args.create_subdirs:
if not os.path.exists('updater'):
os.mkdir('updater')
shutil.copy(updater_creator_template, 'updater/')
else:
shutil.copy(updater_creator_template, './')
1 change: 1 addition & 0 deletions chainercmd/template/__init__.py
Expand Up @@ -4,6 +4,7 @@
from chainercmd.template import dataset # NOQA
from chainercmd.template import loss # NOQA
from chainercmd.template import model # NOQA
from chainercmd.template import updater_creator # NOQA

dname = os.path.dirname(__file__)
config_base = yaml.load(open('{}/config.yml'.format(dname)))
8 changes: 7 additions & 1 deletion chainercmd/template/config.yml
Expand Up @@ -25,7 +25,7 @@ model:
n_class: 10

loss:
file: loss.py
file: loss.py # If 'file' is ommitted, chainer.links.[name] is used
name: Loss

optimizer:
Expand All @@ -38,6 +38,12 @@ optimizer:
points: [10, 15]
unit: epoch

updater_creator:
file: updater_creator.py
name: MyUpdaterCreator
args:
print: True

trainer_extension:
- dump_graph:
root_name: main/loss
Expand Down
7 changes: 7 additions & 0 deletions chainercmd/template/updater_creator.py
@@ -0,0 +1,7 @@
from chainer.training import updater


def updater_creator(iterator, optimizer, devices, *args, **kwargs):
print(args)
print(kwargs)
return updater.StandardUpdater(iterator, optimizer, device=devices['main'])
8 changes: 7 additions & 1 deletion chainercmd/train.py
Expand Up @@ -19,6 +19,7 @@
from chainercmd.config import get_dataset_from_config
from chainercmd.config import get_model_from_config
from chainercmd.config import get_optimizer_from_config
from chainercmd.config import get_updater_creator_from_config

try:
HAVE_NCCL = updaters.MultiprocessParallelUpdater.available()
Expand Down Expand Up @@ -125,7 +126,11 @@ def train(args):
config['valid_batchsize'], devices)

# Create updater and trainer
updater = create_updater(train_iter, optimizer, devices)
if 'updater_creator' in config:
updater_creator = get_updater_creator_from_config(config)
updater = updater_creator(train_iter, optimizer, devices)
else:
updater = create_updater(train_iter, optimizer, devices)
trainer = training.Trainer(
updater, (config['stop_epoch'], 'epoch'), out=config['result_dir'])

Expand All @@ -146,6 +151,7 @@ def train(args):
extensions.Evaluator(valid_iter, model, device=args.gpus[0]),
trigger=config['valid_trigger'])

# Trainer extensions
for ext in config['trainer_extension']:
if isinstance(ext, dict):
ext, values = ext.popitem()
Expand Down
66 changes: 0 additions & 66 deletions examples/cifar10/config.yml

This file was deleted.

16 changes: 0 additions & 16 deletions examples/cifar10/dataset.py

This file was deleted.

32 changes: 0 additions & 32 deletions examples/cifar10/loss.py

This file was deleted.

0 comments on commit a05a670

Please sign in to comment.