Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in checkpointing capability #255

Merged
merged 30 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4483ce2
Finish save implementation, need to write load + tests
muellerzr Feb 16, 2022
e2983c3
Initial save and load complete
muellerzr Feb 16, 2022
820d74f
Start of tests
muellerzr Feb 17, 2022
71625c7
Finish save implementation, need to write load + tests
muellerzr Feb 16, 2022
c241aa4
Initial save and load complete
muellerzr Feb 16, 2022
46db8b0
Start of tests
muellerzr Feb 17, 2022
c974928
Working tests!
muellerzr Feb 17, 2022
7ba0400
Merge
muellerzr Feb 17, 2022
9c7fc75
Merge
muellerzr Feb 17, 2022
e63f515
Add in unwrapping
muellerzr Feb 17, 2022
12e4b18
More unwrapping
muellerzr Feb 17, 2022
be9cbd1
Apply typo-fixes and naming suggestions
muellerzr Feb 18, 2022
d4b3d19
Make constants
muellerzr Feb 22, 2022
0c1a87b
Refactor pt 1
muellerzr Feb 28, 2022
b64945c
Cleaned + more nits
muellerzr Feb 28, 2022
3ec7a88
New state of tests
muellerzr Feb 28, 2022
5ccbead
Apply nits from Sylvain
muellerzr Mar 1, 2022
0499159
Use save instead of torch.save
muellerzr Mar 1, 2022
f653836
Docstring fixes
muellerzr Mar 1, 2022
ddf35f9
Expand name in tests
muellerzr Mar 1, 2022
19af1e3
Final nits
muellerzr Mar 1, 2022
7073c3f
Rm args
muellerzr Mar 1, 2022
f853c45
Wrap under is_available
muellerzr Mar 1, 2022
6b0024c
Push current state
muellerzr Mar 1, 2022
8eb2943
More states of things
muellerzr Mar 1, 2022
5b5b343
More nit
muellerzr Mar 1, 2022
199e7ec
As func
muellerzr Mar 1, 2022
a762971
Working refactor
muellerzr Mar 2, 2022
ba7bbb4
Include test for rands
muellerzr Mar 2, 2022
f90b577
Try with fix
muellerzr Mar 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from packaging import version

from .checkpointing import load_accelerator_state, save_accelerator_state
from .data_loader import prepare_data_loader
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs, KwargsHandler
from .optimizer import AcceleratedOptimizer
Expand All @@ -40,6 +41,7 @@

if is_deepspeed_available():
import deepspeed

from .deepspeed_utils import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper

import logging
Expand Down Expand Up @@ -560,6 +562,36 @@ def save(self, obj, f):
"""
save(obj, f)

def save_state(self, output_dir: str):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""
Saves the current states of the model, optimizer, scaler, and RNG generators.

Args:
output_dir (:obj:`str` or :obj:`os.PathLike`):
The name of the folder to save all relevant weights and states.
"""
# Check if folder exists
output_dir = os.path.expanduser(output_dir)
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving current state to {output_dir}")
weights = [self.get_state_dict(m) for m in self._models]
return save_accelerator_state(output_dir, weights, self._optimizers, self.state.process_index, self.scaler)

def load_state(self, input_dir: str):
"""
Loads the current states of the model, optimizer, scaler, and RNG generators.

Args:
input_dir (:obj:`str` or :obj:`os.PathLike`):
The name of the folder all relevant weights and states were saved in.
"""
# Check if folder exists
input_dir = os.path.expanduser(input_dir)
if not os.path.isdir(input_dir):
raise ValueError(f"Tried to find {input_dir} but folder does not exist")
logger.info(f"Loading states from {input_dir}")
load_accelerator_state(input_dir, self._models, self._optimizers, self.state.process_index, self.scaler)

def free_memory(self):
"""
Will release all references to the internal objects stored and call the garbage collector. You should call this
Expand Down
134 changes: 134 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from typing import List

import numpy as np
import torch
from torch.cuda.amp import GradScaler

from .state import is_tpu_available
from .utils import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, save


if is_tpu_available():
import torch_xla.core.xla_model as xm

import logging


logger = logging.getLogger(__name__)


def save_accelerator_state(
output_dir: str, model_states: List[dict], optimizers: list, process_index: int, scaler: GradScaler = None
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.

Args:
output_dir (:obj:`str` or :obj:`os.PathLike`):
The name of the folder to save all relevant weights and states.
model_states (:obj:`List[torch.nn.Module]`):
A list of model states
optimizers (:obj:`List[torch.optim.Optimizer]`):
A list of optimizer instances
process_index (:obj:`int`):
The current process index in the Accelerator state
scaler (:obj:`torch.cuda.amp.GradScaler`, `optional`):
An optional gradient scaler instance to save
"""
# Model states
for i, state in enumerate(model_states):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
output_model_file = os.path.join(output_dir, weights_name)
save(state, output_model_file)
logger.info(f"Model weights saved in {output_model_file}")
# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
output_optimizer_file = os.path.join(output_dir, optimizer_name)
save(state, output_optimizer_file)
logger.info(f"Optimizer state saved in {output_optimizer_file}")
# GradScaler state
if scaler is not None:
state = scaler.state_dict()
output_scaler_file = os.path.join(output_dir, SCALER_NAME)
torch.save(state, output_scaler_file)
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
# Random number generator states
states = {}
states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
states["random_state"] = random.getstate()
states["numpy_random_seed"] = np.random.get_state()
states["torch_manual_seed"] = torch.get_rng_state()
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
# ^^ safe to call this function even if cuda is not available
if is_tpu_available():
states["xm_seed"] = torch.tensor(xm.get_rng_state())
output_states_file = os.path.join(output_dir, states_name)
torch.save(states, output_states_file)
logger.info(f"Random states saved in {output_states_file}")
return output_dir


