diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index aedd326a..7a2fa504 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -26,6 +26,7 @@ import cdsapi from ecmwfapi import ECMWFService +from .config import Config warnings.simplefilter( "ignore", category=urllib3.connectionpool.InsecureRequestWarning) @@ -42,7 +43,7 @@ class Client(abc.ABC): level: Default log level for the client. """ - def __init__(self, config: t.Dict, level: int = logging.INFO) -> None: + def __init__(self, config: Config, level: int = logging.INFO) -> None: """Clients are initialized with the general CLI configuration.""" self.config = config self.logger = logging.getLogger(f'{__name__}.{type(self).__name__}') @@ -87,11 +88,11 @@ class CdsClient(Client): """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {'reanalysis-era'} - def __init__(self, config: t.Dict, level: int = logging.INFO) -> None: + def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) self.c = cdsapi.Client( - url=config['parameters'].get('api_url', os.environ.get('CDSAPI_URL')), - key=config['parameters'].get('api_key', os.environ.get('CDSAPI_KEY')), + url=config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), + key=config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')), debug_callback=self.logger.debug, info_callback=self.logger.info, warning_callback=self.logger.warning, @@ -171,13 +172,13 @@ class MarsClient(Client): level: Default log level for the client. """ - def __init__(self, config: t.Dict, level: int = logging.INFO) -> None: + def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) self.c = ECMWFService( "mars", - key=config['parameters'].get('api_key', os.environ.get("MARSAPI_KEY")), - url=config['parameters'].get('api_url', os.environ.get("MARSAPI_URL")), - email=config['parameters'].get('api_email', os.environ.get("MARSAPI_EMAIL")), + key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), + url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), + email=config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, verbose=True ) diff --git a/weather_dl/download_pipeline/clients_test.py b/weather_dl/download_pipeline/clients_test.py index 5ccb366b..f11f8ea6 100644 --- a/weather_dl/download_pipeline/clients_test.py +++ b/weather_dl/download_pipeline/clients_test.py @@ -15,26 +15,27 @@ import unittest from .clients import FakeClient, CdsClient, MarsClient +from .config import Config class MaxWorkersTest(unittest.TestCase): def test_cdsclient_internal(self): - client = CdsClient({'parameters': {'api_url': 'url', 'api_key': 'key'}}) + client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) self.assertEqual( client.num_requests_per_key("reanalysis-era5-some-data"), 5) def test_cdsclient_mars_hosted(self): - client = CdsClient({'parameters': {'api_url': 'url', 'api_key': 'key'}}) + client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) self.assertEqual( client.num_requests_per_key("reanalysis-carra-height-levels"), 2) def test_marsclient(self): - client = MarsClient({'parameters': {}}) + client = MarsClient(Config.from_dict({'parameters': {}})) self.assertEqual( client.num_requests_per_key("reanalysis-era5-some-data"), 2) def test_fakeclient(self): - client = FakeClient({'parameters': {}}) + client = FakeClient(Config.from_dict({'parameters': {}})) self.assertEqual( client.num_requests_per_key("reanalysis-era5-some-data"), 1) diff --git a/weather_dl/download_pipeline/config.py b/weather_dl/download_pipeline/config.py new file mode 100644 index 00000000..80309066 --- /dev/null +++ b/weather_dl/download_pipeline/config.py @@ -0,0 +1,71 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing as t +import dataclasses + +Values = t.Union[t.List['Values'], t.Dict[str, 'Values'], bool, int, float, str] # pytype: disable=not-supported-yet + + +@dataclasses.dataclass +class Config: + """Contains pipeline parameters. + + Attributes: + client: + Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. + dataset (optional): + Name of the target dataset. Allowed options are dictated by the client. + partition_keys (optional): + Choose the keys from the selection section to partition the data request. + This will compute a cartesian cross product of the selected keys + and assign each as their own download. + target_path: + Download artifact filename template. Can make use of Python's standard string formatting. + It can contain format symbols to be replaced by partition keys; + if this is used, the total number of format symbols must match the number of partition keys. + subsection_name: + Name of the particular subsection. 'default' if there is no subsection. + force_download: + Force redownload of partitions that were previously downloaded. + user_id: + Username from the environment variables. + kwargs (optional): + For representing subsections or any other parameters. + selection: + Contains parameters used to select desired data. + """ + + client: str = "" + dataset: t.Optional[str] = "" + target_path: str = "" + partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) + subsection_name: str = "default" + force_download: bool = False + user_id: str = "unknown" + kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) + selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, config: t.Dict) -> 'Config': + config_instance = cls() + for section_key, section_value in config.items(): + if section_key == "parameters": + for key, value in section_value.items(): + if hasattr(config_instance, key): + setattr(config_instance, key, value) + else: + config_instance.kwargs[key] = value + if section_key == "selection": + config_instance.selection = section_value + return config_instance diff --git a/weather_dl/download_pipeline/fetcher.py b/weather_dl/download_pipeline/fetcher.py index c2058f7c..7fb0a861 100644 --- a/weather_dl/download_pipeline/fetcher.py +++ b/weather_dl/download_pipeline/fetcher.py @@ -23,6 +23,7 @@ from .clients import CLIENTS, Client from .manifest import Manifest, NoOpManifest, Location from .parsers import prepare_target_name +from .config import Config from .partition import skip_partition from .stores import Store, FSStore from .util import retry_with_exponential_backoff @@ -64,7 +65,7 @@ def retrieve(self, client: Client, dataset: str, selection: t.Dict, dest: str) - """Retrieve from download client, with retries.""" client.retrieve(dataset, selection, dest) - def fetch_data(self, config: t.Dict, *, worker_name: str = 'default') -> None: + def fetch_data(self, config: Config, *, worker_name: str = 'default') -> None: """Download data from a client to a temp file, then upload to Cloud Storage.""" if not config: return @@ -74,14 +75,11 @@ def fetch_data(self, config: t.Dict, *, worker_name: str = 'default') -> None: client = CLIENTS[self.client_name](config) target = prepare_target_name(config) - dataset = config['parameters'].get('dataset', '') - selection = config['selection'] - user = config['parameters'].get('user_id', 'unknown') - with self.manifest.transact(selection, target, user): + with self.manifest.transact(config.selection, target, config.user_id): with tempfile.NamedTemporaryFile() as temp: logger.info(f'[{worker_name}] Fetching data for {target!r}.') - self.retrieve(client, dataset, selection, temp.name) + self.retrieve(client, config.dataset, config.selection, temp.name) logger.info(f'[{worker_name}] Uploading to store for {target!r}.') self.upload(temp, target) diff --git a/weather_dl/download_pipeline/fetcher_test.py b/weather_dl/download_pipeline/fetcher_test.py index 74bc7e1b..0108cc3e 100644 --- a/weather_dl/download_pipeline/fetcher_test.py +++ b/weather_dl/download_pipeline/fetcher_test.py @@ -22,6 +22,7 @@ from .fetcher import Fetcher from .manifest import MockManifest, Location from .stores import InMemoryStore, FSStore +from .config import Config class UploadTest(unittest.TestCase): @@ -69,7 +70,7 @@ def setUp(self) -> None: @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) @patch('cdsapi.Client.retrieve') def test_fetch_data(self, mock_retrieve, mock_gcs_file): - config = { + config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], @@ -82,7 +83,7 @@ def test_fetch_data(self, mock_retrieve, mock_gcs_file): 'month': ['12'], 'year': ['01'] } - } + }) fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) @@ -94,13 +95,13 @@ def test_fetch_data(self, mock_retrieve, mock_gcs_file): mock_retrieve.assert_called_with( 'reanalysis-era5-pressure-levels', - config['selection'], + config.selection, ANY) @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) @patch('cdsapi.Client.retrieve') def test_fetch_data__manifest__returns_success(self, mock_retrieve, mock_gcs_file): - config = { + config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], @@ -113,13 +114,13 @@ def test_fetch_data__manifest__returns_success(self, mock_retrieve, mock_gcs_fil 'month': ['12'], 'year': ['01'] } - } + }) fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) self.assertDictContainsSubset(dict( - selection=config['selection'], + selection=config.selection, location='gs://weather-dl-unittest/download-01-12.nc', status='success', error=None, @@ -128,7 +129,7 @@ def test_fetch_data__manifest__returns_success(self, mock_retrieve, mock_gcs_fil @patch('cdsapi.Client.retrieve') def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): - config = { + config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], @@ -141,7 +142,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): 'month': ['12'], 'year': ['01'] } - } + }) error = IOError("We don't have enough permissions to download this.") mock_retrieve.side_effect = error @@ -153,7 +154,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): actual = list(self.dummy_manifest.records.values())[0]._asdict() self.assertDictContainsSubset(dict( - selection=config['selection'], + selection=config.selection, location='gs://weather-dl-unittest/download-01-12.nc', status='failure', user='unknown', @@ -165,7 +166,7 @@ def test_fetch_data__manifest__records_retrieve_failure(self, mock_retrieve): @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) @patch('cdsapi.Client.retrieve') def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve, mock_gcs_file): - config = { + config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], @@ -178,7 +179,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve, mock_gcs 'month': ['12'], 'year': ['01'] } - } + }) error = IOError("Can't open gcs file.") mock_gcs_file.side_effect = error @@ -189,7 +190,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve, mock_gcs actual = list(self.dummy_manifest.records.values())[0]._asdict() self.assertDictContainsSubset(dict( - selection=config['selection'], + selection=config.selection, location='gs://weather-dl-unittest/download-01-12.nc', status='failure', user='unknown', @@ -201,7 +202,7 @@ def test_fetch_data__manifest__records_gcs_failure(self, mock_retrieve, mock_gcs @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) @patch('cdsapi.Client.retrieve') def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file): - config = { + config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], @@ -214,7 +215,7 @@ def test_fetch_data__skips_existing_download(self, mock_retrieve, mock_gcs_file) 'month': ['12'], 'year': ['01'] } - } + }) # target file already exists in store... store = InMemoryStore() diff --git a/weather_dl/download_pipeline/parsers.py b/weather_dl/download_pipeline/parsers.py index dd29bac0..e47a908f 100644 --- a/weather_dl/download_pipeline/parsers.py +++ b/weather_dl/download_pipeline/parsers.py @@ -21,14 +21,12 @@ import string import textwrap import typing as t +from .config import Config from urllib.parse import urlparse from collections import OrderedDict from .clients import CLIENTS from .manifest import MANIFESTS, Manifest, Location, NoOpManifest -Values = t.Union[t.List['Values'], t.Dict[str, 'Values'], bool, int, float, str] # pytype: disable=not-supported-yet -Config = t.Dict[str, t.Dict[str, Values]] - def date(candidate: str) -> datetime.date: """Converts ECMWF-format date strings into a `datetime.date`. @@ -144,7 +142,7 @@ def typecast(key: str, value: t.Any) -> t.Any: return converted -def parse_config(file: t.IO) -> Config: +def parse_config(file: t.IO) -> t.Dict: """Parses a `*.json` or `*.cfg` file into a configuration dictionary.""" try: # TODO(b/175429166): JSON files do not support MARs range syntax. @@ -298,16 +296,10 @@ def _number_of_replacements(s: t.Text): def parse_subsections(config: t.Dict) -> t.Dict: - """Interprets [section.subsection] as nested dictionaries in `.cfg` files. - - Also counts number of 'api_key' fields found. - """ + """Interprets [section.subsection] as nested dictionaries in `.cfg` files.""" copy = cp.deepcopy(config) - num_api_keys = 0 for key, val in copy.items(): path = key.split('.') - if val.get('api_key', ''): - num_api_keys += 1 runner = copy parent = {} p = None @@ -321,8 +313,6 @@ def parse_subsections(config: t.Dict) -> t.Dict: for_cleanup = [key for key, _ in copy.items() if '.' in key] for target in for_cleanup: del copy[target] - if num_api_keys: - copy['parameters']['num_api_keys'] = num_api_keys return copy @@ -340,7 +330,7 @@ def require(condition: bool, message: str, error_type: t.Type[Exception] = Value """ 'parameters' section required in configuration file. - The 'parameters' section specifies the 'dataset', 'target_path', and + The 'parameters' section specifies the 'client', 'dataset', 'target_path', and 'partition_key' for the API client. Please consult the documentation for more information.""") @@ -411,18 +401,13 @@ def require(condition: bool, message: str, error_type: t.Type[Exception] = Value # Ensure consistent lookup. config['parameters']['partition_keys'] = partition_keys - return config + return Config.from_dict(config) def prepare_target_name(config: Config) -> str: """Returns name of target location.""" - parameters = config['parameters'] - target_path = t.cast(str, parameters.get('target_path', '')) - partition_keys = t.cast(t.List[str], - cp.copy(parameters.get('partition_keys', list()))) - - partition_dict = OrderedDict((key, typecast(key, config['selection'][key][0])) for key in partition_keys) - target = target_path.format(*partition_dict.values(), **partition_dict) + partition_dict = OrderedDict((key, typecast(key, config.selection[key][0])) for key in config.partition_keys) + target = config.target_path.format(*partition_dict.values(), **partition_dict) return target @@ -448,5 +433,5 @@ def get_subsections(config: Config) -> t.List[t.Tuple[str, t.Dict]]: api_url=UUUUU3 ``` """ - return [(name, params) for name, params in config['parameters'].items() + return [(name, params) for name, params in config.kwargs.items() if isinstance(params, dict)] or [('default', {})] diff --git a/weather_dl/download_pipeline/parsers_test.py b/weather_dl/download_pipeline/parsers_test.py index d6f40373..1eba03fb 100644 --- a/weather_dl/download_pipeline/parsers_test.py +++ b/weather_dl/download_pipeline/parsers_test.py @@ -25,6 +25,7 @@ parse_subsections, prepare_target_name, ) +from .config import Config class DateTest(unittest.TestCase): @@ -376,7 +377,6 @@ def test_cfg_parses_parameter_subsections(self): 'api_url': 'https://google.com/', 'alice': {'api_key': '123'}, 'bob': {'api_key': '456'}, - 'num_api_keys': 2, }, }) @@ -418,8 +418,7 @@ def test_api_keys(self): self.assertEqual(actual, {'parameters': {'a': 1, 'b': 2, 'param1': {'api_key': 'key1'}, - 'param2': {'api_key': 'key2'}, - 'num_api_keys': 2}}) + 'param2': {'api_key': 'key2'}}}) class ProcessConfigTest(unittest.TestCase): @@ -439,7 +438,7 @@ def test_require_params_section(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ - [otherSection] + [selection] key=value """ ) as f: @@ -615,37 +614,7 @@ def test_treats_partition_keys_as_list(self): """ ) as f: config = process_config(f) - params = config.get('parameters', {}) - self.assertIsInstance(params['partition_keys'], list) - - def test_params_in_config(self): - with io.StringIO( - """ - [parameters] - dataset=foo - client=cds - target_path=bar-{}-{} - partition_keys= - year - month - [selection] - month= - 01 - 02 - 03 - year= - 1950 - 1960 - 1970 - 1980 - 1990 - 2000 - 2010 - 2020 - """ - ) as f: - config = process_config(f) - self.assertIn('parameters', config) + self.assertIsInstance(config.partition_keys, list) def test_mismatched_template_partition_keys(self): with self.assertRaises(ValueError) as ctx: @@ -862,7 +831,7 @@ def setUp(self) -> None: def test_target_name(self): for it in self.TEST_CASES: with self.subTest(msg=it['case'], **it): - actual = prepare_target_name(it['config']) + actual = prepare_target_name(Config.from_dict(it['config'])) self.assertEqual(actual, it['expected']) diff --git a/weather_dl/download_pipeline/partition.py b/weather_dl/download_pipeline/partition.py index 513c0002..69867514 100644 --- a/weather_dl/download_pipeline/partition.py +++ b/weather_dl/download_pipeline/partition.py @@ -20,7 +20,8 @@ import apache_beam as beam from .manifest import Manifest -from .parsers import Config, prepare_target_name +from .parsers import prepare_target_name +from .config import Config from .stores import Store, FSStore Partition = t.Tuple[str, t.Dict, Config] @@ -101,21 +102,19 @@ def _create_partition_config(option: t.Tuple, config: Config) -> Config: Returns: A configuration with that selects a single download partition. """ - partition_keys = config.get('parameters', {}).get('partition_keys', []) - selection = config.get('selection', {}) - copy = cp.deepcopy(selection) + copy = cp.deepcopy(config.selection) out = cp.deepcopy(config) - for idx, key in enumerate(partition_keys): + for idx, key in enumerate(config.partition_keys): copy[key] = [option[idx]] - out['selection'] = copy + out.selection = copy return out def skip_partition(config: Config, store: Store) -> bool: """Return true if partition should be skipped.""" - if config['parameters'].get('force_download', False): + if config.force_download: return False target = prepare_target_name(config) @@ -138,10 +137,7 @@ def prepare_partitions(config: Config) -> t.Iterator[Config]: Returns: An iterator of `Config`s. """ - partition_keys = config.get('parameters', {}).get('partition_keys', []) - selection = config.get('selection', {}) - - for option in itertools.product(*[selection[key] for key in partition_keys]): + for option in itertools.product(*[config.selection[key] for key in config.partition_keys]): yield _create_partition_config(option, config) @@ -175,12 +171,12 @@ def assemble_config(partition: Partition, manifest: Manifest) -> Config: An `Config` assembled out of subsection parameters and config shards. """ name, params, out = partition - out['parameters'].update(params) - out['parameters']['__subsection__'] = name + out.kwargs.update(params) + out.subsection_name = name location = prepare_target_name(out) - user = out['parameters'].get('user_id', 'unknown') - manifest.schedule(out['selection'], location, user) + user = out.user_id + manifest.schedule(out.selection, location, user) logger.info(f'[{name}] Created partition {location!r}.') beam.metrics.Metrics.counter('Subsection', name).inc() diff --git a/weather_dl/download_pipeline/partition_test.py b/weather_dl/download_pipeline/partition_test.py index a2657ac8..8131f420 100644 --- a/weather_dl/download_pipeline/partition_test.py +++ b/weather_dl/download_pipeline/partition_test.py @@ -25,6 +25,7 @@ from .parsers import get_subsections from .partition import skip_partition, PartitionConfig from .stores import InMemoryStore, Store +from .config import Config class OddFilesDoNotExistStore(InMemoryStore): @@ -43,7 +44,7 @@ class PreparePartitionTest(unittest.TestCase): def setUp(self) -> None: self.dummy_manifest = MockManifest(Location('mock://dummy')) - def create_partition_configs(self, config, store: t.Optional[Store] = None) -> t.List[t.Dict]: + def create_partition_configs(self, config, store: t.Optional[Store] = None) -> t.List[Config]: subsections = get_subsections(config) params_cycle = itertools.cycle(subsections) @@ -64,9 +65,10 @@ def test_partition_single_key(self): } } - actual = self.create_partition_configs(config) + config_obj = Config.from_dict(config) + actual = self.create_partition_configs(config_obj) - self.assertListEqual([d['selection'] for d in actual], [ + self.assertListEqual([d.selection for d in actual], [ {**config['selection'], **{'year': [str(i)]}} for i in range(2015, 2021) ]) @@ -84,9 +86,10 @@ def test_partition_multi_key(self): } } - actual = self.create_partition_configs(config) + config_obj = Config.from_dict(config) + actual = self.create_partition_configs(config_obj) - self.assertListEqual([d['selection'] for d in actual], [ + self.assertListEqual([d.selection for d in actual], [ {**config['selection'], **{'year': ['2015'], 'month': ['1']}}, {**config['selection'], **{'year': ['2015'], 'month': ['2']}}, {**config['selection'], **{'year': ['2016'], 'month': ['1']}}, @@ -118,22 +121,23 @@ def test_partition_multi_params_multi_key(self): } } - actual = self.create_partition_configs(config) - - expected = [ - {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', __subsection__='research'), - 'selection': {**config['selection'], - **{'year': ['2015'], 'month': ['1']}}}, - {'parameters': dict(config['parameters'], api_key='KKKK2', api_url='UUUU2', __subsection__='cloud'), - 'selection': {**config['selection'], - **{'year': ['2015'], 'month': ['2']}}}, - {'parameters': dict(config['parameters'], api_key='KKKK3', api_url='UUUU3', __subsection__='deepmind'), - 'selection': {**config['selection'], - **{'year': ['2016'], 'month': ['1']}}}, - {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', __subsection__='research'), - 'selection': {**config['selection'], - **{'year': ['2016'], 'month': ['2']}}}, - ] + config_obj = Config.from_dict(config) + actual = self.create_partition_configs(config_obj) + + expected = [Config.from_dict(it) for it in [ + {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', subsection_name='research'), + 'selection': {**config['selection'], + **{'year': ['2015'], 'month': ['1']}}}, + {'parameters': dict(config['parameters'], api_key='KKKK2', api_url='UUUU2', subsection_name='cloud'), + 'selection': {**config['selection'], + **{'year': ['2015'], 'month': ['2']}}}, + {'parameters': dict(config['parameters'], api_key='KKKK3', api_url='UUUU3', subsection_name='deepmind'), + 'selection': {**config['selection'], + **{'year': ['2016'], 'month': ['1']}}}, + {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', subsection_name='research'), + 'selection': {**config['selection'], + **{'year': ['2016'], 'month': ['2']}}}, + ]] self.assertListEqual(actual, expected) @@ -150,10 +154,12 @@ def test_prepare_partition_records_download_status_to_manifest(self): } } + config_obj = Config.from_dict(config) + with tempfile.TemporaryDirectory() as tmpdir: self.dummy_manifest = LocalManifest(Location(tmpdir)) - self.create_partition_configs(config) + self.create_partition_configs(config_obj) with open(self.dummy_manifest.location, 'r') as f: actual = json.load(f) @@ -190,9 +196,10 @@ def test_skip_partitions__never_unbalances_licenses(self): } } - actual = self.create_partition_configs(config, store=skip_odd_files) - research_configs = [cfg for cfg in actual if cfg and cfg['parameters']['api_url'].endswith('1')] - cloud_configs = [cfg for cfg in actual if cfg and cfg['parameters']['api_url'].endswith('2')] + config_obj = Config.from_dict(config) + actual = self.create_partition_configs(config_obj, store=skip_odd_files) + research_configs = [cfg for cfg in actual if cfg and t.cast('str', cfg.kwargs.get('api_url', "")).endswith('1')] + cloud_configs = [cfg for cfg in actual if cfg and t.cast('str', cfg.kwargs.get('api_url', "")).endswith('2')] self.assertEqual(len(research_configs), len(cloud_configs)) @@ -215,7 +222,8 @@ def test_skip_partition_missing_force_download(self): } } - actual = skip_partition(config, self.mock_store) + config_obj = Config.from_dict(config) + actual = skip_partition(config_obj, self.mock_store) self.assertEqual(actual, False) @@ -233,7 +241,8 @@ def test_skip_partition_force_download_true(self): } } - actual = skip_partition(config, self.mock_store) + config_obj = Config.from_dict(config) + actual = skip_partition(config_obj, self.mock_store) self.assertEqual(actual, False) @@ -251,9 +260,11 @@ def test_skip_partition_force_download_false(self): } } + config_obj = Config.from_dict(config) + self.mock_store.exists = MagicMock(return_value=True) - actual = skip_partition(config, self.mock_store) + actual = skip_partition(config_obj, self.mock_store) self.assertEqual(actual, True) diff --git a/weather_dl/download_pipeline/pipeline.py b/weather_dl/download_pipeline/pipeline.py index 5d707b8a..fc9022c1 100644 --- a/weather_dl/download_pipeline/pipeline.py +++ b/weather_dl/download_pipeline/pipeline.py @@ -35,10 +35,10 @@ NoOpManifest, ) from .parsers import ( - Config, parse_manifest, process_config, get_subsections, ) +from .config import Config from .partition import PartitionConfig from .stores import TempFileStore, LocalFileStore @@ -91,7 +91,7 @@ def pipeline(args: PipelineArgs) -> None: request_idxs = {name: itertools.cycle(range(args.num_requesters_per_key)) for name, _ in subsections} def subsection_and_request(it: Config) -> t.Tuple[str, int]: - subsection = t.cast(builtins.str, it.get('parameters', {}).get('__subsection__', 'default')) + subsection = it.subsection_name return subsection, builtins.next(request_idxs[subsection]) subsections_cycle = itertools.cycle(subsections) @@ -138,23 +138,23 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: with open(known_args.config, 'r', encoding='utf-8') as f: config = process_config(f) - config['parameters']['force_download'] = known_args.force_download - config['parameters']['user_id'] = getpass.getuser() + config.force_download = known_args.force_download + config.user_id = getpass.getuser() # We use the save_main_session option because one or more DoFn's in this # workflow rely on global context (e.g., a module imported at module level). save_main_session_args = ['--save_main_session'] + ['True' if save_main_session else 'False'] pipeline_options = PipelineOptions(pipeline_args + save_main_session_args) - client_name = config['parameters']['client'] + client_name = config.client store = None # will default to using FileSystems() - config['parameters']['force_download'] = known_args.force_download + config.force_download = known_args.force_download manifest = parse_manifest(known_args.manifest_location, pipeline_options.get_all_options()) if known_args.dry_run: client_name = 'fake' store = TempFileStore('dry_run') - config['parameters']['force_download'] = True + config.force_download = True manifest = NoOpManifest(Location('noop://dry-run')) if known_args.local_run: @@ -167,7 +167,7 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: client = CLIENTS[client_name](config) if num_requesters_per_key == -1: num_requesters_per_key = client.num_requests_per_key( - config.get('parameters', {}).get('dataset', "") + config.dataset ) logger.warning(f'By using {client_name} datasets, ' diff --git a/weather_dl/download_pipeline/pipeline_test.py b/weather_dl/download_pipeline/pipeline_test.py index 32487600..8c72a2d5 100644 --- a/weather_dl/download_pipeline/pipeline_test.py +++ b/weather_dl/download_pipeline/pipeline_test.py @@ -25,17 +25,10 @@ from .manifest import FirestoreManifest, Location, NoOpManifest, LocalManifest from .pipeline import run, PipelineArgs from .stores import TempFileStore, LocalFileStore +from .config import Config PATH_TO_CONFIG = os.path.join(os.path.dirname(list(weather_dl.__path__)[0]), 'configs', 'era5_example_config.cfg') -DEFAULT_ARGS = PipelineArgs( - known_args=argparse.Namespace(config=PATH_TO_CONFIG, - force_download=False, - dry_run=False, - local_run=False, - manifest_location='fs://downloader-manifest', - num_requests_per_key=-1), - pipeline_options=PipelineOptions('--save_main_session True'.split()), - config={ +CONFIG = { 'parameters': {'client': 'cds', 'dataset': 'reanalysis-era5-pressure-levels', 'target_path': 'gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}' @@ -50,7 +43,16 @@ 'year': ['2015', '2016', '2017'], 'month': ['01'], 'day': ['01', '15'], 'time': ['00:00', '06:00', '12:00', '18:00']} - }, + } +DEFAULT_ARGS = PipelineArgs( + known_args=argparse.Namespace(config=PATH_TO_CONFIG, + force_download=False, + dry_run=False, + local_run=False, + manifest_location='fs://downloader-manifest', + num_requests_per_key=-1), + pipeline_options=PipelineOptions('--save_main_session True'.split()), + config=Config.from_dict(CONFIG), client_name='cds', store=None, manifest=FirestoreManifest(Location('fs://downloader-manifest?projectId=None')), @@ -67,9 +69,10 @@ def default_args(parameters: t.Optional[t.Dict] = None, selection: t.Optional[t. if known_args is None: known_args = {} args = dataclasses.replace(DEFAULT_ARGS, **kwargs) - args.config = copy.deepcopy(args.config) - args.config['parameters'].update(parameters) - args.config['selection'].update(selection) + temp_config = copy.deepcopy(CONFIG) + temp_config['parameters'].update(parameters) + temp_config['selection'].update(selection) + args.config = Config.from_dict(temp_config) args.known_args = copy.deepcopy(args.known_args) for k, v in known_args.items(): setattr(args.known_args, k, v) diff --git a/weather_dl/setup.py b/weather_dl/setup.py index 9f1ba96e..1c2380ed 100644 --- a/weather_dl/setup.py +++ b/weather_dl/setup.py @@ -32,7 +32,7 @@ setup( name='download_pipeline', packages=find_packages(), - version='0.1.4', + version='0.1.5', author='Anthromets', author_email='anthromets-ecmwf@google.com', url='https://weather-tools.readthedocs.io/en/latest/weather_dl/',