Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 migration #5167

Merged
merged 68 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
99323fa
initial changes to migrate to pydantic V2
mrwyattii Feb 21, 2024
3e0979c
update requirements
mrwyattii Feb 21, 2024
4571701
fix migration bug
mrwyattii Feb 21, 2024
643ae42
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 21, 2024
96fee35
fix inference config type annotations
mrwyattii Feb 21, 2024
dfe47eb
update RTD reqs
mrwyattii Feb 21, 2024
a6f8651
fix error in offload config
mrwyattii Feb 21, 2024
e780745
final fixes and updates to remove deprecated warnings from pydantic
mrwyattii Feb 21, 2024
9037e3c
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 22, 2024
c12cfe6
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 26, 2024
e2d075a
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Feb 27, 2024
fea7c1d
Test with updating thinc version - fixes pydantic on a6000
loadams Feb 27, 2024
5266568
Remove thinc
loadams Feb 27, 2024
65be824
Confirm uninstall of thinc
loadams Feb 27, 2024
ed08718
Also uninstall spacy
loadams Feb 27, 2024
a97e569
Reverting testing commits
loadams Feb 27, 2024
b398ba6
Update packages to support latest pydantic
loadams Feb 27, 2024
43e6367
further changes to support MII
mrwyattii Feb 28, 2024
1e8ba21
Merge branch 'master' into mrwyattii/pydantic-2-support
mrwyattii Mar 2, 2024
b9781e1
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Mar 4, 2024
ae5fd5b
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Mar 8, 2024
4969551
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
cf7bee9
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
91789b5
Update file that was modified in #5234
loadams Apr 4, 2024
93d3d6a
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 4, 2024
203f5b7
Update container to newer version rather than updating specific packages
loadams Apr 5, 2024
aea6795
Revert "Update container to newer version rather than updating specif…
loadams Apr 5, 2024
a8658ca
Add comment
loadams Apr 5, 2024
55193a5
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 8, 2024
4161028
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Apr 24, 2024
2d5327d
Merge branch 'master' into mrwyattii/pydantic-2-support
tjruwase Apr 27, 2024
0978380
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 7, 2024
a7ddc5e
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 9, 2024
55d39c0
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 16, 2024
fcee6a7
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 20, 2024
d80508d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 22, 2024
8c0b98f
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 May 24, 2024
ace913b
Fix a couple of failing CI tests
adk9 May 28, 2024
aee5f9d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 28, 2024
62ca5f2
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams May 28, 2024
4cb7ac3
Correct fix for dtype validation in DeepSpeedInferenceConfig
adk9 May 28, 2024
45a9c25
Rename model_config to model_conf
adk9 May 28, 2024
96edbbf
Revert "Rename model_config to model_conf"
adk9 May 28, 2024
8c982d2
Merge branch 'master' into mrwyattii/pydantic-2-support
lekurile May 30, 2024
a04de7f
Temporarily checkout PR branch in the nv-accelerate-v100 pipeline
adk9 May 30, 2024
08c16c1
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 3, 2024
75640e3
PR 2814 is now merged into accelerate/master
adk9 Jun 6, 2024
d72db03
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 6, 2024
b97d514
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 10, 2024
0ac9533
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 12, 2024
437ecee
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 14, 2024
ca9c8ef
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 18, 2024
670ac94
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jun 26, 2024
f973393
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jun 26, 2024
09fa6b5
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jun 27, 2024
5d8fb2d
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 1, 2024
1cbf3e1
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 1, 2024
b3804ad
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 12, 2024
41fc635
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Jul 16, 2024
9f65563
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 23, 2024
79c0835
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 23, 2024
295a806
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Jul 25, 2024
1e3925e
Merge branch 'master' into mrwyattii/pydantic-2-support
adk9 Aug 1, 2024
1eec90f
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 14, 2024
c82a73c
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 15, 2024
75a9288
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 19, 2024
e557489
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 20, 2024
628cf25
Merge branch 'master' into mrwyattii/pydantic-2-support
loadams Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/nv-a6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install pydantic==1.10.11
# Update packages included in the container that do not support pydantic 2+ to versions that do
python -m pip install thinc spacy confection --upgrade
loadams marked this conversation as resolved.
Show resolved Hide resolved
python -m pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
Expand Down
14 changes: 3 additions & 11 deletions deepspeed/comm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,12 @@

# DeepSpeed Team

from .constants import *
from ..pydantic_v1 import BaseModel

from deepspeed.runtime.config_utils import DeepSpeedConfigModel

