Skip to content

Commit

Permalink
Refactor some parts in utils (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed May 20, 2022
1 parent 41427c5 commit 8b8c534
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
10 changes: 9 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@
get_max_memory,
infer_auto_device_map,
load_checkpoint_in_model,
load_offloaded_weights,
named_module_tensors,
set_module_tensor_to_device,
)
from .offload import OffloadedWeightsLoader, PrefixedDataset, extract_submodules_state_dict, offload_state_dict
from .offload import (
OffloadedWeightsLoader,
PrefixedDataset,
extract_submodules_state_dict,
offload_state_dict,
offload_weight,
save_offload_index,
)
from .operations import (
broadcast,
broadcast_object_list,
Expand Down
48 changes: 19 additions & 29 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import torch
import torch.nn as nn

from .offload import offload_weight, save_offload_index


WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"

Expand Down Expand Up @@ -312,6 +314,18 @@ def clean_device_map(device_map: Dict[str, Union[int, str, torch.device]], modul
return device_map


def load_offloaded_weights(model, index, offload_folder):
if index is None or len(index) == 0:
# Nothing to do
return

for param_name, metadata in index.items():
tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
shape = tuple(metadata["shape"])
weight = np.memmap(tensor_file, dtype=metadata["dtype"], mode="r", shape=shape)
set_module_tensor_to_device(model, param_name, "cpu", value=torch.tensor(weight))


def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
Expand Down Expand Up @@ -581,44 +595,20 @@ def load_checkpoint_in_model(

if param_device == "disk":
set_module_tensor_to_device(model, param_name, "meta")
tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
array = param.numpy()
offload_index[param_name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
file_array[:] = array[:]
file_array.flush()
offload_weight(param, param_name, offload_folder, index=offload_index)
elif param_device == "cpu" and offload_state_dict:
set_module_tensor_to_device(model, param_name, "meta")
tensor_file = os.path.join(state_dict_folder, f"{param_name}.dat")
array = param.numpy()
state_dict_index[param_name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
file_array[:] = array[:]
file_array.flush()
offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)

# Force Python to clean up.
del checkpoint
gc.collect()

if len(offload_index) > 0:
offload_index_file = os.path.join(offload_folder, "index.json")
if os.path.isfile(offload_index_file):
with open(offload_index_file, "r", encoding="utf-8") as f:
current_offload_index = json.load(f)
else:
current_offload_index = {}
current_offload_index.update(offload_index)

with open(offload_index_file, "w", encoding="utf-8") as f:
json.dump(current_offload_index, f, indent=2)
save_offload_index(offload_index, offload_folder)

# Load back offloaded state dict on CPU
if offload_state_dict and len(state_dict_index) > 0:
for param_name, metadata in state_dict_index.items():
tensor_file = os.path.join(state_dict_folder, f"{param_name}.dat")
shape = tuple(metadata["shape"])
weight = np.memmap(tensor_file, dtype=metadata["dtype"], mode="r", shape=shape)
set_module_tensor_to_device(model, param_name, "cpu", value=torch.tensor(weight))
if offload_state_dict:
load_offloaded_weights(model, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
28 changes: 28 additions & 0 deletions src/accelerate/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: Dict[str,
json.dump(current_index, f, indent=2)


def offload_weight(weight, weight_name, offload_folder, index=None):
array = weight.numpy()
tensor_file = os.path.join(offload_folder, f"{weight_name}.dat")
if index is not None:
index[weight_name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
file_array[:] = array[:]
file_array.flush()
return index


def save_offload_index(index, offload_folder):
if index is None or len(index) == 0:
# Nothing to save
return

offload_index_file = os.path.join(offload_folder, "index.json")
if os.path.isfile(offload_index_file):
with open(offload_index_file, "r", encoding="utf-8") as f:
current_index = json.load(f)
else:
current_index = {}
current_index.update(index)

with open(offload_index_file, "w", encoding="utf-8") as f:
json.dump(current_index, f, indent=2)


class PrefixedDataset(Mapping):
"""
Will access keys in a given dataset by adding a prefix.
Expand Down

0 comments on commit 8b8c534

Please sign in to comment.