def load_accelerator_state(input_dir, models, optimizers, process_index, scaler=None):
"""
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.

Args:
input_dir (:obj:`str` or :obj:`os.PathLike`):
The name of the folder to load all relevant weights and states.
model_stmodelsates (:obj:`List[torch.nn.Module]`):
A list of model instances
optimizers (:obj:`List[torch.optim.Optimizer]`):
A list of optimizer instances
process_index (:obj:`int`):
The current process index in the Accelerator state
scaler (:obj:`torch.cuda.amp.GradScaler`, `optional`):
An optional `GradScaler` instance to load
"""
# Model states
for i, model in enumerate(models):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
input_model_file = os.path.join(input_dir, weights_name)
models[i].load_state_dict(torch.load(input_model_file))
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = os.path.join(input_dir, optimizer_name)
optimizers[i].load_state_dict(torch.load(input_optimizer_file))
logger.info("All optimizer states loaded successfully")

# GradScaler state
if scaler is not None:
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
scaler.load_state_dict(torch.load(input_scaler_file))
logger.info("GradScaler state loaded successfully")

# Random states
states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl"))
random.setstate(states["random_state"])
np.random.set_state(states["numpy_random_seed"])
torch.set_rng_state(states["torch_manual_seed"])
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
# ^^ safe to call this function even if cuda is not available
if is_tpu_available():
xm.set_rng_state(states["xm_seed"])
logger.info("All random states loaded successfully")
5 changes: 5 additions & 0 deletions src/accelerate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def is_sagemaker_available():
if is_deepspeed_available():
from deepspeed import DeepSpeedEngine

SCALER_NAME = "scaler.pt"
MODEL_NAME = "pytorch_model"
RNG_STATE_NAME = "random_states"
OPTIMIZER_NAME = "optimizer"


class RNGType(Enum):
TORCH = "torch"
Expand Down
125 changes: 125 additions & 0 deletions tests/test_state_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import random
import tempfile
import unittest

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from accelerate import Accelerator
from accelerate.utils import set_seed


logger = logging.getLogger(__name__)


def dummy_dataloaders(a=2, b=3, batch_size=16, n_train_batches: int = 10, n_valid_batches: int = 2):
"Generates a tuple of dummy DataLoaders to test with"

def get_dataset(n_batches):
x = torch.randn(batch_size * n_batches, 1)
return TensorDataset(x, a * x + b + 0.1 * torch.randn(batch_size * n_batches, 1))

train_dataset = get_dataset(n_train_batches)
valid_dataset = get_dataset(n_valid_batches)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
return (train_dataloader, valid_dataloader)


def train(num_epochs, model, dataloader, optimizer, accelerator):
"Trains for `num_epochs`"
rands = []
for epoch in range(num_epochs):
# Train quickly
model.train()
for step, batch in enumerate(dataloader):
x, y = batch
outputs = model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
rands.append(random.random()) # Introduce some randomness
return rands


class DummyModel(nn.Module):
"Simple model to do y=mx+b"

def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.randn(1))
self.b = nn.Parameter(torch.randn(1))

def forward(self, x):
return x * self.a + self.b


class CheckpointTest(unittest.TestCase):
def test_can_resume_training(self):
with tempfile.TemporaryDirectory() as tmpdir:
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
# Train baseline
accelerator = Accelerator()
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
# Save initial
initial = os.path.join(tmpdir, "initial")
accelerator.save_state(initial)
(a, b) = model.a.item(), model.b.item()
opt_state = optimizer.state_dict()
ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)
(a1, b1) = model.a.item(), model.b.item()
opt_state1 = optimizer.state_dict()

# Train partially
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
accelerator = Accelerator()
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
accelerator.load_state(initial)
(a2, b2) = model.a.item(), model.b.item()
opt_state2 = optimizer.state_dict()
self.assertEqual(a, a2)
self.assertEqual(b, b2)
self.assertEqual(opt_state, opt_state2)

test_rands = train(2, model, train_dataloader, optimizer, accelerator)
# Save everything
checkpoint = os.path.join(tmpdir, "checkpoint")
accelerator.save_state(checkpoint)

# Load everything back in and make sure all states work
accelerator.load_state(checkpoint)
test_rands += train(1, model, train_dataloader, optimizer, accelerator)
(a3, b3) = model.a.item(), model.b.item()
opt_state3 = optimizer.state_dict()
self.assertEqual(a1, a3)
self.assertEqual(b1, b3)
self.assertEqual(opt_state1, opt_state3)
self.assertEqual(ground_truth_rands, test_rands)