forked from chainer/chainerui
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request chainer#327 from disktnk/fix/ignite-handler
Support ignite handler
- Loading branch information
Showing
7 changed files
with
443 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from ignite.contrib.handlers.base_logger import BaseLogger | ||
from ignite.contrib.handlers.base_logger import BaseOutputHandler | ||
from ignite.engine import Events | ||
|
||
import chainerui | ||
from chainerui.utils.log_report import _get_time | ||
|
||
|
||
class OutputHandler(BaseOutputHandler): | ||
"""Handler for ChainerUI logger | ||
A helper for handler to log engine's output, specialized for ChainerUI. | ||
This handler sets 'epoch', 'iteration' and 'elapsed_time' automatically, | ||
these are default x axis to show. | ||
.. code-block:: python | ||
from chainerui.contrib.ignite.handler import OutputHandler | ||
train_handler = OutputHandler( | ||
'train', output_transform=lambda o: {'param': o}) | ||
val_handler = OutputHandler('val', metric_names='all') | ||
Args: | ||
tag (str): use for a prefix of parameter name, will show as | ||
{tag}/{param} | ||
metric_names (str or list): keys names of ``list`` to monitor. set | ||
``'all'`` to get all metrics monitored by the engine. | ||
output_transform (func): if set, use this function to convert output | ||
from ``engine.state.output`` | ||
another_engine (``ignite.engine.Engine``): if set, use for getting | ||
global step. This option is deprecated from 0.3. | ||
global_step_transform (func): if set, use this to get global step. | ||
interval_step (int): interval step for posting metrics to ChainerUI | ||
server. | ||
""" | ||
|
||
def __init__( | ||
self, tag, metric_names=None, output_transform=None, | ||
another_engine=None, global_step_transform=None, interval_step=-1): | ||
super(OutputHandler, self).__init__( | ||
tag, metric_names, output_transform, another_engine, | ||
global_step_transform) | ||
self.interval = interval_step | ||
|
||
def __call__(self, engine, logger, event_name): | ||
if not isinstance(logger, ChainerUILogger): | ||
raise RuntimeError( | ||
'`chainerui.contrib.ignite.handler.OutputHandler` works only ' | ||
'with ChainerUILogger, but set {}'.format(type(logger))) | ||
|
||
metrics = self._setup_output_metrics(engine) | ||
if not metrics: | ||
return | ||
iteration = self.global_step_transform( | ||
engine, Events.ITERATION_COMPLETED) | ||
epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED) | ||
|
||
# convert metrics name | ||
rendered_metrics = {} | ||
for k, v in metrics.items(): | ||
rendered_metrics['{}/{}'.format(self.tag, k)] = v | ||
rendered_metrics['iteration'] = iteration | ||
rendered_metrics['epoch'] = epoch | ||
if 'elapsed_time' not in rendered_metrics: | ||
rendered_metrics['elapsed_time'] = _get_time() - logger.start_at | ||
|
||
if self.interval <= 1: | ||
logger.post_log([rendered_metrics]) | ||
return | ||
|
||
# enable interval, cache metrics | ||
logger.cache.setdefault(self.tag, []).append(rendered_metrics) | ||
# select appropriate even set by handler init | ||
global_count = self.global_step_transform(engine, event_name) | ||
if global_count % self.interval == 0: | ||
logger.post_log(logger.cache[self.tag]) | ||
logger.cache[self.tag].clear() | ||
|
||
|
||
class ChainerUILogger(BaseLogger): | ||
"""Logger handler for ChainerUI | ||
A helper logger to post metrics to ChainerUI server. Attached handlers | ||
are expected using ``chainerui.contrib.ignite.handler.OutputHandler``. | ||
A tag name of handler must be unique when attach several handlers. | ||
.. code-block:: python | ||
from chainerui.contrib.ignite.handler import OutputHandler | ||
train_handler = OutputHandler(...) | ||
val_handler = OutputHandler(...) | ||
from ignite.engine.engine import Engine | ||
train_engine = Engine(...) | ||
eval_engine = Engine(...) | ||
from chainerui.contrib.ignite.handler import ChainerUILogger | ||
logger = ChainerUILogger() | ||
logger.attach( | ||
train_engine, log_handler=train_handler, | ||
event_name=Events.EPOCH_COMPLETED) | ||
logger.attach( | ||
eval_engine, log_handler=val_handler, | ||
event_name=Event.EPOCH_COMPLETED) | ||
""" | ||
|
||
def __init__(self): | ||
web_client = chainerui.client.client._client | ||
if not web_client: | ||
raise RuntimeError( | ||
'fail to setup ChainerUI logger, please check ' | ||
'`chainerui.init()` is success to execute.') | ||
|
||
self.attached_tags = set() | ||
self.client = web_client | ||
self.cache = {} # key is tag name, value is list of metrics | ||
self.start_at = _get_time() | ||
|
||
def attach(self, engine, log_handler, event_name): | ||
if log_handler.tag in self.attached_tags: | ||
raise RuntimeError('attached handlers must have unique tag name') | ||
self.attached_tags.add(log_handler.tag) | ||
super(ChainerUILogger, self).attach(engine, log_handler, event_name) | ||
|
||
def post_log(self, metrics): | ||
self.client.post_log(metrics) | ||
|
||
def __enter__(self): | ||
# overwrite start timestamp when engine is used with 'with' block | ||
self.start_at = _get_time() | ||
super(ChainerUILogger, self).__enter__() | ||
|
||
def close(self): | ||
for k, v in self.cache.items(): | ||
if v: | ||
self.post_log(v) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
chainer>=3.0.0 | ||
pytorch-ignite>=0.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# This example is based on https://github.com/pytorch/ignite/blob/v0.2.1/examples/mnist/mnist.py # NOQA | ||
# Please see [ChainerUI] commend on the below code. | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch.optim import SGD | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import MNIST | ||
from torchvision.transforms import Compose | ||
from torchvision.transforms import Normalize | ||
from torchvision.transforms import ToTensor | ||
|
||
from ignite.contrib.handlers.base_logger import global_step_from_engine | ||
from ignite.engine import create_supervised_evaluator | ||
from ignite.engine import create_supervised_trainer | ||
from ignite.engine import Events | ||
from ignite.metrics import Accuracy | ||
from ignite.metrics import Loss | ||
|
||
from tqdm import tqdm | ||
|
||
import chainerui | ||
from chainerui.contrib.ignite.handler import ChainerUILogger | ||
from chainerui.contrib.ignite.handler import OutputHandler | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=-1) | ||
|
||
|
||
def get_data_loaders(train_batch_size, val_batch_size): | ||
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | ||
|
||
train_loader = DataLoader( | ||
MNIST(download=True, root=".", transform=data_transform, train=True), | ||
batch_size=train_batch_size, shuffle=True) | ||
|
||
val_loader = DataLoader( | ||
MNIST(download=False, root=".", transform=data_transform, train=False), | ||
batch_size=val_batch_size, shuffle=False) | ||
return train_loader, val_loader | ||
|
||
|
||
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval): | ||
train_loader, val_loader = get_data_loaders( | ||
train_batch_size, val_batch_size) | ||
model = Net() | ||
device = 'cpu' | ||
|
||
if torch.cuda.is_available(): | ||
device = 'cuda' | ||
|
||
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) | ||
trainer = create_supervised_trainer( | ||
model, optimizer, F.nll_loss, device=device) | ||
evaluator = create_supervised_evaluator(model, | ||
metrics={'accuracy': Accuracy(), | ||
'nll': Loss(F.nll_loss)}, | ||
device=device) | ||
|
||
desc = "ITERATION - loss: {:.2f}" | ||
pbar = tqdm( | ||
initial=0, leave=False, total=len(train_loader), | ||
desc=desc.format(0) | ||
) | ||
|
||
@trainer.on(Events.ITERATION_COMPLETED) | ||
def log_training_loss(engine): | ||
iter = (engine.state.iteration - 1) % len(train_loader) + 1 | ||
|
||
if iter % log_interval == 0: | ||
pbar.desc = desc.format(engine.state.output) | ||
pbar.update(log_interval) | ||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def log_training_results(engine): | ||
pbar.refresh() | ||
evaluator.run(train_loader) | ||
metrics = evaluator.state.metrics | ||
avg_accuracy = metrics['accuracy'] | ||
avg_nll = metrics['nll'] | ||
tqdm.write( | ||
"Training Results - " | ||
"Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" | ||
.format(engine.state.epoch, avg_accuracy, avg_nll) | ||
) | ||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def log_validation_results(engine): | ||
evaluator.run(val_loader) | ||
metrics = evaluator.state.metrics | ||
avg_accuracy = metrics['accuracy'] | ||
avg_nll = metrics['nll'] | ||
tqdm.write( | ||
"Validation Results " | ||
"- Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" | ||
.format(engine.state.epoch, avg_accuracy, avg_nll)) | ||
|
||
pbar.n = pbar.last_print_n = 0 | ||
|
||
# [ChainerUI] import logger and handler | ||
# [ChainerUI] setup logger, this logger manages ChainerUI web client. | ||
logger = ChainerUILogger() | ||
# [ChainerUI] to ease requests of metrics posting, set interval option | ||
train_handler = OutputHandler( | ||
'train', output_transform=lambda o: {'nll': o}, interval_step=20) | ||
logger.attach( | ||
trainer, log_handler=train_handler, | ||
event_name=Events.ITERATION_COMPLETED) | ||
# [ChainerUI] to set same value of x axis, use global_step_transform | ||
val_handler = OutputHandler( | ||
'val', metric_names='all', | ||
global_step_transform=global_step_from_engine(trainer)) | ||
logger.attach( | ||
evaluator, log_handler=val_handler, event_name=Events.EPOCH_COMPLETED) | ||
# [ChainerUI] to post remainder of metrics caused by interval, use "with" | ||
with logger: | ||
trainer.run(train_loader, max_epochs=epochs) | ||
pbar.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument('--batch_size', type=int, default=64, | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--val_batch_size', type=int, default=1000, | ||
help='input batch size for validation (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=10, | ||
help='number of epochs to train (default: 10)') | ||
parser.add_argument('--lr', type=float, default=0.01, | ||
help='learning rate (default: 0.01)') | ||
parser.add_argument('--momentum', type=float, default=0.5, | ||
help='SGD momentum (default: 0.5)') | ||
parser.add_argument('--log_interval', type=int, default=10, | ||
help='how many batches to wait before logging ' | ||
'training status') | ||
|
||
args = parser.parse_args() | ||
|
||
# [ChainerUI] To use ChainerUI web client, must initialize | ||
# args will be showed as parameter of this experiment. | ||
chainerui.init(conditions=args) | ||
|
||
run( | ||
args.batch_size, args.val_batch_size, args.epochs, args.lr, | ||
args.momentum, args.log_interval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,10 @@ | |
'matplotlib', | ||
'scipy', | ||
'chainer', | ||
], | ||
'test-ci-contrib': [ | ||
'-r test-ci', | ||
'pytorch-ignite', | ||
] | ||
} | ||
|
||
|
Oops, something went wrong.