Skip to content

Commit

Permalink
Merge pull request chainer#327 from disktnk/fix/ignite-handler
Browse files Browse the repository at this point in the history
Support ignite handler
  • Loading branch information
ofk committed Nov 20, 2019
2 parents 69332e5 + 3d0818d commit 9986488
Show file tree
Hide file tree
Showing 7 changed files with 443 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -38,7 +38,7 @@ before_install:
if [[ $CHAINERUI_PLAIN_INSTALL == 1 ]]; then
export CHAINERUI_TEST_DEPENDS="[test-ci-plain]";
else
export CHAINERUI_TEST_DEPENDS="[test-ci]";
export CHAINERUI_TEST_DEPENDS="[test-ci-contrib]";
fi
- npm install -g npm@6
- npm -v
Expand Down
136 changes: 136 additions & 0 deletions chainerui/contrib/ignite/handler.py
@@ -0,0 +1,136 @@
from ignite.contrib.handlers.base_logger import BaseLogger
from ignite.contrib.handlers.base_logger import BaseOutputHandler
from ignite.engine import Events

import chainerui
from chainerui.utils.log_report import _get_time


class OutputHandler(BaseOutputHandler):
"""Handler for ChainerUI logger
A helper for handler to log engine's output, specialized for ChainerUI.
This handler sets 'epoch', 'iteration' and 'elapsed_time' automatically,
these are default x axis to show.
.. code-block:: python
from chainerui.contrib.ignite.handler import OutputHandler
train_handler = OutputHandler(
'train', output_transform=lambda o: {'param': o})
val_handler = OutputHandler('val', metric_names='all')
Args:
tag (str): use for a prefix of parameter name, will show as
{tag}/{param}
metric_names (str or list): keys names of ``list`` to monitor. set
``'all'`` to get all metrics monitored by the engine.
output_transform (func): if set, use this function to convert output
from ``engine.state.output``
another_engine (``ignite.engine.Engine``): if set, use for getting
global step. This option is deprecated from 0.3.
global_step_transform (func): if set, use this to get global step.
interval_step (int): interval step for posting metrics to ChainerUI
server.
"""

def __init__(
self, tag, metric_names=None, output_transform=None,
another_engine=None, global_step_transform=None, interval_step=-1):
super(OutputHandler, self).__init__(
tag, metric_names, output_transform, another_engine,
global_step_transform)
self.interval = interval_step

def __call__(self, engine, logger, event_name):
if not isinstance(logger, ChainerUILogger):
raise RuntimeError(
'`chainerui.contrib.ignite.handler.OutputHandler` works only '
'with ChainerUILogger, but set {}'.format(type(logger)))

metrics = self._setup_output_metrics(engine)
if not metrics:
return
iteration = self.global_step_transform(
engine, Events.ITERATION_COMPLETED)
epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED)

# convert metrics name
rendered_metrics = {}
for k, v in metrics.items():
rendered_metrics['{}/{}'.format(self.tag, k)] = v
rendered_metrics['iteration'] = iteration
rendered_metrics['epoch'] = epoch
if 'elapsed_time' not in rendered_metrics:
rendered_metrics['elapsed_time'] = _get_time() - logger.start_at

if self.interval <= 1:
logger.post_log([rendered_metrics])
return

# enable interval, cache metrics
logger.cache.setdefault(self.tag, []).append(rendered_metrics)
# select appropriate even set by handler init
global_count = self.global_step_transform(engine, event_name)
if global_count % self.interval == 0:
logger.post_log(logger.cache[self.tag])
logger.cache[self.tag].clear()


class ChainerUILogger(BaseLogger):
"""Logger handler for ChainerUI
A helper logger to post metrics to ChainerUI server. Attached handlers
are expected using ``chainerui.contrib.ignite.handler.OutputHandler``.
A tag name of handler must be unique when attach several handlers.
.. code-block:: python
from chainerui.contrib.ignite.handler import OutputHandler
train_handler = OutputHandler(...)
val_handler = OutputHandler(...)
from ignite.engine.engine import Engine
train_engine = Engine(...)
eval_engine = Engine(...)
from chainerui.contrib.ignite.handler import ChainerUILogger
logger = ChainerUILogger()
logger.attach(
train_engine, log_handler=train_handler,
event_name=Events.EPOCH_COMPLETED)
logger.attach(
eval_engine, log_handler=val_handler,
event_name=Event.EPOCH_COMPLETED)
"""

def __init__(self):
web_client = chainerui.client.client._client
if not web_client:
raise RuntimeError(
'fail to setup ChainerUI logger, please check '
'`chainerui.init()` is success to execute.')

self.attached_tags = set()
self.client = web_client
self.cache = {} # key is tag name, value is list of metrics
self.start_at = _get_time()

