Skip to content

Commit

Permalink
Take train_workflow() out of ModelManager (#400)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #400

`train_workflow()` is basically the same for every algos. The internal version & OSS version are different so let's separate them

Reviewed By: kaiwenw

Differential Revision: D26642559

fbshipit-source-id: 126fc202b519396eb9c3ba43d522a3ed7abad745
  • Loading branch information
kittipatv authored and facebook-github-bot committed Mar 3, 2021
1 parent a99d005 commit b898b63
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 80 deletions.
102 changes: 23 additions & 79 deletions reagent/workflow/model_managers/model_manager.py
@@ -1,9 +1,7 @@
#!/usr/bin/env python3

import abc
import dataclasses
import logging
import time
from typing import Dict, List, Optional, Tuple

import pytorch_lightning as pl
Expand All @@ -12,19 +10,16 @@
from reagent.core.dataclasses import dataclass
from reagent.core.registry_meta import RegistryMeta
from reagent.parameters import NormalizationData
from reagent.tensorboardX import summary_writer_context
from reagent.training import ReAgentLightningModule, Trainer
from reagent.workflow.data import ReAgentDataModule
from reagent.workflow.types import (
Dataset,
ModuleNameToEntityId,
ReaderOptions,
ResourceOptions,
RewardOptions,
RLTrainingOutput,
TableSpec,
)
from torch.utils.tensorboard import SummaryWriter


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,15 +80,19 @@ def get_data_module(
saved_setup_data: Optional[Dict[str, bytes]] = None,
reader_options: Optional[ReaderOptions] = None,
) -> Optional[ReAgentDataModule]:
# Return the data module. If this is not None, then `run_feature_identification` &
# `query_data` will not be run.
"""
Return the data module. If this is not None, then `run_feature_identification` &
`query_data` will not be run.
"""
return None

@abc.abstractmethod
def run_feature_identification(
self, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
"""
DEPRECATED: Implement get_data_module() instead
Derive preprocessing parameters from data. The keys of the dict should
match the keys from `required_normalization_keys()`
"""
Expand Down Expand Up @@ -131,6 +130,9 @@ def __getattr__(self, attr):
@property
@abc.abstractmethod
def should_generate_eval_dataset(self) -> bool:
"""
DEPRECATED: Implement get_data_module() instead
"""
pass

@abc.abstractmethod
Expand All @@ -141,6 +143,8 @@ def query_data(
reward_options: RewardOptions,
) -> Dataset:
"""
DEPRECATED: Implement get_data_module() instead
Massage input table into the format expected by the trainer
"""
pass
Expand Down Expand Up @@ -207,76 +211,6 @@ def build_trainer(self) -> Trainer:
def destroy_trainer(self):
self._trainer = None

def train_workflow(
self,
train_dataset: Optional[Dataset],
eval_dataset: Optional[Dataset],
*,
num_epochs: int,
use_gpu: bool,
named_model_ids: ModuleNameToEntityId,
child_workflow_id: int,
setup_data: Optional[Dict[str, bytes]] = None,
normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
reward_options: Optional[RewardOptions] = None,
reader_options: Optional[ReaderOptions] = None,
resource_options: Optional[ResourceOptions] = None,
warmstart_path: Optional[str] = None,
) -> RLTrainingOutput:
writer = SummaryWriter()
logger.info("TensorBoard logging location is: {}".format(writer.log_dir))

if setup_data is not None:
data_module = self.get_data_module(
setup_data=setup_data, reader_options=reader_options
)
assert data_module is not None
data_module.setup()
else:
data_module = None

if normalization_data_map is None:
assert data_module is not None
normalization_data_map = data_module.get_normalization_data_map(
self.required_normalization_keys
)

warmstart_input_path = warmstart_path or None
self.initialize_trainer(
use_gpu=use_gpu,
# pyre-fixme[6]: Expected `RewardOptions` for 2nd param but got
# `Optional[RewardOptions]`.
reward_options=reward_options,
normalization_data_map=normalization_data_map,
warmstart_path=warmstart_input_path,
)

if not reader_options:
reader_options = ReaderOptions()

if not resource_options:
resource_options = ResourceOptions()

with summary_writer_context(writer):
train_output = self.train(
train_dataset,
eval_dataset,
data_module,
num_epochs,
reader_options,
resource_options,
)

output_paths = {}
for module_name, serving_module in self.build_serving_modules().items():
# TODO: make this a parameter
torchscript_output_path = f"model_{round(time.time())}.torchscript"
serving_module = self.build_serving_module()
torch.jit.save(serving_module, torchscript_output_path)
logger.info(f"Saved {module_name} to {torchscript_output_path}")
output_paths[module_name] = torchscript_output_path
return dataclasses.replace(train_output, output_paths=output_paths)

@abc.abstractmethod
def train(
self,
Expand All @@ -288,6 +222,10 @@ def train(
resource_options: Optional[ResourceOptions],
) -> RLTrainingOutput:
"""
DEPRECATED: Delete this once every trainer is built on PyTorch Lightning &
every ModelManager implemnts get_data_module(). Then, we can just move the code
in train() of DiscreteDQNBase into the training workflow function
Train the model
Arguments:
train/eval_dataset: what you'd expect
Expand All @@ -300,12 +238,18 @@ def train(

# TODO: make abstract
def build_serving_modules(self) -> Dict[str, torch.nn.Module]:
# eventually move to this method to be more generic
"""
Returns TorchScript for serving in production
"""
return {"default_model": self.build_serving_module()}

# TODO: make abstract
def serving_module_names(self) -> List[str]:
# should match sorted(self.build_serving_modules.keys())
"""
Returns the keys that would be returned in `build_serving_modules()`.
This method is required because we need to reserve entity IDs for
these serving modules before we start the training.
"""
return ["default_model"]

def save_trainer(self, output_path: str) -> None:
Expand Down
79 changes: 78 additions & 1 deletion reagent/workflow/training.py
Expand Up @@ -2,15 +2,19 @@

import dataclasses
import logging
import time
from typing import Dict, NamedTuple, Optional, Tuple

import torch
from reagent.parameters import NormalizationData
from reagent.publishers.union import ModelPublisher__Union
from reagent.tensorboardX import summary_writer_context
from reagent.validators.union import ModelValidator__Union
from reagent.workflow.env import get_new_named_entity_ids, get_workflow_id
from reagent.workflow.model_managers.model_manager import ModelManager
from reagent.workflow.model_managers.union import ModelManager__Union
from reagent.workflow.types import (
Dataset,
ModuleNameToEntityId,
ReaderOptions,
RecurringPeriod,
Expand All @@ -19,6 +23,7 @@
RLTrainingOutput,
TableSpec,
)
from torch.utils.tensorboard import SummaryWriter


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -189,7 +194,9 @@ def _maybe_get_bytes(v) -> bytes:
)

logger.info("Starting training")
results = manager.train_workflow(

results = train_workflow(
manager,
train_dataset,
eval_dataset,
num_epochs=num_epochs,
Expand Down Expand Up @@ -220,6 +227,76 @@ def _maybe_get_bytes(v) -> bytes:
return results


def train_workflow(
model_manager: ModelManager,
train_dataset: Optional[Dataset],
eval_dataset: Optional[Dataset],
*,
num_epochs: int,
use_gpu: bool,
named_model_ids: ModuleNameToEntityId,
child_workflow_id: int,
setup_data: Optional[Dict[str, bytes]] = None,
normalization_data_map: Optional[Dict[str, NormalizationData]] = None,
reward_options: Optional[RewardOptions] = None,
reader_options: Optional[ReaderOptions] = None,
resource_options: Optional[ResourceOptions] = None,
warmstart_path: Optional[str] = None,
) -> RLTrainingOutput:
writer = SummaryWriter()
logger.info("TensorBoard logging location is: {}".format(writer.log_dir))

if setup_data is not None:
data_module = model_manager.get_data_module(
setup_data=setup_data, reader_options=reader_options
)
assert data_module is not None
data_module.setup()
else:
data_module = None

if normalization_data_map is None:
assert data_module is not None
normalization_data_map = data_module.get_normalization_data_map(
model_manager.required_normalization_keys
)

warmstart_input_path = warmstart_path or None
model_manager.initialize_trainer(
use_gpu=use_gpu,
# pyre-fixme[6]: Expected `RewardOptions` for 2nd param but got
# `Optional[RewardOptions]`.
reward_options=reward_options,
normalization_data_map=normalization_data_map,
warmstart_path=warmstart_input_path,
)

if not reader_options:
reader_options = ReaderOptions()

if not resource_options:
resource_options = ResourceOptions()

with summary_writer_context(writer):
train_output = model_manager.train(
train_dataset,
eval_dataset,
data_module,
num_epochs,
reader_options,
resource_options,
)

output_paths = {}
for module_name, serving_module in model_manager.build_serving_modules().items():
# TODO: make this a parameter
torchscript_output_path = f"model_{round(time.time())}.torchscript"
torch.jit.save(serving_module, torchscript_output_path)
logger.info(f"Saved {module_name} to {torchscript_output_path}")
output_paths[module_name] = torchscript_output_path
return dataclasses.replace(train_output, output_paths=output_paths)


def run_validator(
validator: ModelValidator__Union, training_output: RLTrainingOutput
) -> RLTrainingOutput:
Expand Down

0 comments on commit b898b63

Please sign in to comment.