Skip to content

Commit

Permalink
fix handler to use another global step func
Browse files Browse the repository at this point in the history
  • Loading branch information
disktnk committed Oct 9, 2019
1 parent 0557210 commit a7aea77
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 21 deletions.
82 changes: 67 additions & 15 deletions chainerui/contrib/ignite/handler.py
@@ -1,21 +1,45 @@
from ignite.contrib.handlers.base_logger import BaseLogger
from ignite.contrib.handlers.base_logger import BaseOutputHandler
from ignite.engine import Events

import chainerui
from chainerui.utils.log_report import _get_time


class OutputHandler(BaseOutputHandler):
"""Handler for ChainerUI
"""Handler for ChainerUI logger
A helper for handler to log engine's output, specialized for ChainerUI.
This handler sets 'epoch', 'iteration' and 'elapsed_time' automatically,
these are default x axis to show.
.. code-block:: python
from chainerui.contrib.ignite.handler import OutputHandler
train_handler = OutputHandler(
'train', output_transform=lambda o: {'param': o})
val_handler = OutputHandler('val', metric_names='all')
Args:
tag (str): use for a prefix of parameter name, will show as
{tag}/{param}
metric_names (str or list): keys names of ``list`` to monitor. set
``'all'`` to get all metrics monitored by the engine.
output_transform (func): if set, use this function to convert output
from ``engine.state.output``
another_engine (``ignite.engine.Engine``): if set, use for getting
global step. This option is deprecated from 0.3.
global_step_transform (func): if set, use this to get global step.
interval_step (int): interval step for posting metrics to ChainerUI
server.
"""
def __init__(
self, tag, metric_names=None, output_transform=None,
another_engine=None, global_step_transform=None,
interval=-1, validation_mode=False):
another_engine=None, global_step_transform=None, interval_step=-1):
super(OutputHandler, self).__init__(
tag, metric_names, output_transform, another_engine,
global_step_transform)
self.interval = interval
self.validation_mode = validation_mode
self.interval = interval_step

def __call__(self, engine, logger, event_name):
if not isinstance(logger, ChainerUILogger):
Expand All @@ -26,21 +50,18 @@ def __call__(self, engine, logger, event_name):
metrics = self._setup_output_metrics(engine)
if not metrics:
return
iteration = engine.state.iteration
epoch = engine.state.epoch
iteration = self.global_step_transform(
engine, Events.ITERATION_COMPLETED)
epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED)

# convert metrics name
rendered_metrics = {}
for k, v in metrics.items():
rendered_metrics['{}/{}'.format(self.tag, k)] = v
if not self.validation_mode:
logger.previous_count['iteration'] = iteration
logger.previous_count['epoch'] = epoch
rendered_metrics['iteration'] = iteration
rendered_metrics['epoch'] = epoch
else:
rendered_metrics['iteration'] = logger.previous_count['iteration']
rendered_metrics['epoch'] = logger.previous_count['epoch']
rendered_metrics['iteration'] = iteration
rendered_metrics['epoch'] = epoch
if 'elapsed_time' not in rendered_metrics:
rendered_metrics['elapsed_time'] = _get_time() - logger.start_at

if self.interval <= 0:
logger.post_log([rendered_metrics])
Expand All @@ -56,6 +77,31 @@ def __call__(self, engine, logger, event_name):


class ChainerUILogger(BaseLogger):
"""Logger handler for ChainerUI
A helper logger to post metrics to ChainerUI server. Attached handlers
are expected using ``chainerui.contrib.ignite.handler.OutputHandler``.
A tag name of handler must be unique when attach several handlers.
.. code-block:: python
from chainerui.contrib.ignite.handler import OutputHandler
train_handler = OutputHandler(...)
val_handler = OutputHandler(...)
from ignite.engine.engine import Engine
train_engine = Engine(...)
eval_engine = Engine(...)
from chainerui.contrib.ignite.handler import ChainerUILogger
logger = ChainerUILogger()
logger.attach(
train_engine, log_handler=train_handler,
event_name=Events.EPOCH_COMPLETED)
logger.attach(
eval_engine, log_handler=val_handler,
event_name=Event.EPOCH_COMPLETED)
"""
def __init__(self):
web_client = chainerui.client.client._client
if not web_client:
Expand All @@ -67,6 +113,7 @@ def __init__(self):
self.client = web_client
self.cache = {} # key is tag name, value is list of metrics
self.previous_count = {}
self.start_at = _get_time()

def attach(self, engine, log_handler, event_name):
if log_handler.tag in self.attached_tags:
Expand All @@ -77,6 +124,11 @@ def attach(self, engine, log_handler, event_name):
def post_log(self, metrics):
self.client.post_log(metrics)

def __enter__(self):
# overwrite start timestamp when engine is used with 'with' block
self.start_at = _get_time()
super(ChainerUILogger, self).__enter__()

def close(self):
for k, v in self.cache.items():
if v:
Expand Down
26 changes: 20 additions & 6 deletions examples/train_mnist_ignite.py
@@ -1,3 +1,5 @@
# This example is based on https://github.com/pytorch/ignite/blob/v0.2.1/examples/mnist/mnist.py
# Please see [ChainerUI] commend on the below code.
from argparse import ArgumentParser

from torch import nn
Expand Down Expand Up @@ -99,15 +101,25 @@ def log_validation_results(engine):

pbar.n = pbar.last_print_n = 0

# [ChainerUI] import logger and handler
from chainerui.contrib.ignite.handler import OutputHandler
from chainerui.contrib.ignite.handler import ChainerUILogger

from ignite.contrib.handlers.base_logger import global_step_from_engine
# [ChainerUI] setup logger, this logger manages ChainerUI web client.
logger = ChainerUILogger()
train_handler = OutputHandler('train', output_transform=lambda o: {'nll': o}, interval=10)
logger.attach(trainer, log_handler=train_handler, event_name=Events.ITERATION_COMPLETED)
val_handler = OutputHandler('val', metric_names=['accuracy', 'nll'], validation_mode=True)
logger.attach(evaluator, log_handler=val_handler, event_name=Events.EPOCH_COMPLETED)

# [ChainerUI] to ease requests of metrics posting, set interval option
train_handler = OutputHandler(
'train', output_transform=lambda o: {'nll': o}, interval_step=20)
logger.attach(
trainer, log_handler=train_handler,
event_name=Events.ITERATION_COMPLETED)
# [ChainerUI] to set same value of x axis, use global_step_transform
val_handler = OutputHandler(
'val', metric_names='all',
global_step_transform=global_step_from_engine(trainer))
logger.attach(
evaluator, log_handler=val_handler, event_name=Events.EPOCH_COMPLETED)
# [ChainerUI] to post remainder of metrics caused by interval, use "with"
with logger:
trainer.run(train_loader, max_epochs=epochs)
pbar.close()
Expand All @@ -130,6 +142,8 @@ def log_validation_results(engine):

args = parser.parse_args()

# [ChainerUI] To use ChainerUI web client, must initialize
# args will be showed as parameter of this experiment.
chainerui.init(conditions=args)

run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum, args.log_interval)

0 comments on commit a7aea77

Please sign in to comment.