class CommsConfig(BaseModel):

class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
from .constants import *


class CommsLoggerConfig(CommsConfig):
class CommsLoggerConfig(DeepSpeedConfigModel):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
Expand Down
119 changes: 63 additions & 56 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,25 @@

import torch
import deepspeed
from deepspeed.pydantic_v1 import Field, validator
from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from typing import Dict, Union
from typing import Dict, Union, Optional
from enum import Enum


class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
int8 = torch.int8, "torch.int8", "int8"

# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj

def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
int8 = (torch.int8, "torch.int8", "int8")

@classmethod
def from_str(cls, value: str):
for dtype in cls:
if value in dtype.value:
return dtype
raise ValueError(f"'{value}' is not a valid DtypeEnum")


class MoETypeEnum(str, Enum):
Expand Down Expand Up @@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):


class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
enabled: bool = True
num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1


class WeightQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}


class ActivationQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True


class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
enabled: bool = True


class QuantizationConfig(DeepSpeedConfigModel):
Expand All @@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):

# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None
checkpoint_dir: Optional[str] = None
save_mp_checkpoint_path: Optional[str] = None
base_dir: Optional[str] = None


class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
Expand All @@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)`
"""

dtype: DtypeEnum = torch.float16
dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
Expand Down Expand Up @@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""

#todo: refactor the following 3 into the new checkpoint_config
checkpoint: Union[str, Dict] = None
checkpoint: Optional[Union[str, Dict]] = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
Expand All @@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
specifying whether the inference-module is created with empty or real Tensor
"""

save_mp_checkpoint_path: str = None
save_mp_checkpoint_path: Optional[str] = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
Expand Down Expand Up @@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):

replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
json_schema_extra={
"deprecated": True,
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
})

injection_policy: Dict = Field(None, alias="injection_dict")
injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
"""

injection_policy_tuple: tuple = None
injection_policy_tuple: Optional[tuple] = None
""" TODO: Add docs """

config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor

max_out_tokens: int = Field(1024, alias="max_tokens")
"""
Expand All @@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):

transposed_mode: bool = Field(False, alias="transposed_mode")

mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")

@validator("moe")
mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
ep_group: object = Field(None,
alias="expert_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_group"
})
ep_mp_group: object = Field(None,
alias="expert_mp_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_mp_group"
})
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
json_schema_extra={
"deprecated": True,
"new_param": "moe.type"
})

@field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")

@field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
@field_validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value

class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
3 changes: 2 additions & 1 deletion deepspeed/inference/v2/config_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

# DeepSpeed Team

from pydantic import Field
from typing import Optional
from deepspeed.pydantic_v1 import Field

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ class TensorMetadata(DeepSpeedConfigModel):
"""
A class to represent a tensor specification.
"""
dtype: Optional[str]
shape: Optional[Tuple[int, ...]]
strides: Optional[Tuple[int, ...]]
dtype: Optional[str] = None
shape: Optional[Tuple[int, ...]] = None
strides: Optional[Tuple[int, ...]] = None
offset: int


class ParameterMetadata(DeepSpeedConfigModel):
"""
A class to represent a parameter specification.
"""
core_param: TensorMetadata = None
core_param: Optional[TensorMetadata] = None
aux_params: Dict[str, TensorMetadata] = {}


Expand Down
14 changes: 6 additions & 8 deletions deepspeed/inference/v2/ragged/manager_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Tuple

from deepspeed.pydantic_v1 import PositiveInt, validator
from pydantic import PositiveInt, model_validator

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum
Expand Down Expand Up @@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
"""

@validator("max_ragged_sequence_count")
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
@model_validator(mode="after")
def max_ragged_sequence_count_validator(self):
# If the attributes below failed their validation they won't appear in the values dict.
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
return v
assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
return self
16 changes: 8 additions & 8 deletions deepspeed/monitor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Optional

from deepspeed.pydantic_v1 import root_validator
from pydantic import model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


Expand Down Expand Up @@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """

group: str = None
group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """

team: str = None
team: Optional[str] = None
""" Name for the WandB team. """

project: str = "deepspeed"
Expand Down Expand Up @@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """

@root_validator
def check_enabled(cls, values):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
"csv_monitor").enabled or values.get("comet").enabled
return values
@model_validator(mode="after")
def check_enabled(self):
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
self.__dict__["enabled"] = enabled
return self
16 changes: 0 additions & 16 deletions deepspeed/pydantic_v1.py

This file was deleted.

Loading
Loading