def attach(self, engine, log_handler, event_name):
if log_handler.tag in self.attached_tags:
raise RuntimeError('attached handlers must have unique tag name')
self.attached_tags.add(log_handler.tag)
super(ChainerUILogger, self).attach(engine, log_handler, event_name)

def post_log(self, metrics):
self.client.post_log(metrics)

def __enter__(self):
# overwrite start timestamp when engine is used with 'with' block
self.start_at = _get_time()
super(ChainerUILogger, self).__enter__()

def close(self):
for k, v in self.cache.items():
if v:
self.post_log(v)
1 change: 1 addition & 0 deletions docs/requirements.txt
@@ -1 +1,2 @@
chainer>=3.0.0
pytorch-ignite>=0.2.1
13 changes: 13 additions & 0 deletions docs/source/reference/module.rst
Expand Up @@ -66,3 +66,16 @@ Utilities
.. _module_save_args:

.. autofunction:: chainerui.utils.save_args


External library support
------------------------

.. _module_ignite_output_handler:

.. autoclass:: chainerui.contrib.ignite.handler.OutputHandler
:members:

.. _module_ignite_logger:

.. autoclass:: chainerui.contrib.ignite.handler.ChainerUILogger
163 changes: 163 additions & 0 deletions examples/train_mnist_ignite.py
@@ -0,0 +1,163 @@
# This example is based on https://github.com/pytorch/ignite/blob/v0.2.1/examples/mnist/mnist.py # NOQA
# Please see [ChainerUI] commend on the below code.
from argparse import ArgumentParser

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor

from ignite.contrib.handlers.base_logger import global_step_from_engine
from ignite.engine import create_supervised_evaluator
from ignite.engine import create_supervised_trainer
from ignite.engine import Events
from ignite.metrics import Accuracy
from ignite.metrics import Loss

from tqdm import tqdm

import chainerui
from chainerui.contrib.ignite.handler import ChainerUILogger
from chainerui.contrib.ignite.handler import OutputHandler


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)


def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True),
batch_size=train_batch_size, shuffle=True)

val_loader = DataLoader(
MNIST(download=False, root=".", transform=data_transform, train=False),
batch_size=val_batch_size, shuffle=False)
return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
train_loader, val_loader = get_data_loaders(
train_batch_size, val_batch_size)
model = Net()
device = 'cpu'

if torch.cuda.is_available():
device = 'cuda'

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(
model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(model,
metrics={'accuracy': Accuracy(),
'nll': Loss(F.nll_loss)},
device=device)

desc = "ITERATION - loss: {:.2f}"
pbar = tqdm(
initial=0, leave=False, total=len(train_loader),
desc=desc.format(0)
)

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1

if iter % log_interval == 0:
pbar.desc = desc.format(engine.state.output)
pbar.update(log_interval)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
tqdm.write(
"Training Results - "
"Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll)
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
tqdm.write(
"Validation Results "
"- Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))

pbar.n = pbar.last_print_n = 0

# [ChainerUI] import logger and handler
# [ChainerUI] setup logger, this logger manages ChainerUI web client.
logger = ChainerUILogger()
# [ChainerUI] to ease requests of metrics posting, set interval option
train_handler = OutputHandler(
'train', output_transform=lambda o: {'nll': o}, interval_step=20)
logger.attach(
trainer, log_handler=train_handler,
event_name=Events.ITERATION_COMPLETED)
# [ChainerUI] to set same value of x axis, use global_step_transform
val_handler = OutputHandler(
'val', metric_names='all',
global_step_transform=global_step_from_engine(trainer))
logger.attach(
evaluator, log_handler=val_handler, event_name=Events.EPOCH_COMPLETED)
# [ChainerUI] to post remainder of metrics caused by interval, use "with"
with logger:
trainer.run(train_loader, max_epochs=epochs)
pbar.close()


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--val_batch_size', type=int, default=1000,
help='input batch size for validation (default: 1000)')
parser.add_argument('--epochs', type=int, default=10,
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5,
help='SGD momentum (default: 0.5)')
parser.add_argument('--log_interval', type=int, default=10,
help='how many batches to wait before logging '
'training status')

args = parser.parse_args()

# [ChainerUI] To use ChainerUI web client, must initialize
# args will be showed as parameter of this experiment.
chainerui.init(conditions=args)

run(
args.batch_size, args.val_batch_size, args.epochs, args.lr,
args.momentum, args.log_interval)
4 changes: 4 additions & 0 deletions setup.py
Expand Up @@ -42,6 +42,10 @@
'matplotlib',
'scipy',
'chainer',
],
'test-ci-contrib': [
'-r test-ci',
'pytorch-ignite',
]
}

Expand Down

0 comments on commit 9986488

Please sign in to comment.