Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add meta device initialization for pretrained models, 5x faster load times #501

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 182 additions & 11 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import json
import inspect
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from contextlib import nullcontext
import importlib
import importlib_metadata
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
Expand All @@ -24,6 +33,142 @@
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs

def set_module_tensor_to_device(
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).

Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
param_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
dtype (`torch.dtype`, *optional*):
If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
the dtype of the existing parameter in the model.
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]

if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)

if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

if value is not None:
if dtype is None:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
value = value.to(old_value.dtype)
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)

with torch.no_grad():
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)

if is_buffer:
module._buffers[tensor_name] = new_value
elif value is not None or torch.device(device) != module._parameters[tensor_name].device:
param_cls = type(module._parameters[tensor_name])
kwargs = module._parameters[tensor_name].__dict__
if param_cls.__name__ == "Int8Params":
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
module._parameters[tensor_name] = new_value


@contextmanager
def init_on_device(device: torch.device, include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the specified device.

Args:
device (`torch.device`):
Device to initialize all parameters on.
include_buffers (`bool`, *optional*, defaults to `False`):
Whether or not to also put all buffers on the meta device while initializing.

Example:

```python
import torch.nn as nn
from accelerate import init_on_device

with init_on_device(device=torch.device("cuda")):
tst = nn.Liner(100, 100) # on `cuda` device
```
"""
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer

def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)

# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}

def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)

return wrapper

try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)


def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
Expand Down Expand Up @@ -94,13 +239,35 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
def load_checkpoint(model, checkpoint_path, strict=True, device='cpu', dtype=torch.float32):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)

incompatible_keys = model.load_state_dict(state_dict, strict=strict)

param_device = device
torch_dtype = dtype
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)

if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load model because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. To ignore this error, set strict=False."
)

if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)

return incompatible_keys


Expand Down Expand Up @@ -183,15 +350,19 @@ def create_model(
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
with init_on_device(torch.device("meta")) if (pretrained or has_hf_hub_prefix) else nullcontext():
print("initializing on meta device with openclip fork")
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
print(list(model.state_dict().items())[0][1].device)


pretrained_loaded = False
if pretrained:
Expand All @@ -204,7 +375,7 @@ def create_model(

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
load_checkpoint(model, checkpoint_path, device=device, dtype=cast_dtype)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
Expand All @@ -214,7 +385,7 @@ def create_model(
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
load_checkpoint(model, checkpoint_path, device=device, dtype=cast_dtype)
pretrained_loaded = True

if require_pretrained and not pretrained_loaded:
Expand Down
Loading