Permalink
Browse files

enable exporting individual models in disjoint multitask (#197)

Summary:
Pull Request resolved: #197

Currently model export for DisjointMultitask is not implemented.  Here we allow configuring exporter under each task, and export sub-models individually.

Reviewed By: mwu1993

Differential Revision: D13612990

fbshipit-source-id: abf63db775f91367ee19c3132e71bd0d5115cf5e
  • Loading branch information...
borguz authored and facebook-github-bot committed Jan 10, 2019
1 parent ac40f39 commit ec4b85185cd013c54652b7c68fa50b953fa1f5c4
Showing with 54 additions and 4 deletions.
  1. +2 −1 demo/configs/multitask_sst_lm.json
  2. +2 −1 demo/configs/sst2.json
  3. +50 −2 pytext/task/disjoint_multitask.py
@@ -36,7 +36,8 @@
}
}
}
}
},
"exporter": {}
}
},
"LM": {
@@ -18,7 +18,8 @@
},
"trainer": {
"epochs": 15
}
},
"exporter": {}
}
}
}
@@ -7,6 +7,7 @@
from pytext.config import config_to_json
from pytext.config.component import (
create_data_handler,
create_exporter,
create_featurizer,
create_metric_reporter,
create_model,
@@ -40,6 +41,7 @@ def from_config(cls, task_config, metadata=None, model_state=None):
pprint(config_to_json(type(task_config), task_config))

data_handlers = OrderedDict()
exporters = OrderedDict()
for name, task in task_config.tasks.items():
featurizer = create_featurizer(task.featurizer, task.features)
data_handlers[name] = create_data_handler(
@@ -54,7 +56,20 @@ def from_config(cls, task_config, metadata=None, model_state=None):
else:
data_handler.init_metadata()
metadata = data_handler.metadata

exporters = {
name: (
create_exporter(
task.exporter,
task.features,
task.labels,
data_handler.data_handlers[name].metadata,
task.model,
)
if task.exporter
else None
)
for name, task in task_config.tasks.items()
}
metric_reporter = DisjointMultitaskMetricReporter(
OrderedDict(
(name, create_metric_reporter(task.metric_reporter, metadata[name]))
@@ -76,6 +91,7 @@ def from_config(cls, task_config, metadata=None, model_state=None):

optimizers = create_optimizer(model, task_config.optimizer)
return cls(
exporters=exporters,
trainer=create_trainer(task_config.trainer),
data_handler=data_handler,
model=model,
@@ -84,5 +100,37 @@ def from_config(cls, task_config, metadata=None, model_state=None):
lr_scheduler=Scheduler(
optimizers, task_config.scheduler, metric_reporter.lower_is_better
),
exporter=None,
)

def __init__(self, exporters, **kwargs):
super().__init__(exporter=None, **kwargs)
self.exporters = exporters

def export(
self, multitask_model, export_path, summary_writer=None, export_onnx_path=None
):
"""
Wrapper method to export PyTorch model to Caffe2 model using :class:`~Exporter`.
Args:
export_path (str): file path of exported caffe2 model
summary_writer: TensorBoard SummaryWriter, used to output the PyTorch
model's execution graph to TensorBoard, default is None.
export_onnx_path (str):file path of exported onnx model
"""
# Make sure to put the model on CPU and disable CUDA before exporting to
# ONNX to disable any data_parallel pieces
cuda_utils.CUDA_ENABLED = False
for name, model in multitask_model.models.items():
model = model.cpu()
if self.exporters[name]:
if summary_writer is not None:
self.exporters[name].export_to_tensorboard(model, summary_writer)
model_export_path = f"{export_path}-{name}"
model_export_onnx_path = (
f"{export_onnx_path}-{name}" if export_onnx_path else None
)
print("Saving caffe2 model to: " + model_export_path)
self.exporters[name].export_to_caffe2(
model, model_export_path, model_export_onnx_path
)

1 comment on commit ec4b851

@puttkraidej

This comment has been minimized.

Copy link

puttkraidej commented on ec4b851 Jan 11, 2019

Thanks

Please sign in to comment.