-
-
Notifications
You must be signed in to change notification settings - Fork 11
/
model.py
765 lines (619 loc) · 26.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
""" Configuration model for the dataset.
All paths must include the protocol prefix. For local files,
it's sufficient to just start with a '/'. For aws, start with 's3://',
for gcp start with 'gs://'.
This file is mostly about _configuring_ the DataSources.
Separate Pydantic models in
`nowcasting_dataset/data_sources/<data_source_name>/<data_source_name>_model.py`
are used to validate the values of the data itself.
"""
import logging
from datetime import datetime
from typing import Dict, List, Optional, Union
import git
import numpy as np
from nowcasting_datamodel.models.pv import providers, pv_output, solar_sheffield_passiv
from pathy import Pathy
from pydantic import BaseModel, Field, root_validator, validator
# nowcasting_dataset imports
from ocf_datapipes.utils.consts import (
AWOS_VARIABLE_NAMES,
NWP_PROVIDERS,
NWP_VARIABLE_NAMES,
RSS_VARIABLE_NAMES,
)
IMAGE_SIZE_PIXELS = 64
IMAGE_SIZE_PIXELS_FIELD = Field(
IMAGE_SIZE_PIXELS, description="The number of pixels of the region of interest."
)
METERS_PER_PIXEL_FIELD = Field(2000, description="The number of meters per pixel.")
METERS_PER_ROI = Field(128_000, description="The number of meters of region of interest.")
DEFAULT_N_GSP_PER_EXAMPLE = 32
DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE = 2048
logger = logging.getLogger(__name__)
# add SV to list of providers
providers.append("SV")
class Base(BaseModel):
"""Pydantic Base model where no extras can be added"""
class Config:
"""config class"""
extra = "forbid" # forbid use of extra kwargs
class General(Base):
"""General pydantic model"""
name: str = Field("example", description="The name of this configuration file.")
description: str = Field(
"example configuration", description="Description of this configuration file"
)
class Git(Base):
"""Git model"""
hash: str = Field(
..., description="The git hash of nowcasting_dataset when a dataset is created."
)
message: str = Field(..., description="The git message for when a dataset is created.")
committed_date: datetime = Field(
..., description="The git datestamp for when a dataset is created."
)
class DataSourceMixin(Base):
"""Mixin class, to add forecast and history minutes"""
forecast_minutes: int = Field(
None,
ge=0,
description="how many minutes to forecast in the future. "
"If set to None, the value is defaulted to InputData.default_forecast_minutes",
)
history_minutes: int = Field(
None,
ge=0,
description="how many historic minutes to use. "
"If set to None, the value is defaulted to InputData.default_history_minutes",
)
log_level: str = Field(
"DEBUG",
description="The logging level for this data source. T"
"his is the default value and can be set in each data source",
)
@property
def seq_length_30_minutes(self):
"""How many steps are there in 30 minute datasets"""
return int(np.ceil((self.history_minutes + self.forecast_minutes) / 30 + 1))
@property
def seq_length_5_minutes(self):
"""How many steps are there in 5 minute datasets"""
return int(np.ceil((self.history_minutes + self.forecast_minutes) / 5 + 1))
@property
def seq_length_60_minutes(self):
"""How many steps are there in 60 minute datasets"""
return int(np.ceil((self.history_minutes + self.forecast_minutes) / 60 + 1))
@property
def history_seq_length_5_minutes(self):
"""How many historical steps are there in 5 minute datasets"""
return int(np.ceil(self.history_minutes / 5))
@property
def history_seq_length_30_minutes(self):
"""How many historical steps are there in 30 minute datasets"""
return int(np.ceil(self.history_minutes / 30))
@property
def history_seq_length_60_minutes(self):
"""How many historical steps are there in 60 minute datasets"""
return int(np.ceil(self.history_minutes / 60))
class TimeResolutionMixin(Base):
"""Time resolution mix in"""
# TODO: Issue #584: Rename to `sample_period_minutes`
time_resolution_minutes: int = Field(
5,
description="The temporal resolution (in minutes) of the satellite images."
"Note that this needs to be divisible by 5.",
)
@validator("time_resolution_minutes")
def forecast_minutes_divide_by_5(cls, v):
"""Validate 'forecast_minutes'"""
assert v % 5 == 0, f"The time resolution ({v}) is not divisible by 5"
return v
class XYDimensionalNames(Base):
"""X and Y dimensions names"""
x_dim_name: str = Field(
"x_osgb",
description="The x dimension name. Should be either x_osgb or longitude",
)
y_dim_name: str = Field(
"y_osgb",
description="The y dimension name. Should be either y_osgb or latitude",
)
@root_validator
def check_x_y_dimension_names(cls, values):
"""Check that the x and y dimeision pair up correctly"""
x_dim_name = values["x_dim_name"]
y_dim_name = values["y_dim_name"]
assert x_dim_name in ["x_osgb", "longitude", "x"]
assert y_dim_name in ["y_osgb", "latitude", "y"]
if x_dim_name == "x":
assert y_dim_name == "y"
if x_dim_name == "x_osgb":
assert y_dim_name == "y_osgb"
if x_dim_name == "longitude":
assert y_dim_name == "latitude"
return values
class StartEndDatetimeMixin(Base):
"""Mixin class to add start and end date"""
start_datetime: datetime = Field(
datetime(2020, 1, 1),
description="Load date from data sources from this date. "
"If None, this will get overwritten by InputData.start_date. ",
)
end_datetime: datetime = Field(
datetime(2021, 9, 1),
description="Load date from data sources up to this date. "
"If None, this will get overwritten by InputData.start_date. ",
)
@root_validator
def check_start_and_end_datetime(cls, values):
"""
Make sure start datetime is before end datetime
"""
start_datetime = values["start_datetime"]
end_datetime = values["end_datetime"]
# check start datetime is less than end datetime
if start_datetime >= end_datetime:
message = (
f"Start datetime ({start_datetime}) "
f"should be less than end datetime ({end_datetime})"
)
logger.error(message)
assert Exception(message)
return values
class PVFiles(BaseModel):
"""Model to hold pv file and metadata file"""
pv_filename: str = Field(
"gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc",
description="The NetCDF files holding the solar PV power timeseries.",
)
pv_metadata_filename: str = Field(
"gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv",
description="Tthe CSV files describing each PV system.",
)
inferred_metadata_filename: str = Field(
None,
description="Tthe CSV files describing inferred PV metadata for each system.",
)
label: str = Field(pv_output, description="Label of where the pv data came from")
@validator("label")
def v_label0(cls, v):
"""Validate 'label'"""
if v not in providers:
message = f"provider {v} not in {providers}"
logger.error(message)
raise Exception(message)
return v
class PV(DataSourceMixin, StartEndDatetimeMixin, TimeResolutionMixin, XYDimensionalNames):
"""PV configuration model"""
pv_files_groups: List[PVFiles] = [PVFiles()]
n_pv_systems_per_example: int = Field(
DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE,
description="The number of PV systems samples per example. "
"If there are less in the ROI then the data is padded with zeros. ",
)
pv_image_size_meters_height: int = METERS_PER_ROI
pv_image_size_meters_width: int = METERS_PER_ROI
get_center: bool = Field(
False,
description="If the batches are centered on one PV system (or not). "
"The other options is to have one GSP at the center of a batch. "
"Typically, get_center would be set to true if and only if "
"PVDataSource is used to define the geospatial positions of each example.",
)
pv_filename: str = Field(
None,
description="The NetCDF files holding the solar PV power timeseries.",
)
pv_metadata_filename: str = Field(
None,
description="Tthe CSV files describing each PV system.",
)
pv_ml_ids: List[int] = Field(
None,
description="List of the ML IDs of the PV systems you'd like to filter to.",
)
is_live: bool = Field(
False, description="Option if to use live data from the nowcasting pv database"
)
live_interpolate_minutes: int = Field(
30, description="The number of minutes we allow PV data to interpolate"
)
live_load_extra_minutes: int = Field(
0,
description="The number of extra minutes in the past we should load. Then the recent "
"values can be interpolated, and the extra minutes removed. This is "
"because some live data takes ~1 hour to come in.",
)
@classmethod
def model_validation(cls, v):
"""Move old way of storing filenames to new way"""
if (v.pv_filename is not None) and (v.pv_metadata_filename is not None):
logger.warning(
"Loading pv files the old way, and moving them the new way. "
"Please update configuration file"
)
label = pv_output if "pvoutput" in v.pv_filename.lower() else solar_sheffield_passiv
pv_file = PVFiles(
pv_filename=v.pv_filename, pv_metadata_filename=v.pv_metadata_filename, label=label
)
v.pv_files_groups = [pv_file]
v.pv_filename = None
v.pv_metadata_filename = None
return v
class Sensor(DataSourceMixin, StartEndDatetimeMixin, TimeResolutionMixin, XYDimensionalNames):
"""PV configuration model"""
sensor_image_size_meters_height: int = METERS_PER_ROI
sensor_image_size_meters_width: int = METERS_PER_ROI
get_center: bool = Field(
False,
description="If the batches are centered on one Sensor system (or not). "
"The other options is to have one GSP at the center of a batch. "
"Typically, get_center would be set to true if and only if "
"SensorDataSource is used to define the geospatial positions of each example.",
)
sensor_filename: str = Field(
None,
description="The NetCDF files holding the Sensor timeseries.",
)
sensor_ml_ids: List[int] = Field(
None,
description="List of the ML IDs of the PV systems you'd like to filter to.",
)
is_live: bool = Field(
False, description="Option if to use live data from the nowcasting pv database"
)
live_interpolate_minutes: int = Field(
30, description="The number of minutes we allow PV data to interpolate"
)
live_load_extra_minutes: int = Field(
0,
description="The number of extra minutes in the past we should load. Then the recent "
"values can be interpolated, and the extra minutes removed. This is "
"because some live data takes ~1 hour to come in.",
)
sensor_variables: tuple = Field(
AWOS_VARIABLE_NAMES, description="the sensor variables that are used"
)
class Satellite(DataSourceMixin, TimeResolutionMixin):
"""Satellite configuration model"""
satellite_zarr_path: Union[str, tuple[str], list[str]] = Field(
"gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr", # noqa: E501
description="The path or list of paths which hold the satellite zarr.",
)
satellite_channels: tuple = Field(
RSS_VARIABLE_NAMES[1:], description="the satellite channels that are used"
)
satellite_image_size_pixels_height: int = Field(
IMAGE_SIZE_PIXELS_FIELD.default // 3,
description="The number of pixels of the height of the region of interest"
" for non-HRV satellite channels.",
)
satellite_image_size_pixels_width: int = Field(
IMAGE_SIZE_PIXELS_FIELD.default // 3,
description="The number of pixels of the width of the region "
"of interest for non-HRV satellite channels.",
)
satellite_meters_per_pixel: int = Field(
METERS_PER_PIXEL_FIELD.default * 3,
description="The number of meters per pixel for non-HRV satellite channels.",
)
is_live: bool = Field(
False,
description="Option if to use live data from the satelite consumer. "
"This is useful becasuse the data is about ~30 mins behind, "
"so we need to expect that",
)
live_delay_minutes: int = Field(
30, description="The expected delay in minutes of the satellite data"
)
class HRVSatellite(DataSourceMixin, TimeResolutionMixin):
"""Satellite configuration model for HRV data"""
hrvsatellite_zarr_path: Union[str, tuple[str], list[str]] = Field(
"gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr", # noqa: E501
description="The path or list of paths which hold the satellite zarr.",
)
hrvsatellite_channels: tuple = Field(
RSS_VARIABLE_NAMES[0:1], description="the satellite channels that are used"
)
# HRV is 3x the resolution, so to cover the same area, its 1/3 the meters per pixel and 3
# time the number of pixels
hrvsatellite_image_size_pixels_height: int = IMAGE_SIZE_PIXELS_FIELD
hrvsatellite_image_size_pixels_width: int = IMAGE_SIZE_PIXELS_FIELD
hrvsatellite_meters_per_pixel: int = METERS_PER_PIXEL_FIELD
is_live: bool = Field(
False,
description="Option if to use live data from the satelite consumer. "
"This is useful becasuse the data is about ~30 mins behind, "
"so we need to expect that",
)
live_delay_minutes: int = Field(
30, description="The expected delay in minutes of the satellite data"
)
class OpticalFlow(DataSourceMixin, TimeResolutionMixin):
"""Optical Flow configuration model"""
opticalflow_zarr_path: str = Field(
"",
description=(
"The satellite Zarr data to use. If in doubt, use the same value as"
" satellite.satellite_zarr_path."
),
)
# history_minutes, set in DataSourceMixin.
# Duration of historical data to use when computing the optical flow field.
# For example, set to 5 to use just two images: the t-1 and t0 images. Set to 10 to
# compute the optical flow field separately for the image pairs (t-2, t-1), and
# (t-1, t0) and to use the mean optical flow field.
# forecast_minutes, set in DataSourceMixin.
# Duration of the optical flow predictions.
opticalflow_meters_per_pixel: int = METERS_PER_PIXEL_FIELD
opticalflow_input_image_size_pixels_height: int = Field(
IMAGE_SIZE_PIXELS * 2,
description=(
"The *input* image height (i.e. the image size to load off disk)."
" This should be larger than output_image_size_pixels to provide sufficient border to"
" mean that, even after the image has been flowed, all edges of the output image are"
" real pixels values, and not NaNs."
),
)
opticalflow_output_image_size_pixels_height: int = Field(
IMAGE_SIZE_PIXELS,
description=(
"The height of the images after optical flow has been applied. The output image is a"
" center-crop of the input image, after it has been flowed."
),
)
opticalflow_input_image_size_pixels_width: int = Field(
IMAGE_SIZE_PIXELS * 2,
description=(
"The *input* image width (i.e. the image size to load off disk)."
" This should be larger than output_image_size_pixels to provide sufficient border to"
" mean that, even after the image has been flowed, all edges of the output image are"
" real pixels values, and not NaNs."
),
)
opticalflow_output_image_size_pixels_width: int = Field(
IMAGE_SIZE_PIXELS,
description=(
"The width of the images after optical flow has been applied. The output image is a"
" center-crop of the input image, after it has been flowed."
),
)
opticalflow_channels: tuple = Field(
RSS_VARIABLE_NAMES[1:], description="the satellite channels that are used"
)
opticalflow_source_data_source_class_name: str = Field(
"SatelliteDataSource",
description=(
"Either SatelliteDataSource or HRVSatelliteDataSource."
" The name of the DataSource that will load the satellite images."
),
)
class NWP(DataSourceMixin, StartEndDatetimeMixin, TimeResolutionMixin, XYDimensionalNames):
"""NWP configuration model"""
# TODO change to nwp_path, as it could be a netcdf now.
# https://github.com/openclimatefix/nowcasting_dataset/issues/582
nwp_zarr_path: Union[str, tuple[str], list[str]] = Field(
"gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr", # noqa: E501
description="The path which holds the NWP zarr.",
)
nwp_channels: tuple = Field(
NWP_VARIABLE_NAMES["ukv"], description="the channels used in the nwp data"
)
nwp_image_size_pixels_height: int = IMAGE_SIZE_PIXELS_FIELD
nwp_image_size_pixels_width: int = IMAGE_SIZE_PIXELS_FIELD
nwp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD
nwp_provider: str = Field("ukv", description="The provider of the NWP data")
index_by_id: bool = Field(
False, description="If the NWP data has an id coordinate, not x and y."
)
@validator("nwp_provider")
def validate_nwp_provider(cls, v):
"""Validate 'nwp_provider'"""
if v.lower() not in NWP_PROVIDERS:
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
logger.warning(message)
assert Exception(message)
return v
class MultiNWP(Base):
"""Configuration for multiple NWPs"""
__root__: Dict[str, NWP]
def __getattr__(self, item):
return self.__root__[item]
def __getitem__(self, item):
return self.__root__[item]
def __len__(self):
return len(self.__root__)
def __iter__(self):
return iter(self.__root__)
def keys(self):
"""Returns dictionary-like keys"""
return self.__root__.keys()
def items(self):
"""Returns dictionary-like items"""
return self.__root__.items()
class GSP(DataSourceMixin, StartEndDatetimeMixin, TimeResolutionMixin):
"""GSP configuration model"""
gsp_zarr_path: str = Field("gs://solar-pv-nowcasting-data/PV/GSP/v2/pv_gsp.zarr")
n_gsp_per_example: int = Field(
DEFAULT_N_GSP_PER_EXAMPLE,
description="The number of GSP samples per example. "
"If there are less in the ROI then the data is padded with zeros. ",
)
gsp_image_size_pixels_height: int = IMAGE_SIZE_PIXELS_FIELD
gsp_image_size_pixels_width: int = IMAGE_SIZE_PIXELS_FIELD
gsp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD
metadata_only: bool = Field(False, description="Option to only load metadata.")
is_live: bool = Field(
False, description="Option if to use live data from the nowcasting GSP/Forecast database"
)
live_interpolate_minutes: int = Field(
60, description="The number of minutes we allow GSP data to be interpolated"
)
live_load_extra_minutes: int = Field(
60,
description="The number of extra minutes in the past we should load. Then the recent "
"values can be interpolated, and the extra minutes removed. This is "
"because some live data takes ~1 hour to come in.",
)
@validator("history_minutes")
def history_minutes_divide_by_30(cls, v):
"""Validate 'history_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
return v
@validator("forecast_minutes")
def forecast_minutes_divide_by_30(cls, v):
"""Validate 'forecast_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
return v
class Topographic(DataSourceMixin):
"""Topographic configuration model"""
topographic_filename: str = Field(
"gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif",
description="Path to the GeoTIFF Topographic data source",
)
topographic_image_size_pixels_height: int = IMAGE_SIZE_PIXELS_FIELD
topographic_image_size_pixels_width: int = IMAGE_SIZE_PIXELS_FIELD
topographic_meters_per_pixel: int = METERS_PER_PIXEL_FIELD
class Sun(DataSourceMixin):
"""Sun configuration model"""
sun_zarr_path: str = Field(
"gs://solar-pv-nowcasting-data/Sun/v1/sun.zarr/",
description="Path to the Sun data source i.e Azimuth and Elevation",
)
load_live: bool = Field(
False, description="Option to load sun data on the fly, rather than from file"
)
elevation_limit: int = Field(
10,
description="The limit to the elevations for examples. "
"Datetimes below this limits will be ignored",
)
class InputData(Base):
"""
Input data model.
"""
pv: Optional[PV] = None
satellite: Optional[Satellite] = None
hrvsatellite: Optional[HRVSatellite] = None
opticalflow: Optional[OpticalFlow] = None
nwp: Optional[MultiNWP] = None
gsp: Optional[GSP] = None
topographic: Optional[Topographic] = None
sun: Optional[Sun] = None
sensor: Optional[Sensor] = None
default_forecast_minutes: int = Field(
60,
ge=0,
description="how many minutes to forecast in the future. "
"This sets the default for all the data sources if they are not set.",
)
default_history_minutes: int = Field(
30,
ge=0,
description="how many historic minutes are used. "
"This sets the default for all the data sources if they are not set.",
)
data_source_which_defines_geospatial_locations: str = Field(
"gsp",
description=(
"The name of the DataSource which will define the geospatial position of each example."
),
)
@property
def default_seq_length_5_minutes(self):
"""How many steps are there in 5 minute datasets"""
return int((self.default_history_minutes + self.default_forecast_minutes) / 5 + 1)
@root_validator
def set_forecast_and_history_minutes(cls, values):
"""
Set default history and forecast values, if needed.
Run through the different data sources and if the forecast or history minutes are not set,
then set them to the default values
"""
# It would be much better to use nowcasting_dataset.data_sources.ALL_DATA_SOURCE_NAMES,
# but that causes a circular import.
ALL_DATA_SOURCE_NAMES = (
"pv",
"hrvsatellite",
"satellite",
# "nwp", # nwp is treated separately
"gsp",
"topographic",
"sun",
"opticalflow",
"sensor",
)
enabled_data_sources = [
data_source_name
for data_source_name in ALL_DATA_SOURCE_NAMES
if values[data_source_name] is not None
]
for data_source_name in enabled_data_sources:
if values[data_source_name].forecast_minutes is None:
values[data_source_name].forecast_minutes = values["default_forecast_minutes"]
if values[data_source_name].history_minutes is None:
values[data_source_name].history_minutes = values["default_history_minutes"]
if values["nwp"] is not None:
for k in values["nwp"].keys():
if values["nwp"][k].forecast_minutes is None:
values["nwp"][k].forecast_minutes = values["default_forecast_minutes"]
if values["nwp"][k].history_minutes is None:
values["nwp"][k].history_minutes = values["default_history_minutes"]
return values
@classmethod
def set_all_to_defaults(cls):
"""Returns an InputData instance with all fields set to their default values.
Used for unittests.
"""
return cls(
pv=PV(),
satellite=Satellite(),
hrvsatellite=HRVSatellite(),
nwp=dict(UKV=NWP()),
gsp=GSP(),
topographic=Topographic(),
sun=Sun(),
opticalflow=OpticalFlow(),
sensor=Sensor(),
)
class Configuration(Base):
"""Configuration model for the dataset"""
general: General = General()
input_data: InputData = InputData()
git: Optional[Git] = None
def set_base_path(self, base_path: str):
"""Append base_path to all paths. Mostly used for testing."""
base_path = Pathy(base_path)
path_attrs = [
"pv.pv_filename",
"pv.pv_metadata_filename",
"satellite.satellite_zarr_path",
"hrvsatellite.hrvsatellite_zarr_path",
"nwp.nwp_zarr_path",
"gsp.gsp_zarr_path",
"sensor.sensor_filename",
]
for cls_and_attr_name in path_attrs:
cls_name, attribute = cls_and_attr_name.split(".")
cls = getattr(self.input_data, cls_name)
path = getattr(getattr(self.input_data, cls_name), attribute)
path = base_path / path
setattr(cls, attribute, path)
setattr(self.input_data, cls_name, cls)
def set_git_commit(configuration: Configuration):
"""
Set the git information in the configuration file
Args:
configuration: configuration object
Returns: configuration object with git information
"""
repo = git.Repo(search_parent_directories=True)
git.refresh("/usr/bin/git")
git_details = Git(
hash=repo.head.object.hexsha,
committed_date=datetime.fromtimestamp(repo.head.object.committed_date),
message=repo.head.object.message,
)
configuration.git = git_details
return configuration