-
Notifications
You must be signed in to change notification settings - Fork 2
/
models.py
248 lines (216 loc) · 8.22 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import bisect
import copy
from datetime import datetime
import dask
import itertools as it
from dateutil import parser
from typing import Union
import zarr
import numpy as np
from dask.utils import memory_repr
import dask.array as da
import xarray as xr
from xarray.core.utils import either_dict_or_kwargs
TIME_DEFAULTS = {
'units': 'seconds since 1900-01-01 00:00:00',
'calendar': 'gregorian',
}
class OOIDataset:
def __init__(
self,
dataset_id,
bucket_name="ooi-data",
storage_options={'anon': True},
):
self.dataset_id = dataset_id
self.bucket_name = bucket_name
self.storage_options = storage_options
self.dimensions = []
self.variables = {}
self.global_attributes = {}
self.dataset = None
self._total_size = None
self._total_size_repr = None
self._zarr_group = None
self._dataset_dict = {"variables": {}, "dims": []}
# Private attributes
self.__default_time_units = 'seconds since 1900-01-01 00:00:00'
self.__default_calendar = 'gregorian'
self.__time_filter = []
self._open_zarr()
self._parse_zarr_group()
self._set_variables()
def __repr__(self):
text_arr = [f"<{self.dataset_id}: {self._total_size_repr}>"]
text_arr.append(f"Dimensions: ({', '.join(self.dimensions)})")
variables_arr = "\n ".join([name for name in self.variables.keys()])
text_arr.append(f"Data variables: \n {variables_arr}")
return "\n".join(text_arr)
def __getitem__(self, item):
new_self = copy.deepcopy(self)
new_variables = {}
for k, v in new_self.variables.items():
if k in item:
new_variables[k] = v
else:
delattr(new_self, k)
new_self.variables = new_variables
return new_self
def _open_zarr(self):
self._zarr_group = zarr.open_group(
store=f's3://{self.bucket_name}/{self.dataset_id}',
mode="r+",
storage_options=self.storage_options,
)
self._total_size = np.sum(
[arr.nbytes for _, arr in self._zarr_group.items()]
)
self._total_size_repr = memory_repr(self._total_size)
def _parse_zarr_group(self):
all_dims = []
for k in self._zarr_group.array_keys():
arr = self._zarr_group[k]
dims = arr.attrs['_ARRAY_DIMENSIONS']
attrs = arr.attrs.asdict()
attrs.pop('_ARRAY_DIMENSIONS')
self._dataset_dict['variables'][k] = xr.DataArray(
data=da.from_zarr(arr).rechunk(),
dims=dims,
name=k,
attrs=attrs,
)
all_dims.append(dims)
self._dataset_dict['dims'] = list(
set(it.chain.from_iterable(all_dims))
)
self.dimensions = self._dataset_dict['dims']
self.variables = self._dataset_dict['variables']
self.global_attributes = self._zarr_group.attrs.asdict()
def _set_variables(self):
for name, data_array in self.variables.items():
setattr(self, name, data_array)
def _get_dim_indexers(self, indexers) -> dict:
# Retrieve dimension indexers
pos_indexes = {}
for dim in self._dataset_dict['dims']:
dim_arr = self._dataset_dict['variables'][dim].data
if dim in indexers:
if indexers[dim] is not None:
start, end = indexers[dim]
pos_indexes[dim] = da.where(
(dim_arr >= start) & (dim_arr <= end)
)[0].compute()
else:
pos_indexes[dim] = None
return pos_indexes
def _create_dataset_dict(self, pos_indexes) -> dict:
data_vars = {}
# Get data arrays
for k, v in self.variables.items():
key = {
dim: slice(pos_indexes[dim][0], pos_indexes[dim][-1])
if dim in pos_indexes
else slice(None)
for dim in v.dims
}
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
data_vars[k] = v.isel(**key)
return data_vars
def _in_time_range(
self, time_filter: Union[list, tuple], time_range: Union[list, tuple]
) -> bool:
"""Checks whether the time filter is within the data time range"""
in_range_idx = 1
sorted_time_range = sorted(time_range)
in_range_list = [
bisect.bisect(sorted_time_range, t) for t in sorted(time_filter)
]
return any(i for i in in_range_list if i == in_range_idx)
def _create_dataset(self, data_vars: dict) -> xr.Dataset:
data_vars = {
k: self._set_time_attrs(
xr.apply_ufunc(
xr.coding.times.decode_cf_datetime,
v,
keep_attrs=True,
dask="parallelized",
kwargs={
'units': v.attrs.get('units', TIME_DEFAULTS['units']),
'calendar': v.attrs.get(
'calendar', TIME_DEFAULTS['calendar']
),
},
)
)
if k == 'time'
else v
for k, v in data_vars.items()
}
ds = xr.Dataset(data_vars)
# new_attrs = self.global_attributes.copy()
# new_attrs['time_coverage_start'] = ds.time.data[0]
# new_attrs['time_coverage_end'] = ds.time.data[-1]
# ds.attrs = new_attrs
self.dataset = ds
@staticmethod
def _set_time_attrs(da):
new_attrs = da.attrs.copy()
if 'units' in new_attrs:
new_attrs.pop('units')
if 'calendar' in new_attrs:
new_attrs.pop('calendar')
da.attrs = new_attrs
return da
def reset(self):
self.variables = self._dataset_dict['variables']
self._set_variables()
return self
def _time_range_check(
self, arr: zarr.Array, start_dt: datetime, end_dt: datetime
) -> bool:
"""Performs an initial time range check for data"""
calendar = arr.attrs.get('calendar', TIME_DEFAULTS['calendar'])
units = arr.attrs.get('units', TIME_DEFAULTS['units'])
time_range = xr.coding.times.decode_cf_datetime(
[arr[0], arr[-1]], units, calendar
)
time_filter = np.array(
[start_dt, end_dt],
dtype='datetime64[ns]',
)
return self._in_time_range(
time_filter=time_filter, time_range=time_range
)
def sel(self, indexers: dict = None, **indexers_kwargs):
# TODO: Figure out how to handle one indexer instead of start, end
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
for k, v in indexers.items():
if k not in self.dimensions:
raise ValueError(
f"'{k}' is not a dimension in this data stream."
)
if k == "time":
if isinstance(v, slice):
start_dt = parser.parse(v.start)
end_dt = parser.parse(v.stop)
else:
start_dt, end_dt = (parser.parse(value) for value in v)
arr = getattr(self, k)
time_units = self.__default_time_units
calendar = self.__default_calendar
if 'units' in arr.attrs:
time_units = arr.attrs['units']
if 'calendar' in arr.attrs:
calendar = arr.attrs['calendar']
self.__time_filter, _, _ = xr.coding.times.encode_cf_datetime(
[start_dt, end_dt], time_units, calendar
)
# Peforms initial time range check from time array of data
in_time_range = self._time_range_check(arr, start_dt, end_dt)
if not in_time_range:
return self
indexers[k] = self.__time_filter
pos_indexes = self._get_dim_indexers(indexers)
data_vars = self._create_dataset_dict(pos_indexes)
self._create_dataset(data_vars)
return self