Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Refactor State API's persistence methods (#4)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/elastic#4

Makes state API's persistence (rollback and serialization) more coherent, consistent, and natural. Does the following:

* Renames `deep_copy` and `rollback` to `snapshot` and `apply`
* The semantics of `snapshot` and `apply` is that the state is recoverable by:
      ```
           any_user_defined_snapshot_obj = state.snapshot()
           modify_state(state)
           state.apply(any_user_defined_snapshot_obj)
           state.sync()
      ```
* Renames `serialize` and `deserialize` to `save` and `load` (to be consistent with torch)
* `State` provides a default implementation of `save` and `load` using `snapshot` and `apply`.
* Removes the redundant `supports_rollback()` method from `State`. By not implementing `snapshot/apply` the user indicates that rollback is not supported on the `State` object. If the user wants to checkpoint but not rollback they can implement the `save/load` and not implement `snapshot/apply`. If the user wants rollback support, they lose no performance (in comparison) in doing checkpoints so they might as well get checkpoint for free.
* Makes changes to the `test_mock` and `elastic classy_vision` code to be compliant with the new API.
* Makes imagenet example compliant with the new API.

NOTE: This change renders the imagenet example under `//fblearner/flow/projects/pytorch/elastic/imagenet` broken. However this example was already broken and has zero users. The task to fix this is T57831531.

(Note: this ignores all push blocking failures!)

Reviewed By: vreis

Differential Revision: D18672302

fbshipit-source-id: 849b6cdcc5cb21e95406b42fd22d5b3d6d9a6f66
  • Loading branch information
Kiuk Chung authored and facebook-github-bot committed Dec 1, 2019
1 parent fdd4c2b commit e6b8f30
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions classy_vision/trainer/elastic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def _run_step(self, state, local_variables, use_gpu):

class _ClassyWorkerStats(WorkerStats):
"""
ClassyVision-specific implementation of WorkerStats, which is used by PET loop
ClassyVision-specific implementation of WorkerStats,
which is used by torchelastic train_loop
to detect (and correct stragglers), or other progress-impeding issues.
"""

Expand All @@ -146,26 +147,28 @@ def get_progress_rate(self) -> Optional[float]:
return self.progress_rate

class _ClassyElasticState(torchelastic.State):
"""
Rollback is disabled on this state since currently, data loaders are
too expensive to snapshot on every train_step
"""

def __init__(self, task: ClassyTask, input_args: Any):
self.task = task
self.input_args = input_args if input_args else {}
self.advance_to_next_phase = True
self.skip_current_phase = False

def deep_copy(self):
raise RuntimeError("Unsupported method")

def broadcast_state(self, rank, src_rank):
data = None
if rank == src_rank:
save_stream = io.BytesIO()
self.serialize(save_stream)
self.save(save_stream)
# Note: save_stream.getbuffer() will return a memoryview, which
# cannot be convert to a tensor, need convert it to np array first
data = numpy.asarray(save_stream.getbuffer())
data = dist.broadcast_binary(data, src_rank)
load_stream = io.BytesIO(data)
self.deserialize(load_stream)
self.load(load_stream)

def sync(self, world_size, rank):
self._recreate_ddp_model()
Expand Down Expand Up @@ -209,10 +212,6 @@ def sync(self, world_size, rank):
# Set up pytorch module in train vs eval mode, update optimizer.
self.task._set_model_train_mode()

def supports_rollback(self):
# Dataloaders are too expensive to deep copy on every train iter for now
return False

def should_save_checkpoint(self, rank):
# should_save_checkpoint need to return same value for all trainers
# we take checkpoint when a phase completed
Expand All @@ -222,18 +221,17 @@ def should_save_checkpoint(self, rank):
# consider the cost it is not very necessary to do checkpoint for test phase
return self.task.train and self.advance_to_next_phase

def serialize(self, stream):
def save(self, stream):
checkpoint_state = get_checkpoint_dict(self.task, self.input_args)
checkpoint_state["advance_to_next_phase"] = self.advance_to_next_phase
torch.save(checkpoint_state, stream)

def deserialize(self, stream):
def load(self, stream):
checkpoint_state = torch.load(stream)
state = checkpoint_state["classy_state_dict"]
self.task.set_classy_state(state)
if "advance_to_next_phase" in checkpoint_state:
self.advance_to_next_phase = checkpoint_state["advance_to_next_phase"]
return self

def _recreate_ddp_model(self):
# Delete & re-create the DDP module wrapper. This is required because
Expand Down

0 comments on commit e6b8f30

Please sign in to comment.