Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import betterproto2
from typing_extensions import Self

from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny

Expand Down Expand Up @@ -60,7 +61,7 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]:

# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
value = dict(value) # Make a copy

type_url = value.pop("@type", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

import betterproto2
from typing_extensions import Self

from betterproto2_compiler.lib.google.protobuf import Duration as VanillaDuration

Expand Down Expand Up @@ -30,13 +31,13 @@ def delta_to_json(delta: datetime.timedelta) -> str:

# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, str):
if not re.match(r"^\d+(\.\d+)?s$", value):
raise ValueError(f"Invalid duration string: {value}")

seconds = float(value[:-1])
return Duration(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9))
return cls(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9))

return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import betterproto2
from typing_extensions import Self

from betterproto2_compiler.lib.google.protobuf import (
BoolValue as VanillaBoolValue,
Expand All @@ -24,9 +25,9 @@ def to_wrapped(self) -> bool:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, bool):
return BoolValue(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -48,9 +49,9 @@ def to_wrapped(self) -> int:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, int):
return Int32Value(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -72,9 +73,9 @@ def to_wrapped(self) -> int:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, int):
return Int64Value(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -96,9 +97,9 @@ def to_wrapped(self) -> int:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, int):
return UInt32Value(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -120,9 +121,9 @@ def to_wrapped(self) -> int:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, int):
return UInt64Value(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -144,9 +145,9 @@ def to_wrapped(self) -> float:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, float):
return FloatValue(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -168,9 +169,9 @@ def to_wrapped(self) -> float:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, float):
return DoubleValue(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -192,9 +193,9 @@ def to_wrapped(self) -> str:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, str):
return StringValue(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand All @@ -216,9 +217,9 @@ def to_wrapped(self) -> bytes:
return self.value

@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, bytes):
return BytesValue(value=value)
return cls(value=value)
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

def to_dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import betterproto2
from typing_extensions import Self

from betterproto2_compiler.lib.google.protobuf import (
ListValue as VanillaListValue,
Expand All @@ -13,7 +14,7 @@
class Struct(VanillaStruct):
# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
assert isinstance(value, dict)

fields: dict[str, Value] = {}
Expand Down Expand Up @@ -47,7 +48,7 @@ def to_dict(
class Value(VanillaValue):
# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
match value:
case bool() as b:
return cls(bool_value=b)
Expand Down Expand Up @@ -94,7 +95,7 @@ def to_dict(
class ListValue(VanillaListValue):
# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
return cls(values=[Value.from_dict(v) for v in value])

# TODO typing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import betterproto2
import dateutil.parser
from typing_extensions import Self

from betterproto2_compiler.lib.google.protobuf import Timestamp as VanillaTimestamp


class Timestamp(VanillaTimestamp):
@classmethod
def from_datetime(cls, dt: datetime.datetime) -> "Timestamp":
def from_datetime(cls, dt: datetime.datetime) -> Self:
if not dt.tzinfo:
raise ValueError("datetime must be timezone aware")

Expand Down Expand Up @@ -55,11 +56,11 @@ def timestamp_to_json(dt: datetime.datetime) -> str:

# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
if isinstance(value, str):
dt = dateutil.parser.isoparse(value)
dt = dt.astimezone(datetime.timezone.utc)
return Timestamp.from_datetime(dt)
return cls.from_datetime(dt)

return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)

Expand Down
Loading