From 9b5983d506a020b0361bc03d218196602aa3c857 Mon Sep 17 00:00:00 2001 From: devblin Date: Thu, 28 Sep 2023 15:14:55 +0530 Subject: [PATCH] feat(subscriber): #59 - add bulk create for subscribers --- novu/api/base.py | 4 +- novu/api/subscriber.py | 18 ++++++ novu/dto/__init__.py | 2 + novu/dto/subscriber.py | 20 +++++++ tests/api/test_subscriber.py | 107 +++++++++++++++++++++++++++++++++++ 5 files changed, 149 insertions(+), 2 deletions(-) diff --git a/novu/api/base.py b/novu/api/base.py index 57ba3fb0..491f01be 100644 --- a/novu/api/base.py +++ b/novu/api/base.py @@ -3,7 +3,7 @@ import logging import os from json.decoder import JSONDecodeError -from typing import Generic, List, Optional, Type, TypeVar +from typing import Generic, List, Optional, Type, TypeVar, Union import requests @@ -110,7 +110,7 @@ def handle_request( self, method: str, url: str, - json: Optional[dict] = None, + json: Optional[Union[dict, list]] = None, payload: Optional[dict] = None, headers: Optional[dict] = None, **kwargs, diff --git a/novu/api/subscriber.py b/novu/api/subscriber.py index 664972fb..e60956ba 100644 --- a/novu/api/subscriber.py +++ b/novu/api/subscriber.py @@ -8,6 +8,7 @@ from novu.api.base import Api, PaginationIterator from novu.constants import SUBSCRIBERS_ENDPOINT from novu.dto.subscriber import ( + BulkResultSubscriberDto, PaginatedSubscriberDto, SubscriberDto, SubscriberPreferenceDto, @@ -65,6 +66,23 @@ def create(self, subscriber: SubscriberDto) -> SubscriberDto: self.handle_request("POST", self._subscriber_url, subscriber.to_camel_case()).get("data", {}) ) + def bulk_create(self, subscribers: Iterator[SubscriberDto]) -> BulkResultSubscriberDto: + """Using this endpoint you can create multiple subscribers at once, to avoid multiple calls to the API. + + The bulk API is limited to 500 subscribers per request. + + Args: + subscribers: An iterator of subscribers instance to push to Novu + + Returns: + Result of the bulk operation in Novu. + """ + return BulkResultSubscriberDto.from_camel_case( + self.handle_request( + "POST", f"{self._subscriber_url}/bulk", [subscriber.to_camel_case() for subscriber in subscribers] + ).get("data", {}) + ) + def get(self, subscriber_id: str) -> SubscriberDto: """Method to get a subscriber using his identifier diff --git a/novu/dto/__init__.py b/novu/dto/__init__.py index 882535ea..600be966 100644 --- a/novu/dto/__init__.py +++ b/novu/dto/__init__.py @@ -44,6 +44,7 @@ ) from novu.dto.step_filter import StepFilterDto from novu.dto.subscriber import ( + BulkResultSubscriberDto, PaginatedSubscriberDto, SubscriberChannelSettingsCredentialsDto, SubscriberChannelSettingsDto, @@ -58,6 +59,7 @@ __all__ = [ "BlueprintDto", + "BulkResultSubscriberDto", "ChangeDetailDto", "ChangeDto", "EnvironmentApiKeyDto", diff --git a/novu/dto/subscriber.py b/novu/dto/subscriber.py index e6872400..fd92f662 100644 --- a/novu/dto/subscriber.py +++ b/novu/dto/subscriber.py @@ -179,3 +179,23 @@ class PaginatedSubscriberDto(CamelCaseDto["PaginatedSubscriberDto"]): default_factory=list, item_cls=SubscriberDto ) """Data""" + + +@dataclasses.dataclass +class BulkResultSubscriberDto(CamelCaseDto["BulkResultSubscriberDto"]): + """Definition of paginated subscribers""" + + created: DtoIterableDescriptor[SubscriberDto] = DtoIterableDescriptor[SubscriberDto]( + default_factory=list, item_cls=SubscriberDto + ) + """List of subscribers that were created during the operation.""" + + updated: DtoIterableDescriptor[SubscriberDto] = DtoIterableDescriptor[SubscriberDto]( + default_factory=list, item_cls=SubscriberDto + ) + """List of subscribers that were updated during the operation.""" + + failed: DtoIterableDescriptor[SubscriberDto] = DtoIterableDescriptor[SubscriberDto]( + default_factory=list, item_cls=SubscriberDto + ) + """List of subscribers whose creation (or update) failed.""" diff --git a/tests/api/test_subscriber.py b/tests/api/test_subscriber.py index 227cc980..a2945999 100644 --- a/tests/api/test_subscriber.py +++ b/tests/api/test_subscriber.py @@ -5,6 +5,7 @@ from novu.api.base import PaginationIterator from novu.config import NovuConfig from novu.dto.subscriber import ( + BulkResultSubscriberDto, PaginatedSubscriberDto, SubscriberChannelSettingsCredentialsDto, SubscriberChannelSettingsDto, @@ -161,6 +162,112 @@ def test_create_subscriber(self, mock_request: mock.MagicMock) -> None: timeout=5, ) + @mock.patch("requests.request") + def test_bulk_create_subscribers(self, mock_request: mock.MagicMock) -> None: + mock_request.return_value = MockResponse( + 201, + { + "data": { + "created": [ + { + "_organizationId": None, + "_environmentId": None, + "firstName": None, + "lastName": None, + "phone": None, + "subscriberId": "subscriber-id", + "email": "subscriber@sample.com", + "avatar": None, + "locale": None, + "channels": [], + "_id": "63e2cc7151af34c4b2f2b5d1", + "deleted": None, + "__v": 0, + "id": "63e2cc7151af34c4b2f2b5d1", + }, + { + "_organizationId": None, + "_environmentId": None, + "firstName": None, + "lastName": None, + "phone": None, + "subscriberId": "subscriber1-id", + "email": "subscriber1@sample.com", + "avatar": None, + "locale": None, + "channels": [], + "_id": "63e2cc7151af34c4b2f2b5d2", + "deleted": None, + "__v": 0, + "id": "63e2cc7151af34c4b2f2b5d2", + }, + ], + "updated": [], + "failed": [], + } + }, + ) + + subscribers = [ + SubscriberDto(subscriber_id="subscriber-id", email="subscriber@sample.com"), + SubscriberDto(subscriber_id="subscriber1-id", email="subscriber1@sample.com"), + ] + + res = self.api.bulk_create(subscribers) + + self.assertIsInstance(res, BulkResultSubscriberDto) + self.assertEqual( + res, + BulkResultSubscriberDto( + created=[ + SubscriberDto( + subscriber_id="subscriber-id", + email="subscriber@sample.com", + _id="63e2cc7151af34c4b2f2b5d1", + channels=[], + ), + SubscriberDto( + subscriber_id="subscriber1-id", + email="subscriber1@sample.com", + _id="63e2cc7151af34c4b2f2b5d2", + channels=[], + ), + ], + updated=[], + failed=[], + ), + ) + + mock_request.assert_called_once_with( + method="POST", + url="sample.novu.com/v1/subscribers/bulk", + headers={"Authorization": "ApiKey api-key"}, + json=[ + { + "subscriberId": "subscriber-id", + "email": "subscriber@sample.com", + "firstName": None, + "lastName": None, + "phone": None, + "avatar": None, + "locale": None, + "channels": None, + }, + { + "subscriberId": "subscriber1-id", + "email": "subscriber1@sample.com", + "firstName": None, + "lastName": None, + "phone": None, + "avatar": None, + "locale": None, + "channels": None, + }, + ], + params=None, + timeout=5, + ) + @mock.patch("requests.request") def test_get_subscriber(self, mock_request: mock.MagicMock) -> None: mock_request.return_value = MockResponse(200, self.response_get)