From 9ec88235c050f1499ccf13184d492e6c24149bca Mon Sep 17 00:00:00 2001 From: Samuel Gratzl Date: Mon, 2 Aug 2021 09:58:16 -0400 Subject: [PATCH 1/2] feat: adapt for required api_key --- delphi_epidata/_endpoints.py | 82 ++++++++++++++------------------ delphi_epidata/_model.py | 6 ++- delphi_epidata/async_requests.py | 45 ++++++++++-------- delphi_epidata/requests.py | 42 +++++++++------- smoke_test.py | 7 +-- smoke_test_async.py | 3 +- 6 files changed, 97 insertions(+), 88 deletions(-) diff --git a/delphi_epidata/_endpoints.py b/delphi_epidata/_endpoints.py index 43f44e4..a7e1b12 100644 --- a/delphi_epidata/_endpoints.py +++ b/delphi_epidata/_endpoints.py @@ -33,11 +33,11 @@ def _create_call( ) -> CALL_TYPE: raise NotImplementedError() - def afhsb(self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, flu_types: StringParam) -> CALL_TYPE: + def afhsb(self, locations: StringParam, epiweeks: EpiRangeParam, flu_types: StringParam) -> CALL_TYPE: """Fetch AFHSB data (point data, no min/max).""" - if auth is None or locations is None or epiweeks is None or flu_types is None: - raise InvalidArgumentException("`auth`, `locations`, `epiweeks` and `flu_types` are all required") + if locations is None or epiweeks is None or flu_types is None: + raise InvalidArgumentException("`locations`, `epiweeks` and `flu_types` are all required") loc_exception = ( "Location parameter `{}` is invalid. Valid `location` parameters are: " @@ -67,7 +67,7 @@ def afhsb(self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, flu_ return self._create_call( "afhsb/", - dict(auth=auth, locations=locations, epiweeks=epiweeks, flu_types=flu_types), + dict(locations=locations, epiweeks=epiweeks, flu_types=flu_types), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("flu_type", EpidataFieldType.text), @@ -76,15 +76,15 @@ def afhsb(self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, flu_ ], ) - def cdc(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: + def cdc(self, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: """Fetch CDC page hits.""" - if auth is None or epiweeks is None or locations is None: - raise InvalidArgumentException("`auth`, `epiweeks`, and `locations` are all required") + if epiweeks is None or locations is None: + raise InvalidArgumentException("`epiweeks`, and `locations` are all required") return self._create_call( "cdc/", - dict(auth=auth, epiweeks=epiweeks, locations=locations), + dict(epiweeks=epiweeks, locations=locations), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -500,17 +500,15 @@ def dengue_nowcast(self, locations: StringParam, epiweeks: EpiRangeParam) -> CAL ], ) - def dengue_sensors( - self, auth: str, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam - ) -> CALL_TYPE: + def dengue_sensors(self, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch Delphi's digital surveillance sensors.""" - if auth is None or names is None or locations is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `names`, `locations`, and `epiweeks` are all required") + if names is None or locations is None or epiweeks is None: + raise InvalidArgumentException("`names`, `locations`, and `epiweeks` are all required") return self._create_call( "dengue_sensors/", - dict(auth=auth, names=names, locations=locations, epiweeks=epiweeks), + dict(names=names, locations=locations, epiweeks=epiweeks), [ EpidataFieldInfo("name", EpidataFieldType.text), EpidataFieldInfo("location", EpidataFieldType.text), @@ -623,7 +621,6 @@ def fluview( epiweeks: EpiRangeParam, issues: Optional[EpiRangeParam] = None, lag: Optional[int] = None, - auth: Optional[str] = None, ) -> CALL_TYPE: if regions is None or epiweeks is None: raise InvalidArgumentException("`regions` and `epiweeks` are both required") @@ -631,7 +628,7 @@ def fluview( raise InvalidArgumentException("`issues` and `lag` are mutually exclusive") return self._create_call( "fluview/", - dict(regions=regions, epiweeks=epiweeks, issues=issues, lag=lag, auth=auth), + dict(regions=regions, epiweeks=epiweeks, issues=issues, lag=lag), [ EpidataFieldInfo("release_date", EpidataFieldType.text), EpidataFieldInfo("region", EpidataFieldType.text), @@ -665,13 +662,13 @@ def gft(self, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: ], ) - def ght(self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, query: str) -> CALL_TYPE: + def ght(self, locations: StringParam, epiweeks: EpiRangeParam, query: str) -> CALL_TYPE: """Fetch Google Health Trends data.""" - if auth is None or locations is None or epiweeks is None or query is None: - raise InvalidArgumentException("`auth`, `locations`, `epiweeks`, and `query` are all required") + if locations is None or epiweeks is None or query is None: + raise InvalidArgumentException("`locations`, `epiweeks`, and `query` are all required") return self._create_call( "ght/", - dict(auth=auth, locations=locations, epiweeks=epiweeks, query=query), + dict(locations=locations, epiweeks=epiweeks, query=query), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -704,25 +701,19 @@ def kcdc_ili( ], ) - def meta_afhsb(self, auth: str) -> CALL_TYPE: + def meta_afhsb(self) -> CALL_TYPE: """Fetch AFHSB metadata.""" - - if auth is None: - raise InvalidArgumentException("`auth` is required") - return self._create_call( "meta_afhsb/", - dict(auth=auth), + dict(), ) - def meta_norostat(self, auth: str) -> CALL_TYPE: + def meta_norostat(self) -> CALL_TYPE: """Fetch NoroSTAT metadata.""" - if auth is None: - raise InvalidArgumentException("`auth` is required") return self._create_call( "meta_norostat/", - dict(auth=auth), + dict(), ) def meta(self) -> CALL_TYPE: @@ -776,14 +767,14 @@ def nidss_flu( ], ) - def norostat(self, auth: str, location: str, epiweeks: EpiRangeParam) -> CALL_TYPE: + def norostat(self, location: str, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch NoroSTAT data (point data, no min/max).""" - if auth is None or location is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `location`, and `epiweeks` are all required") + if location is None or epiweeks is None: + raise InvalidArgumentException("`location`, and `epiweeks` are all required") return self._create_call( "norostat/", - dict(auth=auth, epiweeks=epiweeks, location=location), + dict(epiweeks=epiweeks, location=location), [ EpidataFieldInfo("release_date", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -839,15 +830,15 @@ def paho_dengue( ], ) - def quidel(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: + def quidel(self, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: """Fetch Quidel data.""" - if auth is None or epiweeks is None or locations is None: - raise InvalidArgumentException("`auth`, `epiweeks`, and `locations` are all required") + if epiweeks is None or locations is None: + raise InvalidArgumentException("`epiweeks`, and `locations` are all required") return self._create_call( "quidel/", - dict(auth=auth, epiweeks=epiweeks, locations=locations), + dict(epiweeks=epiweeks, locations=locations), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -855,14 +846,14 @@ def quidel(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) -> ], ) - def sensors(self, auth: str, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: + def sensors(self, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch Delphi's digital surveillance sensors.""" - if auth is None or names is None or locations is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `names`, `locations`, and `epiweeks` are all required") + if names is None or locations is None or epiweeks is None: + raise InvalidArgumentException("`names`, `locations`, and `epiweeks` are all required") return self._create_call( "sensors/", - dict(auth=auth, names=names, locations=locations, epiweeks=epiweeks), + dict(names=names, locations=locations, epiweeks=epiweeks), [ EpidataFieldInfo("name", EpidataFieldType.text), EpidataFieldInfo("location", EpidataFieldType.text), @@ -873,20 +864,19 @@ def sensors(self, auth: str, names: StringParam, locations: StringParam, epiweek def twitter( self, - auth: str, locations: StringParam, dates: Optional[EpiRangeParam] = None, epiweeks: Optional[EpiRangeParam] = None, ) -> CALL_TYPE: """Fetch HealthTweets data.""" - if auth is None or locations is None: - raise InvalidArgumentException("`auth` and `locations` are both required") + if locations is None: + raise InvalidArgumentException("`locations` is required") if not (dates is None) ^ (epiweeks is None): raise InvalidArgumentException("exactly one of `dates` and `epiweeks` is required") return self._create_call( "twitter/", - dict(auth=auth, locations=locations, dates=dates, epiweeks=epiweeks), + dict(locations=locations, dates=dates, epiweeks=epiweeks), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("date", EpidataFieldType.date) diff --git a/delphi_epidata/_model.py b/delphi_epidata/_model.py index 54c63dc..09e4c72 100644 --- a/delphi_epidata/_model.py +++ b/delphi_epidata/_model.py @@ -120,6 +120,7 @@ class AEpiDataCall: """ _base_url: Final[str] + _api_key: Final[str] _endpoint: Final[str] _params: Final[Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]]] meta: Final[Sequence[EpidataFieldInfo]] @@ -128,11 +129,13 @@ class AEpiDataCall: def __init__( self, base_url: str, + api_key: str, endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, ) -> None: self._base_url = base_url + self._api_key = api_key self._endpoint = endpoint self._params = params self.meta = meta or [] @@ -149,6 +152,7 @@ def _formatted_paramters( all_params["format"] = format_type if fields: all_params["fields"] = fields + all_params["token"] = self._api_key return {k: format_list(v) for k, v in all_params.items() if v is not None} def request_arguments( @@ -182,7 +186,7 @@ def request_url( u, p = self.request_arguments(format_type, fields) query = urlencode(p) if query: - return f"{u}?{query}" + return f"{u}?{query}&token={self._api_key}" return u def __repr__(self) -> str: diff --git a/delphi_epidata/async_requests.py b/delphi_epidata/async_requests.py index c8c788d..cc2f52f 100644 --- a/delphi_epidata/async_requests.py +++ b/delphi_epidata/async_requests.py @@ -17,6 +17,7 @@ from asyncio import get_event_loop, gather from aiohttp import TCPConnector, ClientSession, ClientResponse +from aiohttp.helpers import BasicAuth from pandas import DataFrame from ._model import EpiRangeLike, AEpiDataCall, EpiDataFormatType, EpiDataResponse, EpiRange, EpidataFieldInfo @@ -25,12 +26,12 @@ async def _async_request( - url: str, params: Mapping[str, str], session: Optional[ClientSession] = None + url: str, params: Mapping[str, str], api_key: str, session: Optional[ClientSession] = None ) -> ClientResponse: async def call_impl(s: ClientSession) -> ClientResponse: - res = await s.get(url, params=params, headers=HTTP_HEADERS) + res = await s.get(url, params=params, headers=HTTP_HEADERS, auth=BasicAuth("epidata", api_key)) if res.status == 414: - return await s.post(url, params=params, headers=HTTP_HEADERS) + return await s.post(url, params=params, headers=HTTP_HEADERS, auth=BasicAuth("epidata", api_key)) return res if session: @@ -50,19 +51,23 @@ class EpiDataAsyncCall(AEpiDataCall): def __init__( self, base_url: str, + api_key: str, session: Optional[ClientSession], endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, ) -> None: - super().__init__(base_url, endpoint, params, meta) + super().__init__(base_url, api_key, endpoint, params, meta) self._session = session def with_base_url(self, base_url: str) -> "EpiDataAsyncCall": - return EpiDataAsyncCall(base_url, self._session, self._endpoint, self._params) + return EpiDataAsyncCall(base_url, self._api_key, self._session, self._endpoint, self._params) def with_session(self, session: ClientSession) -> "EpiDataAsyncCall": - return EpiDataAsyncCall(self._base_url, session, self._endpoint, self._params) + return EpiDataAsyncCall(self._base_url, self._api_key, session, self._endpoint, self._params) + + def with_api_key(self, api_key: str) -> "EpiDataAsyncCall": + return EpiDataAsyncCall(self._base_url, api_key, self._session, self._endpoint, self._params) async def _call( self, @@ -70,7 +75,7 @@ async def _call( fields: Optional[Iterable[str]] = None, ) -> ClientResponse: url, params = self.request_arguments(format_type, fields) - return await _async_request(url, params, self._session) + return await _async_request(url, params, self._api_key, self._session) async def classic(self, fields: Optional[Iterable[str]] = None) -> EpiDataResponse: """Request and parse epidata in CLASSIC message format.""" @@ -123,24 +128,29 @@ async def __(self) -> AsyncGenerator[Mapping[str, Union[str, int, float, date, N return self.iter() -class EpiDataAsyncContext(AEpiDataEndpoints[EpiDataAsyncCall]): +class Epidata(AEpiDataEndpoints[EpiDataAsyncCall]): """ - sync epidata call class + async epidata call class """ _base_url: Final[str] + _api_key: Final[str] _session: Final[Optional[ClientSession]] - def __init__(self, base_url: str = BASE_URL, session: Optional[ClientSession] = None) -> None: + def __init__(self, api_key: str, base_url: str = BASE_URL, session: Optional[ClientSession] = None) -> None: super().__init__() + self._api_key = api_key self._base_url = base_url self._session = session - def with_base_url(self, base_url: str) -> "EpiDataAsyncContext": - return EpiDataAsyncContext(base_url, self._session) + def with_base_url(self, base_url: str) -> "Epidata": + return Epidata(self._api_key, base_url, self._session) + + def with_session(self, session: ClientSession) -> "Epidata": + return Epidata(self._api_key, self._base_url, session) - def with_session(self, session: ClientSession) -> "EpiDataAsyncContext": - return EpiDataAsyncContext(self._base_url, session) + def with_api_key(self, api_key: str) -> "Epidata": + return Epidata(api_key, self._base_url, self._session) def _create_call( self, @@ -148,7 +158,7 @@ def _create_call( params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, ) -> EpiDataAsyncCall: - return EpiDataAsyncCall(self._base_url, self._session, endpoint, params, meta) + return EpiDataAsyncCall(self._base_url, self._api_key, self._session, endpoint, params, meta) @staticmethod def all( @@ -216,7 +226,4 @@ def call_api(call: EpiDataAsyncCall, session: ClientSession) -> Coroutine: return self.all(calls, call_api, batch_size) -Epidata = EpiDataAsyncContext() - - -__all__ = ["Epidata", "EpiDataAsyncCall", "EpiDataAsyncContext", "EpiRange"] +__all__ = ["Epidata", "EpiDataAsyncCall", "EpiRange"] diff --git a/delphi_epidata/requests.py b/delphi_epidata/requests.py index 22c3730..d7358fe 100644 --- a/delphi_epidata/requests.py +++ b/delphi_epidata/requests.py @@ -13,14 +13,14 @@ @retry(reraise=True, stop=stop_after_attempt(2)) def _request_with_retry( - url: str, params: Mapping[str, str], session: Optional[Session] = None, stream: bool = False + url: str, params: Mapping[str, str], api_key: str, session: Optional[Session] = None, stream: bool = False ) -> Response: """Make request with a retry if an exception is thrown.""" def call_impl(s: Session) -> Response: - res = s.get(url, params=params, headers=HTTP_HEADERS, stream=stream) + res = s.get(url, params=params, headers=HTTP_HEADERS, stream=stream, auth=("epidata", api_key)) if res.status_code == 414: - return s.post(url, params=params, headers=HTTP_HEADERS, stream=stream) + return s.post(url, params=params, headers=HTTP_HEADERS, stream=stream, auth=("epidata", api_key)) return res if session: @@ -40,19 +40,23 @@ class EpiDataCall(AEpiDataCall): def __init__( self, base_url: str, + api_key: str, session: Optional[Session], endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, ) -> None: - super().__init__(base_url, endpoint, params, meta) + super().__init__(base_url, api_key, endpoint, params, meta) self._session = session def with_base_url(self, base_url: str) -> "EpiDataCall": - return EpiDataCall(base_url, self._session, self._endpoint, self._params) + return EpiDataCall(base_url, self._api_key, self._session, self._endpoint, self._params) def with_session(self, session: Session) -> "EpiDataCall": - return EpiDataCall(self._base_url, session, self._endpoint, self._params) + return EpiDataCall(self._base_url, self._api_key, session, self._endpoint, self._params) + + def with_api_key(self, api_key: str) -> "EpiDataCall": + return EpiDataCall(self._base_url, api_key, self._session, self._endpoint, self._params) def _call( self, @@ -61,7 +65,7 @@ def _call( stream: bool = False, ) -> Response: url, params = self.request_arguments(format_type, fields) - return _request_with_retry(url, params, self._session, stream) + return _request_with_retry(url, params, self._api_key, self._session, stream) def classic(self, fields: Optional[Iterable[str]] = None) -> EpiDataResponse: """Request and parse epidata in CLASSIC message format.""" @@ -110,24 +114,29 @@ def __iter__(self) -> Generator[Mapping[str, Union[str, int, float, date, None]] return self.iter() -class EpiDataContext(AEpiDataEndpoints[EpiDataCall]): +class Epidata(AEpiDataEndpoints[EpiDataCall]): """ sync epidata call class """ _base_url: Final[str] + _api_key: Final[str] _session: Final[Optional[Session]] - def __init__(self, base_url: str = BASE_URL, session: Optional[Session] = None) -> None: + def __init__(self, api_key: str, base_url: str = BASE_URL, session: Optional[Session] = None) -> None: super().__init__() + self._api_key = api_key self._base_url = base_url self._session = session - def with_base_url(self, base_url: str) -> "EpiDataContext": - return EpiDataContext(base_url, self._session) + def with_base_url(self, base_url: str) -> "Epidata": + return Epidata(self._api_key, base_url, self._session) + + def with_session(self, session: Session) -> "Epidata": + return Epidata(self._api_key, self._base_url, session) - def with_session(self, session: Session) -> "EpiDataContext": - return EpiDataContext(self._base_url, session) + def with_api_key(self, api_key: str) -> "Epidata": + return Epidata(api_key, self._base_url, self._session) def _create_call( self, @@ -135,10 +144,7 @@ def _create_call( params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, ) -> EpiDataCall: - return EpiDataCall(self._base_url, self._session, endpoint, params, meta) - - -Epidata = EpiDataContext() + return EpiDataCall(self._base_url, self._api_key, self._session, endpoint, params, meta) -__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange"] +__all__ = ["Epidata", "EpiDataCall", "EpiRange"] diff --git a/smoke_test.py b/smoke_test.py index c29ccbd..dd17924 100644 --- a/smoke_test.py +++ b/smoke_test.py @@ -1,7 +1,8 @@ from datetime import date from delphi_epidata.requests import Epidata, EpiRange -apicall = Epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", EpiRange(20210405, 20210410), "us") +epidata = Epidata("test") +apicall = epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", EpiRange(20210405, 20210410), "us") print(apicall) @@ -21,9 +22,9 @@ for row in apicall.iter(): print(row) -StagingEpidata = Epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/") +stagingEpidata = epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/") -df = StagingEpidata.covidcast( +df = stagingEpidata.covidcast( "fb-survey", "smoothed_cli", "day", "nation", EpiRange(date(2021, 4, 5), date(2021, 4, 10)), "*" ).df() print(df.shape) diff --git a/smoke_test_async.py b/smoke_test_async.py index e228a96..c0dc3bf 100644 --- a/smoke_test_async.py +++ b/smoke_test_async.py @@ -3,7 +3,8 @@ async def main() -> None: - apicall = Epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", Epidata.range(20210405, 20210410), "us") + epidata = Epidata("test") + apicall = epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", Epidata.range(20210405, 20210410), "us") classic = await apicall.classic() print(classic) From 253b9989aa7cc62c2ce55fa71859c30b5c0e7de5 Mon Sep 17 00:00:00 2001 From: Samuel Gratzl Date: Mon, 23 Aug 2021 09:01:09 -0400 Subject: [PATCH 2/2] feat: adapt test --- smoke_covid_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smoke_covid_test.py b/smoke_covid_test.py index d0fd29a..3a8bec6 100644 --- a/smoke_covid_test.py +++ b/smoke_covid_test.py @@ -1,6 +1,6 @@ from delphi_epidata.requests import CovidcastEpidata, EpiRange -epidata = CovidcastEpidata() +epidata = CovidcastEpidata("test") print(list(epidata.source_names)) apicall = epidata[("fb-survey", "smoothed_cli")].call( "nation",