diff --git a/delphi_epidata/_endpoints.py b/delphi_epidata/_endpoints.py index 8537e29..a4fa4ab 100644 --- a/delphi_epidata/_endpoints.py +++ b/delphi_epidata/_endpoints.py @@ -34,13 +34,11 @@ def _create_call( ) -> CALL_TYPE: raise NotImplementedError() - def pvt_afhsb( - self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, flu_types: StringParam - ) -> CALL_TYPE: + def pvt_afhsb(self, locations: StringParam, epiweeks: EpiRangeParam, flu_types: StringParam) -> CALL_TYPE: """Fetch AFHSB data (point data, no min/max).""" - if auth is None or locations is None or epiweeks is None or flu_types is None: - raise InvalidArgumentException("`auth`, `locations`, `epiweeks` and `flu_types` are all required") + if locations is None or epiweeks is None or flu_types is None: + raise InvalidArgumentException("`locations`, `epiweeks` and `flu_types` are all required") loc_exception = ( "Location parameter `{}` is invalid. Valid `location` parameters are: " @@ -70,7 +68,7 @@ def pvt_afhsb( return self._create_call( "afhsb/", - dict(auth=auth, locations=locations, epiweeks=epiweeks, flu_types=flu_types), + dict(locations=locations, epiweeks=epiweeks, flu_types=flu_types), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("flu_type", EpidataFieldType.text), @@ -79,15 +77,15 @@ def pvt_afhsb( ], ) - def pvt_cdc(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: + def pvt_cdc(self, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: """Fetch CDC page hits.""" - if auth is None or epiweeks is None or locations is None: - raise InvalidArgumentException("`auth`, `epiweeks`, and `locations` are all required") + if epiweeks is None or locations is None: + raise InvalidArgumentException("`epiweeks`, and `locations` are all required") return self._create_call( "cdc/", - dict(auth=auth, epiweeks=epiweeks, locations=locations), + dict(epiweeks=epiweeks, locations=locations), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -484,17 +482,15 @@ def dengue_nowcast(self, locations: StringParam, epiweeks: EpiRangeParam) -> CAL ], ) - def pvt_dengue_sensors( - self, auth: str, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam - ) -> CALL_TYPE: + def pvt_dengue_sensors(self, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch Delphi's digital surveillance sensors.""" - if auth is None or names is None or locations is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `names`, `locations`, and `epiweeks` are all required") + if names is None or locations is None or epiweeks is None: + raise InvalidArgumentException("`names`, `locations`, and `epiweeks` are all required") return self._create_call( "dengue_sensors/", - dict(auth=auth, names=names, locations=locations, epiweeks=epiweeks), + dict(names=names, locations=locations, epiweeks=epiweeks), [ EpidataFieldInfo("name", EpidataFieldType.text), EpidataFieldInfo("location", EpidataFieldType.text), @@ -607,7 +603,6 @@ def fluview( epiweeks: EpiRangeParam, issues: Optional[EpiRangeParam] = None, lag: Optional[int] = None, - auth: Optional[str] = None, ) -> CALL_TYPE: if regions is None or epiweeks is None: raise InvalidArgumentException("`regions` and `epiweeks` are both required") @@ -615,7 +610,7 @@ def fluview( raise InvalidArgumentException("`issues` and `lag` are mutually exclusive") return self._create_call( "fluview/", - dict(regions=regions, epiweeks=epiweeks, issues=issues, lag=lag, auth=auth), + dict(regions=regions, epiweeks=epiweeks, issues=issues, lag=lag), [ EpidataFieldInfo("release_date", EpidataFieldType.text), EpidataFieldInfo("region", EpidataFieldType.text), @@ -649,13 +644,13 @@ def gft(self, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: ], ) - def pvt_ght(self, auth: str, locations: StringParam, epiweeks: EpiRangeParam, query: str) -> CALL_TYPE: + def pvt_ght(self, locations: StringParam, epiweeks: EpiRangeParam, query: str) -> CALL_TYPE: """Fetch Google Health Trends data.""" - if auth is None or locations is None or epiweeks is None or query is None: - raise InvalidArgumentException("`auth`, `locations`, `epiweeks`, and `query` are all required") + if locations is None or epiweeks is None or query is None: + raise InvalidArgumentException("`locations`, `epiweeks`, and `query` are all required") return self._create_call( "ght/", - dict(auth=auth, locations=locations, epiweeks=epiweeks, query=query), + dict(locations=locations, epiweeks=epiweeks, query=query), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -688,28 +683,18 @@ def kcdc_ili( ], ) - def pvt_meta_afhsb(self, auth: str) -> CALL_TYPE: + def pvt_meta_afhsb(self) -> CALL_TYPE: """Fetch AFHSB metadata.""" - - if auth is None: - raise InvalidArgumentException("`auth` is required") - return self._create_call( "meta_afhsb/", - dict(auth=auth), + {}, only_supports_classic=True, ) - def pvt_meta_norostat(self, auth: str) -> CALL_TYPE: + def pvt_meta_norostat(self) -> CALL_TYPE: """Fetch NoroSTAT metadata.""" - if auth is None: - raise InvalidArgumentException("`auth` is required") - return self._create_call( - "meta_norostat/", - dict(auth=auth), - only_supports_classic=True, - ) + return self._create_call("meta_norostat/", {}, only_supports_classic=True) def meta(self) -> CALL_TYPE: """Fetch API metadata.""" @@ -763,14 +748,14 @@ def nidss_flu( ], ) - def pvt_norostat(self, auth: str, location: str, epiweeks: EpiRangeParam) -> CALL_TYPE: + def pvt_norostat(self, location: str, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch NoroSTAT data (point data, no min/max).""" - if auth is None or location is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `location`, and `epiweeks` are all required") + if location is None or epiweeks is None: + raise InvalidArgumentException("`location`, and `epiweeks` are all required") return self._create_call( "norostat/", - dict(auth=auth, epiweeks=epiweeks, location=location), + dict(epiweeks=epiweeks, location=location), [ EpidataFieldInfo("release_date", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -826,15 +811,15 @@ def paho_dengue( ], ) - def pvt_quidel(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: + def pvt_quidel(self, epiweeks: EpiRangeParam, locations: StringParam) -> CALL_TYPE: """Fetch Quidel data.""" - if auth is None or epiweeks is None or locations is None: - raise InvalidArgumentException("`auth`, `epiweeks`, and `locations` are all required") + if epiweeks is None or locations is None: + raise InvalidArgumentException("`epiweeks`, and `locations` are all required") return self._create_call( "quidel/", - dict(auth=auth, epiweeks=epiweeks, locations=locations), + dict(epiweeks=epiweeks, locations=locations), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("epiweek", EpidataFieldType.epiweek), @@ -842,14 +827,14 @@ def pvt_quidel(self, auth: str, epiweeks: EpiRangeParam, locations: StringParam) ], ) - def pvt_sensors(self, auth: str, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: + def pvt_sensors(self, names: StringParam, locations: StringParam, epiweeks: EpiRangeParam) -> CALL_TYPE: """Fetch Delphi's digital surveillance sensors.""" - if auth is None or names is None or locations is None or epiweeks is None: - raise InvalidArgumentException("`auth`, `names`, `locations`, and `epiweeks` are all required") + if names is None or locations is None or epiweeks is None: + raise InvalidArgumentException("`names`, `locations`, and `epiweeks` are all required") return self._create_call( "sensors/", - dict(auth=auth, names=names, locations=locations, epiweeks=epiweeks), + dict(names=names, locations=locations, epiweeks=epiweeks), [ EpidataFieldInfo("name", EpidataFieldType.text), EpidataFieldInfo("location", EpidataFieldType.text), @@ -860,20 +845,19 @@ def pvt_sensors(self, auth: str, names: StringParam, locations: StringParam, epi def pvt_twitter( self, - auth: str, locations: StringParam, dates: Optional[EpiRangeParam] = None, epiweeks: Optional[EpiRangeParam] = None, ) -> CALL_TYPE: """Fetch HealthTweets data.""" - if auth is None or locations is None: - raise InvalidArgumentException("`auth` and `locations` are both required") + if locations is None: + raise InvalidArgumentException("`locations` is required") if not (dates is None) ^ (epiweeks is None): raise InvalidArgumentException("exactly one of `dates` and `epiweeks` is required") return self._create_call( "twitter/", - dict(auth=auth, locations=locations, dates=dates, epiweeks=epiweeks), + dict(locations=locations, dates=dates, epiweeks=epiweeks), [ EpidataFieldInfo("location", EpidataFieldType.text), EpidataFieldInfo("date", EpidataFieldType.date) diff --git a/delphi_epidata/_model.py b/delphi_epidata/_model.py index 87623de..9286b82 100644 --- a/delphi_epidata/_model.py +++ b/delphi_epidata/_model.py @@ -154,6 +154,7 @@ class AEpiDataCall: """ _base_url: Final[str] + _api_key: Final[str] _endpoint: Final[str] _params: Final[Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]]] meta: Final[Sequence[EpidataFieldInfo]] @@ -163,12 +164,14 @@ class AEpiDataCall: def __init__( self, base_url: str, + api_key: str, endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, only_supports_classic: bool = False, ) -> None: self._base_url = base_url + self._api_key = api_key self._endpoint = endpoint self._params = params self.only_supports_classic = only_supports_classic @@ -190,6 +193,7 @@ def _formatted_paramters( all_params["format"] = format_type if fields: all_params["fields"] = fields + all_params["token"] = self._api_key return {k: format_list(v) for k, v in all_params.items() if v is not None} def request_arguments( @@ -220,7 +224,7 @@ def request_url( u, p = self.request_arguments(format_type, fields) query = urlencode(p) if query: - return f"{u}?{query}" + return f"{u}?{query}&token={self._api_key}" return u def __repr__(self) -> str: diff --git a/delphi_epidata/async_requests.py b/delphi_epidata/async_requests.py index 7f7993d..c2cc32a 100644 --- a/delphi_epidata/async_requests.py +++ b/delphi_epidata/async_requests.py @@ -17,6 +17,7 @@ from asyncio import get_event_loop, gather from aiohttp import TCPConnector, ClientSession, ClientResponse +from aiohttp.helpers import BasicAuth from pandas import DataFrame from ._model import ( @@ -35,12 +36,12 @@ async def _async_request( - url: str, params: Mapping[str, str], session: Optional[ClientSession] = None + url: str, params: Mapping[str, str], api_key: str, session: Optional[ClientSession] = None ) -> ClientResponse: async def call_impl(s: ClientSession) -> ClientResponse: - res = await s.get(url, params=params, headers=HTTP_HEADERS) + res = await s.get(url, params=params, headers=HTTP_HEADERS, auth=BasicAuth("epidata", api_key)) if res.status == 414: - return await s.post(url, params=params, headers=HTTP_HEADERS) + return await s.post(url, params=params, headers=HTTP_HEADERS, auth=BasicAuth("epidata", api_key)) return res if session: @@ -60,20 +61,24 @@ class EpiDataAsyncCall(AEpiDataCall): def __init__( self, base_url: str, + api_key: str, session: Optional[ClientSession], endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, only_supports_classic: bool = False, ) -> None: - super().__init__(base_url, endpoint, params, meta, only_supports_classic) + super().__init__(base_url, api_key, endpoint, params, meta, only_supports_classic) self._session = session def with_base_url(self, base_url: str) -> "EpiDataAsyncCall": - return EpiDataAsyncCall(base_url, self._session, self._endpoint, self._params) + return EpiDataAsyncCall(base_url, self._api_key, self._session, self._endpoint, self._params) def with_session(self, session: ClientSession) -> "EpiDataAsyncCall": - return EpiDataAsyncCall(self._base_url, session, self._endpoint, self._params) + return EpiDataAsyncCall(self._base_url, self._api_key, session, self._endpoint, self._params) + + def with_api_key(self, api_key: str) -> "EpiDataAsyncCall": + return EpiDataAsyncCall(self._base_url, api_key, self._session, self._endpoint, self._params) async def _call( self, @@ -81,7 +86,7 @@ async def _call( fields: Optional[Iterable[str]] = None, ) -> ClientResponse: url, params = self.request_arguments(format_type, fields) - return await _async_request(url, params, self._session) + return await _async_request(url, params, self._api_key, self._session) async def classic( self, fields: Optional[Iterable[str]] = None, disable_date_parsing: Optional[bool] = False @@ -153,24 +158,29 @@ async def __(self) -> AsyncGenerator[Mapping[str, Union[str, int, float, date, N return self.iter() -class EpiDataAsyncContext(AEpiDataEndpoints[EpiDataAsyncCall]): +class Epidata(AEpiDataEndpoints[EpiDataAsyncCall]): """ - sync epidata call class + async epidata call class """ _base_url: Final[str] + _api_key: Final[str] _session: Final[Optional[ClientSession]] - def __init__(self, base_url: str = BASE_URL, session: Optional[ClientSession] = None) -> None: + def __init__(self, api_key: str, base_url: str = BASE_URL, session: Optional[ClientSession] = None) -> None: super().__init__() + self._api_key = api_key self._base_url = base_url self._session = session - def with_base_url(self, base_url: str) -> "EpiDataAsyncContext": - return EpiDataAsyncContext(base_url, self._session) + def with_base_url(self, base_url: str) -> "Epidata": + return Epidata(self._api_key, base_url, self._session) + + def with_session(self, session: ClientSession) -> "Epidata": + return Epidata(self._api_key, self._base_url, session) - def with_session(self, session: ClientSession) -> "EpiDataAsyncContext": - return EpiDataAsyncContext(self._base_url, session) + def with_api_key(self, api_key: str) -> "Epidata": + return Epidata(api_key, self._base_url, self._session) def _create_call( self, @@ -179,7 +189,9 @@ def _create_call( meta: Optional[Sequence[EpidataFieldInfo]] = None, only_supports_classic: bool = False, ) -> EpiDataAsyncCall: - return EpiDataAsyncCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic) + return EpiDataAsyncCall( + self._base_url, self._api_key, self._session, endpoint, params, meta, only_supports_classic + ) @staticmethod def all( @@ -247,21 +259,18 @@ def call_api(call: EpiDataAsyncCall, session: ClientSession) -> Coroutine: return self.all(calls, call_api, batch_size) -Epidata = EpiDataAsyncContext() - - async def CovidcastEpidata( - base_url: str = BASE_URL, session: Optional[ClientSession] = None + api_key: str, base_url: str = BASE_URL, session: Optional[ClientSession] = None ) -> CovidcastDataSources[EpiDataAsyncCall]: url = add_endpoint_to_url(base_url, "covidcast/meta") - meta_data_res = await _async_request(url, {}, session) + meta_data_res = await _async_request(url, {}, api_key, session) meta_data_res.raise_for_status() meta_data = await meta_data_res.json() def create_call(params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]]) -> EpiDataAsyncCall: - return EpiDataAsyncCall(base_url, session, "covidcast", params, define_covidcast_fields()) + return EpiDataAsyncCall(base_url, api_key, session, "covidcast", params, define_covidcast_fields()) return CovidcastDataSources.create(meta_data, create_call) -__all__ = ["Epidata", "EpiDataAsyncCall", "EpiDataAsyncContext", "EpiRange", "CovidcastEpidata"] +__all__ = ["Epidata", "EpiDataAsyncCall", "EpiRange", "CovidcastEpidata"] diff --git a/delphi_epidata/requests.py b/delphi_epidata/requests.py index 974bebe..c004c3b 100644 --- a/delphi_epidata/requests.py +++ b/delphi_epidata/requests.py @@ -23,14 +23,14 @@ @retry(reraise=True, stop=stop_after_attempt(2)) def _request_with_retry( - url: str, params: Mapping[str, str], session: Optional[Session] = None, stream: bool = False + url: str, params: Mapping[str, str], api_key: str, session: Optional[Session] = None, stream: bool = False ) -> Response: """Make request with a retry if an exception is thrown.""" def call_impl(s: Session) -> Response: - res = s.get(url, params=params, headers=HTTP_HEADERS, stream=stream) + res = s.get(url, params=params, headers=HTTP_HEADERS, stream=stream, auth=("epidata", api_key)) if res.status_code == 414: - return s.post(url, params=params, headers=HTTP_HEADERS, stream=stream) + return s.post(url, params=params, headers=HTTP_HEADERS, stream=stream, auth=("epidata", api_key)) return res if session: @@ -50,20 +50,24 @@ class EpiDataCall(AEpiDataCall): def __init__( self, base_url: str, + api_key: str, session: Optional[Session], endpoint: str, params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]], meta: Optional[Sequence[EpidataFieldInfo]] = None, only_supports_classic: bool = False, ) -> None: - super().__init__(base_url, endpoint, params, meta, only_supports_classic) + super().__init__(base_url, api_key, endpoint, params, meta, only_supports_classic) self._session = session def with_base_url(self, base_url: str) -> "EpiDataCall": - return EpiDataCall(base_url, self._session, self._endpoint, self._params) + return EpiDataCall(base_url, self._api_key, self._session, self._endpoint, self._params) def with_session(self, session: Session) -> "EpiDataCall": - return EpiDataCall(self._base_url, session, self._endpoint, self._params) + return EpiDataCall(self._base_url, self._api_key, session, self._endpoint, self._params) + + def with_api_key(self, api_key: str) -> "EpiDataCall": + return EpiDataCall(self._base_url, api_key, self._session, self._endpoint, self._params) def _call( self, @@ -72,7 +76,7 @@ def _call( stream: bool = False, ) -> Response: url, params = self.request_arguments(format_type, fields) - return _request_with_retry(url, params, self._session, stream) + return _request_with_retry(url, params, self._api_key, self._session, stream) def classic( self, fields: Optional[Iterable[str]] = None, disable_date_parsing: Optional[bool] = False @@ -143,24 +147,29 @@ def __iter__(self) -> Generator[Mapping[str, Union[str, int, float, date, None]] return self.iter() -class EpiDataContext(AEpiDataEndpoints[EpiDataCall]): +class Epidata(AEpiDataEndpoints[EpiDataCall]): """ sync epidata call class """ _base_url: Final[str] + _api_key: Final[str] _session: Final[Optional[Session]] - def __init__(self, base_url: str = BASE_URL, session: Optional[Session] = None) -> None: + def __init__(self, api_key: str, base_url: str = BASE_URL, session: Optional[Session] = None) -> None: super().__init__() + self._api_key = api_key self._base_url = base_url self._session = session - def with_base_url(self, base_url: str) -> "EpiDataContext": - return EpiDataContext(base_url, self._session) + def with_base_url(self, base_url: str) -> "Epidata": + return Epidata(self._api_key, base_url, self._session) + + def with_session(self, session: Session) -> "Epidata": + return Epidata(self._api_key, self._base_url, session) - def with_session(self, session: Session) -> "EpiDataContext": - return EpiDataContext(self._base_url, session) + def with_api_key(self, api_key: str) -> "Epidata": + return Epidata(api_key, self._base_url, self._session) def _create_call( self, @@ -169,22 +178,21 @@ def _create_call( meta: Optional[Sequence[EpidataFieldInfo]] = None, only_supports_classic: bool = False, ) -> EpiDataCall: - return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic) - - -Epidata = EpiDataContext() + return EpiDataCall(self._base_url, self._api_key, self._session, endpoint, params, meta, only_supports_classic) -def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None) -> CovidcastDataSources[EpiDataCall]: +def CovidcastEpidata( + api_key: str, base_url: str = BASE_URL, session: Optional[Session] = None +) -> CovidcastDataSources[EpiDataCall]: url = add_endpoint_to_url(base_url, "covidcast/meta") - meta_data_res = _request_with_retry(url, {}, session, False) + meta_data_res = _request_with_retry(url, {}, api_key, session, False) meta_data_res.raise_for_status() meta_data = meta_data_res.json() def create_call(params: Mapping[str, Union[None, EpiRangeLike, Iterable[EpiRangeLike]]]) -> EpiDataCall: - return EpiDataCall(base_url, session, "covidcast", params, define_covidcast_fields()) + return EpiDataCall(base_url, api_key, session, "covidcast", params, define_covidcast_fields()) return CovidcastDataSources.create(meta_data, create_call) -__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"] +__all__ = ["Epidata", "EpiDataCall", "EpiRange", "CovidcastEpidata"] diff --git a/smoke_covid_test.py b/smoke_covid_test.py index d0fd29a..3a8bec6 100644 --- a/smoke_covid_test.py +++ b/smoke_covid_test.py @@ -1,6 +1,6 @@ from delphi_epidata.requests import CovidcastEpidata, EpiRange -epidata = CovidcastEpidata() +epidata = CovidcastEpidata("test") print(list(epidata.source_names)) apicall = epidata[("fb-survey", "smoothed_cli")].call( "nation", diff --git a/smoke_test.py b/smoke_test.py index 6ff2720..ab46ee9 100644 --- a/smoke_test.py +++ b/smoke_test.py @@ -1,7 +1,8 @@ from datetime import date from delphi_epidata.requests import Epidata, EpiRange -apicall = Epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", EpiRange(20210405, 20210410), "us") +epidata = Epidata("test") +apicall = epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", EpiRange(20210405, 20210410), "us") print(apicall) @@ -26,9 +27,9 @@ for row in apicall.iter(): print(row) -StagingEpidata = Epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/") +stagingEpidata = epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/") -df = StagingEpidata.covidcast( +df = stagingEpidata.covidcast( "fb-survey", "smoothed_cli", "day", "nation", EpiRange(date(2021, 4, 5), date(2021, 4, 10)), "*" ).df() print(df.shape) diff --git a/smoke_test_async.py b/smoke_test_async.py index e228a96..c0dc3bf 100644 --- a/smoke_test_async.py +++ b/smoke_test_async.py @@ -3,7 +3,8 @@ async def main() -> None: - apicall = Epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", Epidata.range(20210405, 20210410), "us") + epidata = Epidata("test") + apicall = epidata.covidcast("fb-survey", "smoothed_cli", "day", "nation", Epidata.range(20210405, 20210410), "us") classic = await apicall.classic() print(classic)