diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index bfe26b86..bd794953 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -12,6 +12,7 @@ While the new TargetCategory class supports subtypes, only reading them is curre * `TargetIds(ComponentIds(1), ComponentIds(2), ComponentIds(3))` * `TargetCategories` can be used to specify one or more target categories: * `TargetCategories(ComponentCategory.BATTERY, ComponentCategory.INVERTER)` +* Dispatch ids and microgrid ids are no longer simple `int` types but are now wrapped in `DispatchId` and `MicrogridId` classes, respectively. This allows for better type safety and clarity in the codebase. ## New Features diff --git a/pyproject.toml b/pyproject.toml index 3d93fe50..2b59919c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,10 +36,11 @@ classifiers = [ ] requires-python = ">= 3.11, < 4" dependencies = [ - "typing-extensions >= 4.6.1, < 5", + "typing-extensions >= 4.13.0, < 5", "frequenz-api-dispatch == 1.0.0-rc2", "frequenz-client-base >= 0.8.0, < 0.12.0", - "frequenz-client-common >= 0.1.0, < 0.4.0", + "frequenz-client-common >= 0.3.2, < 0.4.0", + "frequenz-core >= 1.0.2, < 2.0.0", "grpcio >= 1.70.0, < 2", "python-dateutil >= 2.8.2, < 3.0", ] diff --git a/src/frequenz/client/dispatch/__main__.py b/src/frequenz/client/dispatch/__main__.py index 1cefc892..36d2cfd5 100644 --- a/src/frequenz/client/dispatch/__main__.py +++ b/src/frequenz/client/dispatch/__main__.py @@ -18,6 +18,8 @@ from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.shortcuts import CompleteStyle +from frequenz.client.common.microgrid import MicrogridId + from ._cli_types import ( FuzzyDateTime, FuzzyIntRange, @@ -27,7 +29,7 @@ ) from ._client import DispatchApiClient from .recurrence import EndCriteria, Frequency, RecurrenceRule, Weekday -from .types import Dispatch, DispatchEvent +from .types import Dispatch, DispatchEvent, DispatchId def format_datetime(dt: datetime | None) -> str: @@ -260,7 +262,7 @@ async def list_(ctx: click.Context, /, **filters: Any) -> None: @cli.command("stream") @click.pass_context @click.argument("microgrid-id", required=True, type=int) -async def stream(ctx: click.Context, microgrid_id: int) -> None: +async def stream(ctx: click.Context, microgrid_id: MicrogridId) -> None: """Stream dispatches.""" event_stream: Receiver[DispatchEvent] = ctx.obj["client"].stream( microgrid_id=microgrid_id @@ -452,8 +454,8 @@ async def create( async def update( ctx: click.Context, /, - microgrid_id: int, - dispatch_id: int, + microgrid_id: MicrogridId, + dispatch_id: DispatchId, **new_fields: dict[str, Any], ) -> None: """Update a dispatch.""" @@ -499,14 +501,17 @@ def skip_field(value: Any) -> bool: @click.argument("microgrid-id", required=True, type=int) @click.argument("dispatch_ids", type=int, nargs=-1) # Allow multiple IDs @click.pass_context -async def get(ctx: click.Context, microgrid_id: int, dispatch_ids: List[int]) -> None: +async def get( + ctx: click.Context, microgrid_id: MicrogridId, dispatch_ids: List[int] +) -> None: """Get one or multiple dispatches.""" num_failed = 0 for dispatch_id in dispatch_ids: try: dispatch = await ctx.obj["client"].get( - microgrid_id=microgrid_id, dispatch_id=dispatch_id + microgrid_id=microgrid_id, + dispatch_id=DispatchId(dispatch_id), ) if ctx.obj["raw"]: click.echo(pformat(dispatch, compact=True)) @@ -537,7 +542,7 @@ async def repl( @click.argument("dispatch_ids", type=FuzzyIntRange(), nargs=-1) # Allow multiple IDs @click.pass_context async def delete( - ctx: click.Context, microgrid_id: int, dispatch_ids: list[list[int]] + ctx: click.Context, microgrid_id: MicrogridId, dispatch_ids: list[list[int]] ) -> None: """Delete multiple dispatches. diff --git a/src/frequenz/client/dispatch/_client.py b/src/frequenz/client/dispatch/_client.py index ed6cc0ed..18d2db74 100644 --- a/src/frequenz/client/dispatch/_client.py +++ b/src/frequenz/client/dispatch/_client.py @@ -36,12 +36,14 @@ from frequenz.client.base.exception import ClientNotConnected from frequenz.client.base.retry import LinearBackoff from frequenz.client.base.streaming import GrpcStreamBroadcaster +from frequenz.client.common.microgrid import MicrogridId from ._internal_types import DispatchCreateRequest from .recurrence import RecurrenceRule from .types import ( Dispatch, DispatchEvent, + DispatchId, TargetComponents, _target_components_to_protobuf, ) @@ -83,7 +85,8 @@ def __init__( ) self._metadata = (("key", key),) self._streams: dict[ - int, GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent] + MicrogridId, + GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent], ] = {} """A dictionary of streamers, keyed by microgrid_id.""" @@ -114,7 +117,7 @@ def stub(self) -> dispatch_pb2_grpc.MicrogridDispatchServiceAsyncStub: # pylint: disable=too-many-arguments, too-many-locals async def list( self, - microgrid_id: int, + microgrid_id: MicrogridId, *, target_components: Iterator[TargetComponents] = iter(()), start_from: datetime | None = None, @@ -138,7 +141,7 @@ async def list( key="key", server_url="grpc://dispatch.url.goes.here.example.com" ) - async for page in client.list(microgrid_id=1): + async for page in client.list(microgrid_id=MicrogridId(1)): for dispatch in page: print(dispatch) ``` @@ -185,7 +188,7 @@ def to_interval( ) request = ListMicrogridDispatchesRequest( - microgrid_id=microgrid_id, + microgrid_id=int(microgrid_id), filter=filters, pagination_params=( PaginationParams(page_size=page_size) if page_size else None @@ -211,7 +214,7 @@ def to_interval( else: break - def stream(self, microgrid_id: int) -> channels.Receiver[DispatchEvent]: + def stream(self, microgrid_id: MicrogridId) -> channels.Receiver[DispatchEvent]: """Receive a stream of dispatch events. This function returns a receiver channel that can be used to receive @@ -238,7 +241,7 @@ def stream(self, microgrid_id: int) -> channels.Receiver[DispatchEvent]: return self._get_stream(microgrid_id).new_receiver() def _get_stream( - self, microgrid_id: int + self, microgrid_id: MicrogridId ) -> GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent]: """Get an instance to the streaming helper.""" broadcaster = self._streams.get(microgrid_id) @@ -246,7 +249,7 @@ def _get_stream( del self._streams[microgrid_id] broadcaster = None if broadcaster is None: - request = StreamMicrogridDispatchesRequest(microgrid_id=microgrid_id) + request = StreamMicrogridDispatchesRequest(microgrid_id=int(microgrid_id)) broadcaster = GrpcStreamBroadcaster( stream_name="StreamMicrogridDispatches", stream_method=lambda: cast( @@ -266,7 +269,7 @@ def _get_stream( async def create( # pylint: disable=too-many-positional-arguments self, - microgrid_id: int, + microgrid_id: MicrogridId, type: str, # pylint: disable=redefined-builtin start_time: datetime | Literal["NOW"], duration: timedelta | None, @@ -334,8 +337,8 @@ async def create( # pylint: disable=too-many-positional-arguments async def update( self, *, - microgrid_id: int, - dispatch_id: int, + microgrid_id: MicrogridId, + dispatch_id: DispatchId, new_fields: dict[str, Any], ) -> Dispatch: """Update a dispatch. @@ -359,7 +362,7 @@ async def update( ValueError: If updating `type` or `dry_run`. """ msg = UpdateMicrogridDispatchRequest( - dispatch_id=dispatch_id, microgrid_id=microgrid_id + dispatch_id=int(dispatch_id), microgrid_id=int(microgrid_id) ) for key, val in new_fields.items(): @@ -423,7 +426,9 @@ async def update( return Dispatch.from_protobuf(response.dispatch) - async def get(self, *, microgrid_id: int, dispatch_id: int) -> Dispatch: + async def get( + self, *, microgrid_id: MicrogridId, dispatch_id: DispatchId + ) -> Dispatch: """Get a dispatch. Args: @@ -434,7 +439,7 @@ async def get(self, *, microgrid_id: int, dispatch_id: int) -> Dispatch: Dispatch: The dispatch. """ request = GetMicrogridDispatchRequest( - dispatch_id=dispatch_id, microgrid_id=microgrid_id + dispatch_id=int(dispatch_id), microgrid_id=int(microgrid_id) ) response = await cast( Awaitable[GetMicrogridDispatchResponse], @@ -444,7 +449,9 @@ async def get(self, *, microgrid_id: int, dispatch_id: int) -> Dispatch: ) return Dispatch.from_protobuf(response.dispatch) - async def delete(self, *, microgrid_id: int, dispatch_id: int) -> None: + async def delete( + self, *, microgrid_id: MicrogridId, dispatch_id: DispatchId + ) -> None: """Delete a dispatch. Args: @@ -452,7 +459,7 @@ async def delete(self, *, microgrid_id: int, dispatch_id: int) -> None: dispatch_id: The dispatch_id to delete. """ request = DeleteMicrogridDispatchRequest( - dispatch_id=dispatch_id, microgrid_id=microgrid_id + dispatch_id=int(dispatch_id), microgrid_id=int(microgrid_id) ) await cast( Awaitable[None], diff --git a/src/frequenz/client/dispatch/_internal_types.py b/src/frequenz/client/dispatch/_internal_types.py index 568c6784..c2d8f1d0 100644 --- a/src/frequenz/client/dispatch/_internal_types.py +++ b/src/frequenz/client/dispatch/_internal_types.py @@ -20,6 +20,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from frequenz.client.base.conversion import to_datetime, to_timestamp +from frequenz.client.common.microgrid import MicrogridId from .recurrence import RecurrenceRule from .types import ( @@ -36,7 +37,7 @@ class DispatchCreateRequest: """Request to create a new dispatch.""" - microgrid_id: int + microgrid_id: MicrogridId """The identifier of the microgrid to which this dispatch belongs.""" type: str @@ -93,7 +94,7 @@ def from_protobuf( ) return DispatchCreateRequest( - microgrid_id=pb_object.microgrid_id, + microgrid_id=MicrogridId(pb_object.microgrid_id), type=pb_object.dispatch_data.type, start_time=( "NOW" @@ -118,7 +119,7 @@ def to_protobuf(self) -> PBDispatchCreateRequest: payload.update(self.payload) return PBDispatchCreateRequest( - microgrid_id=self.microgrid_id, + microgrid_id=int(self.microgrid_id), dispatch_data=DispatchData( type=self.type, start_time=( diff --git a/src/frequenz/client/dispatch/test/_service.py b/src/frequenz/client/dispatch/test/_service.py index 7bc755ef..f43ebc23 100644 --- a/src/frequenz/client/dispatch/test/_service.py +++ b/src/frequenz/client/dispatch/test/_service.py @@ -35,9 +35,10 @@ # pylint: enable=no-name-in-module from frequenz.client.base.conversion import to_datetime as _to_dt +from frequenz.client.common.microgrid import MicrogridId from .._internal_types import DispatchCreateRequest -from ..types import Dispatch, DispatchEvent, Event +from ..types import Dispatch, DispatchEvent, DispatchId, Event ALL_KEY = "all" """Key that has access to all resources in the FakeService.""" @@ -53,7 +54,7 @@ class FakeService: class StreamEvent: """Event for the stream.""" - microgrid_id: int + microgrid_id: MicrogridId """The microgrid id.""" event: DispatchEvent @@ -66,12 +67,21 @@ def __init__(self) -> None: ) self._stream_sender = self._stream_channel.new_sender() - self.dispatches: dict[int, list[Dispatch]] = {} + self.dispatches: dict[MicrogridId, list[Dispatch]] = {} """List of dispatches per microgrid.""" - self._last_id: int = 0 + self._last_id: DispatchId = DispatchId(0) """Last used dispatch id.""" + def refresh_last_id_for(self, microgrid_id: MicrogridId) -> None: + """Update last id to be the next highest number.""" + dispatches = self.dispatches.get(microgrid_id, []) + + if len(dispatches) == 0: + return + + self._last_id = max(self._last_id, max(dispatch.id for dispatch in dispatches)) + def _check_access(self, metadata: grpc.aio.Metadata) -> None: """Check if the access key is valid. @@ -129,7 +139,7 @@ async def ListMicrogridDispatches( """ self._check_access(metadata) - grid_dispatches = self.dispatches.get(request.microgrid_id, []) + grid_dispatches = self.dispatches.get(MicrogridId(request.microgrid_id), []) return ListMicrogridDispatchesResponse( dispatches=map( @@ -169,7 +179,7 @@ async def StreamMicrogridDispatches( async for message in receiver: logging.debug("Received message: %s", message) - if message.microgrid_id == request.microgrid_id: + if message.microgrid_id == MicrogridId(request.microgrid_id): response = StreamMicrogridDispatchesResponse( event=message.event.event.value, dispatch=message.event.dispatch.to_protobuf(), @@ -224,7 +234,8 @@ async def CreateMicrogridDispatch( ) -> CreateMicrogridDispatchResponse: """Create a new dispatch.""" self._check_access(metadata) - self._last_id += 1 + microgrid_id = MicrogridId(request.microgrid_id) + self._last_id = DispatchId(int(self._last_id) + 1) new_dispatch = _dispatch_from_request( DispatchCreateRequest.from_protobuf(request), @@ -234,11 +245,11 @@ async def CreateMicrogridDispatch( ) # implicitly create the list if it doesn't exist - self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch) + self.dispatches.setdefault(microgrid_id, []).append(new_dispatch) await self._stream_sender.send( self.StreamEvent( - request.microgrid_id, + microgrid_id, DispatchEvent(dispatch=new_dispatch, event=Event.CREATED), ) ) @@ -253,9 +264,15 @@ async def UpdateMicrogridDispatch( ) -> UpdateMicrogridDispatchResponse: """Update a dispatch.""" self._check_access(metadata) - grid_dispatches = self.dispatches[request.microgrid_id] + + microgrid_id = MicrogridId(request.microgrid_id) + grid_dispatches = self.dispatches.get(microgrid_id, []) index = next( - (i for i, d in enumerate(grid_dispatches) if d.id == request.dispatch_id), + ( + i + for i, d in enumerate(grid_dispatches) + if d.id == DispatchId(request.dispatch_id) + ), None, ) @@ -326,7 +343,7 @@ async def UpdateMicrogridDispatch( await self._stream_sender.send( self.StreamEvent( - request.microgrid_id, + microgrid_id, DispatchEvent(dispatch=dispatch, event=Event.UPDATED), ) ) @@ -341,9 +358,11 @@ async def GetMicrogridDispatch( ) -> GetMicrogridDispatchResponse: """Get a single dispatch.""" self._check_access(metadata) - grid_dispatches = self.dispatches.get(request.microgrid_id, []) + microgrid_id = MicrogridId(request.microgrid_id) + grid_dispatches = self.dispatches.get(microgrid_id, []) dispatch = next( - (d for d in grid_dispatches if d.id == request.dispatch_id), None + (d for d in grid_dispatches if d.id == DispatchId(request.dispatch_id)), + None, ) if dispatch is None: @@ -364,10 +383,12 @@ async def DeleteMicrogridDispatch( ) -> Empty: """Delete a given dispatch.""" self._check_access(metadata) - grid_dispatches = self.dispatches.get(request.microgrid_id, []) + microgrid_id = MicrogridId(request.microgrid_id) + grid_dispatches = self.dispatches.get(microgrid_id, []) dispatch_to_delete = next( - (d for d in grid_dispatches if d.id == request.dispatch_id), None + (d for d in grid_dispatches if d.id == DispatchId(request.dispatch_id)), + None, ) if dispatch_to_delete is None: @@ -382,7 +403,7 @@ async def DeleteMicrogridDispatch( await self._stream_sender.send( self.StreamEvent( - request.microgrid_id, + microgrid_id, DispatchEvent( dispatch=dispatch_to_delete, event=Event.DELETED, @@ -397,7 +418,7 @@ async def DeleteMicrogridDispatch( def _dispatch_from_request( _request: DispatchCreateRequest, - _id: int, + _id: DispatchId, create_time: datetime, update_time: datetime, ) -> Dispatch: diff --git a/src/frequenz/client/dispatch/test/client.py b/src/frequenz/client/dispatch/test/client.py index c4cc3241..7d739f38 100644 --- a/src/frequenz/client/dispatch/test/client.py +++ b/src/frequenz/client/dispatch/test/client.py @@ -5,6 +5,8 @@ from typing import Any +from frequenz.client.common.microgrid import MicrogridId + from .. import DispatchApiClient from ..types import Dispatch from ._service import ALL_KEY, NONE_KEY, FakeService @@ -34,7 +36,7 @@ def stub(self) -> FakeService: # type: ignore """ return self._stuba - def dispatches(self, microgrid_id: int) -> list[Dispatch]: + def dispatches(self, microgrid_id: MicrogridId) -> list[Dispatch]: """List of dispatches. Args: @@ -45,7 +47,7 @@ def dispatches(self, microgrid_id: int) -> list[Dispatch]: """ return self._service.dispatches.get(microgrid_id, []) - def set_dispatches(self, microgrid_id: int, value: list[Dispatch]) -> None: + def set_dispatches(self, microgrid_id: MicrogridId, value: list[Dispatch]) -> None: """Set the list of dispatches. Args: @@ -53,16 +55,7 @@ def set_dispatches(self, microgrid_id: int, value: list[Dispatch]) -> None: value: The list of dispatches to set. """ self._service.dispatches[microgrid_id] = value - - if len(value) == 0: - return - - # Max between last id and the max id in the list - # pylint: disable=protected-access - self._service._last_id = max( - self._service._last_id, max(dispatch.id for dispatch in value) - ) - # pylint: enable=protected-access + self._service.refresh_last_id_for(microgrid_id) @property def _service(self) -> FakeService: @@ -74,7 +67,7 @@ def _service(self) -> FakeService: return self._stuba -def to_create_params(microgrid_id: int, dispatch: Dispatch) -> dict[str, Any]: +def to_create_params(microgrid_id: MicrogridId, dispatch: Dispatch) -> dict[str, Any]: """Convert a dispatch to client.create parameters. Args: diff --git a/src/frequenz/client/dispatch/test/generator.py b/src/frequenz/client/dispatch/test/generator.py index 5bf054f5..ee77adee 100644 --- a/src/frequenz/client/dispatch/test/generator.py +++ b/src/frequenz/client/dispatch/test/generator.py @@ -13,6 +13,7 @@ from ..types import ( BatteryType, Dispatch, + DispatchId, EvChargerType, InverterType, TargetCategories, @@ -122,7 +123,7 @@ def generate_dispatch(self) -> Dispatch: ] return Dispatch( - id=self._last_id, + id=DispatchId(self._last_id), create_time=create_time, update_time=create_time + timedelta(seconds=self._rng.randint(0, 1000000)), type=str(self._rng.randint(0, 100_000)), diff --git a/src/frequenz/client/dispatch/types.py b/src/frequenz/client/dispatch/types.py index 62222d35..07d81a53 100644 --- a/src/frequenz/client/dispatch/types.py +++ b/src/frequenz/client/dispatch/types.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Any, Self, SupportsInt, TypeAlias, cast +from typing import Any, Self, SupportsInt, TypeAlias, cast, final # pylint: enable=no-name-in-module from frequenz.api.common.v1.microgrid.components.battery_pb2 import ( @@ -28,6 +28,7 @@ StreamMicrogridDispatchesResponse, ) from frequenz.api.dispatch.v1.dispatch_pb2 import TargetComponents as PBTargetComponents +from frequenz.core.id import BaseId from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct @@ -37,6 +38,11 @@ from .recurrence import Frequency, RecurrenceRule, Weekday +@final +class DispatchId(BaseId, str_prefix="DID"): + """A unique identifier for a dispatch.""" + + class EvChargerType(Enum): """Enum representing the type of EV charger.""" @@ -359,7 +365,7 @@ class TimeIntervalFilter: class Dispatch: # pylint: disable=too-many-instance-attributes """Represents a dispatch operation within a microgrid system.""" - id: int + id: DispatchId """The unique identifier for the dispatch.""" type: str @@ -534,7 +540,7 @@ def from_protobuf(cls, pb_object: PBDispatch) -> "Dispatch": The converted dispatch. """ return Dispatch( - id=pb_object.metadata.dispatch_id, + id=DispatchId(pb_object.metadata.dispatch_id), type=pb_object.data.type, create_time=to_datetime(pb_object.metadata.create_time), update_time=to_datetime(pb_object.metadata.update_time), @@ -567,7 +573,7 @@ def to_protobuf(self) -> PBDispatch: return PBDispatch( metadata=DispatchMetadata( - dispatch_id=self.id, + dispatch_id=int(self.id), create_time=to_timestamp(self.create_time), update_time=to_timestamp(self.update_time), end_time=( diff --git a/tests/test_cli.py b/tests/test_cli.py index 1321162a..627916b3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,6 +12,7 @@ from asyncclick.testing import CliRunner from tzlocal import get_localzone +from frequenz.client.common.microgrid import MicrogridId from frequenz.client.common.microgrid.components import ComponentCategory from frequenz.client.dispatch.__main__ import cli from frequenz.client.dispatch.recurrence import ( @@ -23,6 +24,7 @@ from frequenz.client.dispatch.test.client import ALL_KEY, FakeClient from frequenz.client.dispatch.types import ( Dispatch, + DispatchId, TargetCategories, TargetComponents, TargetIds, @@ -69,7 +71,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: { 1: [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -83,7 +85,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: ) ] }, - 1, + MicrogridId(1), "1 dispatches, 0 filtered out", 0, ), @@ -92,7 +94,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: { 2: [ Dispatch( - id=2, + id=DispatchId(2), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -106,7 +108,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: ) ] }, - 1, + MicrogridId(1), "0 dispatches, 0 filtered out", 0, ), @@ -114,7 +116,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: { 1: [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -129,7 +131,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: ], 2: [ Dispatch( - id=2, + id=DispatchId(2), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -143,7 +145,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: ), ], }, - 1, + MicrogridId(1), "1 dispatches, 0 filtered out", 0, ), @@ -157,7 +159,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: { 1: [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -170,7 +172,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: update_time=datetime(2023, 1, 1, 0, 0, 0), ), Dispatch( - id=2, + id=DispatchId(2), type="filtered", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=1800), @@ -184,7 +186,7 @@ def mock_client(fake_client: FakeClient) -> Generator[None, None, None]: ), ], }, - 1, + MicrogridId(1), "1 dispatches, 1 filtered out", 0, ), @@ -194,17 +196,21 @@ async def test_list_command( runner: CliRunner, fake_client: FakeClient, dispatches: dict[int, list[Dispatch]], - microgrid_id: int, + microgrid_id: MicrogridId, expected_output: str, expected_return_code: int, ) -> None: """Test the list command.""" for microgrid_id_, dispatch_list in dispatches.items(): - fake_client.set_dispatches(microgrid_id_, dispatch_list) + if isinstance(microgrid_id_, int): + fake_client.set_dispatches(MicrogridId(microgrid_id_), dispatch_list) + str_microgrid_id = str( + int(microgrid_id) if isinstance(microgrid_id, MicrogridId) else microgrid_id + ) result = await runner.invoke( cli, - ["--raw", "list", str(microgrid_id), "--type", "test"], + ["--raw", "list", str_microgrid_id, "--type", "test"], env=ENVIRONMENT_VARIABLES, ) assert expected_output in result.output @@ -228,7 +234,7 @@ async def test_list_command( "--active", "False", ], - 829, + MicrogridId(829), "test", timedelta(hours=1), timedelta(seconds=3600), @@ -248,7 +254,7 @@ async def test_list_command( "--dry-run", "true", ], - 1, + MicrogridId(1), "test", timedelta(hours=2), timedelta(seconds=3600), @@ -259,7 +265,7 @@ async def test_list_command( ), ( ["create", "x"], - 0, + MicrogridId(0), "", timedelta(), timedelta(), @@ -299,7 +305,7 @@ async def test_list_command( "--by-monthday", "17", ], - 1, + MicrogridId(1), "test", timedelta(hours=1), timedelta(seconds=3600), @@ -334,7 +340,7 @@ async def test_list_command( "--by-minute", "5", ], - 50, + MicrogridId(50), "test50", timedelta(hours=5), timedelta(seconds=3600), @@ -362,7 +368,7 @@ async def test_list_command( "now", "1h", ], - 1, + MicrogridId(1), "test_start_immediately", "NOW", timedelta(seconds=3600), @@ -377,7 +383,7 @@ async def test_create_command( runner: CliRunner, fake_client: FakeClient, args: list[str], - expected_microgrid_id: int, + expected_microgrid_id: MicrogridId, expected_type: str, expected_start_time_delta: timedelta | Literal["NOW"], expected_duration: timedelta, @@ -447,7 +453,7 @@ async def test_create_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -471,7 +477,7 @@ async def test_create_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -497,7 +503,7 @@ async def test_create_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -529,7 +535,7 @@ async def test_create_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -615,16 +621,16 @@ async def test_update_command( expected_output: str, ) -> None: """Test the update command.""" - fake_client.set_dispatches(1, dispatches) + fake_client.set_dispatches(MicrogridId(1), dispatches) result = await runner.invoke( cli, ["--raw", "update", "1", "1", *args], env=ENVIRONMENT_VARIABLES ) assert expected_output in result.output assert result.exit_code == expected_return_code if dispatches: - assert len(fake_client.dispatches(1)) == 1 + assert len(fake_client.dispatches(MicrogridId(1))) == 1 for key, value in fields.items(): - assert getattr(fake_client.dispatches(1)[0], key) == value + assert getattr(fake_client.dispatches(MicrogridId(1))[0], key) == value @pytest.mark.asyncio @@ -634,7 +640,7 @@ async def test_update_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -648,7 +654,7 @@ async def test_update_command( ) ], 1, - "Dispatch(id=1,", + "Dispatch(id=DispatchId(1),", ), ([], 999, "Error"), ( @@ -662,13 +668,16 @@ async def test_get_command( runner: CliRunner, fake_client: FakeClient, dispatches: list[Dispatch], - dispatch_id: int, + dispatch_id: DispatchId, expected_in_output: str, ) -> None: """Test the get command.""" - fake_client.set_dispatches(1, dispatches) + fake_client.set_dispatches(MicrogridId(1), dispatches) + str_dispatch_id = str( + int(dispatch_id) if isinstance(dispatch_id, DispatchId) else dispatch_id + ) result = await runner.invoke( - cli, ["--raw", "get", "1", str(dispatch_id)], env=ENVIRONMENT_VARIABLES + cli, ["--raw", "get", "1", str_dispatch_id], env=ENVIRONMENT_VARIABLES ) assert result.exit_code == 0 if dispatches else 1 assert expected_in_output in result.output @@ -681,7 +690,7 @@ async def test_get_command( ( [ Dispatch( - id=1, + id=DispatchId(1), type="test", start_time=datetime(2023, 1, 1, 0, 0, 0), duration=timedelta(seconds=3600), @@ -694,7 +703,7 @@ async def test_get_command( update_time=datetime(2023, 1, 1, 0, 0, 0), ) ], - 1, + DispatchId(1), "Dispatches deleted: [1]", 0, ), @@ -711,16 +720,19 @@ async def test_delete_command( runner: CliRunner, fake_client: FakeClient, dispatches: list[Dispatch], - dispatch_id: int, + dispatch_id: DispatchId, expected_output: str, expected_return_code: int, ) -> None: """Test the delete command.""" - fake_client.set_dispatches(1, dispatches) + str_dispatch_id = str( + int(dispatch_id) if isinstance(dispatch_id, DispatchId) else dispatch_id + ) + fake_client.set_dispatches(MicrogridId(1), dispatches) result = await runner.invoke( - cli, ["delete", "1", str(dispatch_id)], env=ENVIRONMENT_VARIABLES + cli, ["delete", "1", str_dispatch_id], env=ENVIRONMENT_VARIABLES ) assert result.exit_code == expected_return_code assert expected_output in result.output if dispatches: - assert len(fake_client.dispatches(1)) == 0 + assert len(fake_client.dispatches(MicrogridId(1))) == 0 diff --git a/tests/test_client.py b/tests/test_client.py index df80ca1a..26a6b236 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,10 +13,11 @@ import pytest from pytest import raises +from frequenz.client.common.microgrid import MicrogridId from frequenz.client.dispatch.test.client import FakeClient, to_create_params from frequenz.client.dispatch.test.fixtures import client, generator, sample from frequenz.client.dispatch.test.generator import DispatchGenerator -from frequenz.client.dispatch.types import Dispatch, Event +from frequenz.client.dispatch.types import Dispatch, DispatchId, Event # Ignore flake8 error in the rest of the file to use the same fixture names # flake8: noqa[811] @@ -46,7 +47,7 @@ def _update_metadata(dispatch: Dispatch, created: Dispatch) -> Dispatch: async def test_create_dispatch(client: FakeClient, sample: Dispatch) -> None: """Test creating a dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -57,7 +58,7 @@ async def test_create_return_dispatch( client: FakeClient, generator: DispatchGenerator ) -> None: """Test creating a dispatch and returning the created dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) for _ in range(100): sample = generator.generate_dispatch() @@ -70,7 +71,7 @@ async def test_create_return_dispatch( async def test_create_duration_none(client: FakeClient, sample: Dispatch) -> None: """Test creating a dispatch with a None duration.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) sample = replace(sample, duration=None) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -79,7 +80,7 @@ async def test_create_duration_none(client: FakeClient, sample: Dispatch) -> Non async def test_create_duration_0(client: FakeClient, sample: Dispatch) -> None: """Test creating a dispatch with a 0 duration.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) sample = replace(sample, duration=timedelta(minutes=0)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -90,21 +91,21 @@ async def test_list_dispatches( client: FakeClient, generator: DispatchGenerator ) -> None: """Test listing dispatches.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) client.set_dispatches( microgrid_id=microgrid_id, value=[generator.generate_dispatch() for _ in range(100)], ) - dispatches = client.list(microgrid_id=1) + dispatches = client.list(microgrid_id=MicrogridId(1)) async for page in dispatches: for dispatch in page: # First find matching id in client.dispatches, then compare service_side_dispatch = next( filter( partial(lambda d, md: d.id == md.id, md=dispatch), - client.dispatches(microgrid_id=1), + client.dispatches(microgrid_id=MicrogridId(1)), ), None, ) @@ -116,7 +117,7 @@ async def test_list_dispatches_no_duration( client: FakeClient, generator: DispatchGenerator ) -> None: """Test listing dispatches with a None duration.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) client.set_dispatches( microgrid_id=microgrid_id, @@ -125,17 +126,17 @@ async def test_list_dispatches_no_duration( ], ) - dispatches = client.list(microgrid_id=1) + dispatches = client.list(microgrid_id=MicrogridId(1)) async for page in dispatches: for dispatch in page: - assert dispatch in client.dispatches(microgrid_id=1) + assert dispatch in client.dispatches(microgrid_id=MicrogridId(1)) async def test_list_create_dispatches( client: FakeClient, generator: DispatchGenerator ) -> None: """Test listing dispatches.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) # Test with empty list page = await anext(client.list(microgrid_id=microgrid_id)) @@ -161,7 +162,7 @@ async def test_list_create_dispatches( async def test_update_dispatch(client: FakeClient, sample: Dispatch) -> None: """Test updating a dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -179,7 +180,7 @@ async def test_update_dispatch_to_no_duration( client: FakeClient, sample: Dispatch ) -> None: """Test updating the duration field of a dispatch to None.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) client.set_dispatches( microgrid_id=microgrid_id, value=[replace(sample, duration=timedelta(minutes=10))], @@ -197,7 +198,7 @@ async def test_update_dispatch_to_0_duration( client: FakeClient, sample: Dispatch ) -> None: """Test updating the duration field of a dispatch to 0.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) client.set_dispatches( microgrid_id=microgrid_id, value=[replace(sample, duration=timedelta(minutes=10))], @@ -215,7 +216,7 @@ async def test_update_dispatch_from_no_duration( client: FakeClient, sample: Dispatch ) -> None: """Test updating the duration field of a dispatch from None.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) client.set_dispatches( microgrid_id=microgrid_id, value=[replace(sample, duration=None)] ) @@ -230,7 +231,7 @@ async def test_update_dispatch_from_no_duration( async def test_update_dispatch_fail(client: FakeClient, sample: Dispatch) -> None: """Test updating the type and dry_run fields of a dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) assert dispatch is not None @@ -253,7 +254,7 @@ async def test_update_dispatch_fail(client: FakeClient, sample: Dispatch) -> Non async def test_get_dispatch(client: FakeClient, sample: Dispatch) -> None: """Test getting a dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -266,7 +267,7 @@ async def test_get_dispatch(client: FakeClient, sample: Dispatch) -> None: async def test_get_dispatch_no_duration(client: FakeClient, sample: Dispatch) -> None: """Test getting a dispatch with a None duration.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) sample = replace(sample, duration=None) dispatch = await client.create(**to_create_params(microgrid_id, sample)) @@ -281,12 +282,12 @@ async def test_get_dispatch_no_duration(client: FakeClient, sample: Dispatch) -> async def test_get_dispatch_fail(client: FakeClient) -> None: """Test getting a non-existent dispatch.""" with raises(grpc.RpcError): - await client.get(microgrid_id=1, dispatch_id=1) + await client.get(microgrid_id=MicrogridId(1), dispatch_id=DispatchId(1)) async def test_delete_dispatch(client: FakeClient, sample: Dispatch) -> None: """Test deleting a dispatch.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatch = await client.create(**to_create_params(microgrid_id, sample)) sample = _update_metadata(sample, dispatch) @@ -300,7 +301,7 @@ async def test_delete_dispatch(client: FakeClient, sample: Dispatch) -> None: async def test_delete_dispatch_fail(client: FakeClient) -> None: """Test deleting a non-existent dispatch.""" with raises(grpc.RpcError): - await client.delete(microgrid_id=1, dispatch_id=1) + await client.delete(microgrid_id=MicrogridId(1), dispatch_id=DispatchId(1)) @pytest.mark.parametrize("call_twice", [True, False]) @@ -308,7 +309,7 @@ async def test_dispatch_stream( client: FakeClient, sample: Dispatch, call_twice: bool ) -> None: """Test dispatching a stream of dispatches.""" - microgrid_id = random.randint(1, 100) + microgrid_id = MicrogridId(random.randint(1, 100)) dispatches = [sample, sample, sample] stream = client.stream(microgrid_id) diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 0d7e83df..8baf3241 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -11,7 +11,12 @@ from frequenz.client.common.microgrid.components import ComponentCategory from frequenz.client.dispatch.recurrence import Frequency, RecurrenceRule, Weekday -from frequenz.client.dispatch.types import Dispatch, TargetCategories, TargetIds +from frequenz.client.dispatch.types import ( + Dispatch, + DispatchId, + TargetCategories, + TargetIds, +) # Define a fixed current time for testing to avoid issues with datetime.now() CURRENT_TIME = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) @@ -21,7 +26,7 @@ def dispatch_base() -> Dispatch: """Fixture to create a base Dispatch instance.""" return Dispatch( - id=1, + id=DispatchId(1), type="TypeA", start_time=CURRENT_TIME, duration=timedelta(minutes=20), diff --git a/tests/test_proto.py b/tests/test_proto.py index 70a65423..e6277bc0 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta, timezone +from frequenz.client.common.microgrid import MicrogridId from frequenz.client.common.microgrid.components import ComponentCategory from frequenz.client.dispatch._internal_types import DispatchCreateRequest from frequenz.client.dispatch.recurrence import ( @@ -16,6 +17,7 @@ from frequenz.client.dispatch.types import ( BatteryType, Dispatch, + DispatchId, EvChargerType, InverterType, TargetCategories, @@ -108,7 +110,7 @@ def test_dispatch() -> None: """Test the dispatch.""" for dispatch in ( Dispatch( - id=123, + id=DispatchId(123), type="test", create_time=datetime(2023, 1, 1, tzinfo=timezone.utc), update_time=datetime(2023, 1, 1, tzinfo=timezone.utc), @@ -129,7 +131,7 @@ def test_dispatch() -> None: ), ), Dispatch( - id=124, + id=DispatchId(124), type="test-2", create_time=datetime(2024, 3, 10, tzinfo=timezone.utc), update_time=datetime(2024, 3, 11, tzinfo=timezone.utc), @@ -149,7 +151,7 @@ def test_dispatch() -> None: ), ), Dispatch( - id=125, + id=DispatchId(125), type="test-3", create_time=datetime(2024, 3, 10, tzinfo=timezone.utc), update_time=datetime(2024, 3, 11, tzinfo=timezone.utc), @@ -175,7 +177,7 @@ def test_dispatch() -> None: def test_dispatch_create_request_with_no_recurrence() -> None: """Test the dispatch create request with no recurrence.""" request = DispatchCreateRequest( - microgrid_id=123, + microgrid_id=MicrogridId(123), type="test", start_time=datetime(2024, 10, 10, tzinfo=timezone.utc), duration=timedelta(days=10), @@ -192,7 +194,7 @@ def test_dispatch_create_request_with_no_recurrence() -> None: def test_dispatch_create_start_immediately() -> None: """Test the dispatch create request with no start time.""" request = DispatchCreateRequest( - microgrid_id=123, + microgrid_id=MicrogridId(123), type="test", start_time="NOW", duration=timedelta(days=10),