-
Notifications
You must be signed in to change notification settings - Fork 28
/
data.py
250 lines (208 loc) · 8.91 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import numpy as np
import pandas as pd
import torch
from dateutil.relativedelta import relativedelta
from openmapflow.bbox import BBox
from openmapflow.constants import CLASS_PROB, END, EO_DATA, LAT, LON, MONTHS, START
from torch.utils.data import Dataset
from tqdm import tqdm
class CropDataset(Dataset):
def __init__(
self,
df: pd.DataFrame,
subset: str,
cache: bool,
upsample: bool,
target_bbox: BBox,
wandb_logger,
start_month: str = "April",
probability_threshold: float = 0.5,
input_months: int = 12,
normalizing_dict: Optional[Dict] = None,
up_to_year: Optional[int] = None,
) -> None:
df = df.copy()
if subset == "training" and up_to_year is not None:
df = df[pd.to_datetime(df[START]).dt.year <= up_to_year]
self.start_month_index = MONTHS.index(start_month)
self.input_months = input_months
df["is_crop"] = df[CLASS_PROB] >= probability_threshold
df["is_local"] = (
(df[LAT] >= target_bbox.min_lat)
& (df[LAT] <= target_bbox.max_lat)
& (df[LON] >= target_bbox.min_lon)
& (df[LON] <= target_bbox.max_lon)
)
if subset != "training":
outside_model_bbox = (~df["is_local"]).sum()
assert outside_model_bbox == 0, (
f"{outside_model_bbox} points outside model bbox: "
+ f"({df[LAT].min()}, {df[LON].min()}, {df[LAT].max()}, {df[LON].max()})"
)
local_crop = len(df[df["is_local"] & df["is_crop"]])
local_non_crop = len(df[df["is_local"] & ~df["is_crop"]])
local_difference = np.abs(local_crop - local_non_crop)
self.num_timesteps = self._compute_num_timesteps(df=df)
if wandb_logger:
to_log: Dict[str, Union[float, int]] = {}
if df["is_local"].any():
to_log[f"local_{subset}_original_size"] = len(df[df["is_local"]])
to_log[f"local_{subset}_crop_percentage"] = round(
local_crop / len(df[df["is_local"]]), 4
)
if not df["is_local"].all():
to_log[f"global_{subset}_original_size"] = len(df[~df["is_local"]])
to_log[f"global_{subset}_crop_percentage"] = round(
len(df[~df["is_local"] & df["is_crop"]]) / len(df[~df["is_local"]]), 4
)
if upsample:
to_log[f"{subset}_upsampled_size"] = len(df) + local_difference
wandb_logger.experiment.config.update(to_log)
if upsample:
if local_crop > local_non_crop:
arrow = "<-"
df = df.append(
df[df["is_local"] & ~df["is_crop"]].sample(
n=local_difference, replace=True, random_state=42
),
ignore_index=True,
)
elif local_crop < local_non_crop:
arrow = "->"
df = df.append(
df[df["is_local"] & df["is_crop"]].sample(
n=local_difference, replace=True, random_state=42
),
ignore_index=True,
)
print(f"Upsampling: local crop{arrow}non-crop: {local_crop}{arrow}{local_non_crop}")
self.normalizing_dict: Dict = (
normalizing_dict
if normalizing_dict
else self._calculate_normalizing_dict(df[EO_DATA].to_list())
)
self.df = df
# Set parameters needed for __getitem__
self.probability_threshold = probability_threshold
self.target_bbox = target_bbox
# Cache dataset if necessary
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None
self.weights: Optional[torch.Tensor] = None
self.cache = False
if cache:
self.x, self.y, self.weights = self.to_array()
self.cache = cache
def _compute_num_timesteps(self, df) -> List[int]:
df_start_date = pd.to_datetime(df[START]).apply(
lambda dt: dt.replace(month=self.start_month_index + 1)
)
df_candidate_end_date = df_start_date.apply(
lambda dt: dt + relativedelta(months=+self.input_months)
)
df_data_end_date = pd.to_datetime(df[END])
df_end_date = pd.DataFrame({"1": df_data_end_date, "2": df_candidate_end_date}).min(axis=1)
df["timesteps"] = (
((df_end_date - df_start_date) / np.timedelta64(1, "M")).round().astype(int)
)
timesteps = df["timesteps"].unique().tolist()
if len(timesteps) > 1:
timesteps_w_dataset = (
df[["dataset", "timesteps"]]
.groupby("timesteps")
.agg({"dataset": lambda ds: ",".join(ds.unique())})
)
print(
"WARNING: Datasets have different amounts of timesteps available. "
+ "Forecaster will be used to fill gaps."
+ f"\n{timesteps_w_dataset}"
)
return timesteps
@staticmethod
def _update_normalizing_values(
norm_dict: Dict[str, Union[int, Any]], array: np.ndarray
) -> None:
# given an input array of shape [timesteps, bands]
# update the normalizing dict
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
# https://www.johndcook.com/blog/standard_deviation/
if array is None:
raise ValueError("Array is None")
# initialize
if "mean" not in norm_dict:
num_bands = array.shape[1]
norm_dict["mean"] = np.zeros(num_bands)
norm_dict["M2"] = np.zeros(num_bands)
for time_idx in range(array.shape[0]):
norm_dict["n"] += 1
x = array[time_idx, :]
delta = x - norm_dict["mean"]
norm_dict["mean"] += delta / norm_dict["n"]
norm_dict["M2"] += delta * (x - norm_dict["mean"])
@staticmethod
def _calculate_normalizing_dict(
eo_data_list: List[np.ndarray],
) -> Dict[str, Union[int, np.ndarray]]:
norm_dict_interim = {"n": 0}
for eo_data in tqdm(eo_data_list, desc="Calculating normalizing_dict"):
CropDataset._update_normalizing_values(norm_dict_interim, eo_data)
variance = norm_dict_interim["M2"] / (norm_dict_interim["n"] - 1)
std = np.sqrt(variance)
return {"mean": norm_dict_interim["mean"], "std": std}
def _normalize(self, array: np.ndarray) -> np.ndarray:
if self.normalizing_dict is None:
return array
else:
return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]
def __len__(self) -> int:
return len(self.df)
def to_array(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.x is not None:
assert self.y is not None
assert self.weights is not None
return self.x, self.y, self.weights
else:
x_list: List[torch.Tensor] = []
y_list: List[torch.Tensor] = []
weight_list: List[torch.Tensor] = []
print("Loading data into memory")
for i in tqdm(range(len(self)), desc="Caching files"):
x, y, weight = self[i]
x_list.append(x)
y_list.append(y)
weight_list.append(weight)
return torch.stack(x_list), torch.stack(y_list), torch.stack(weight_list)
@property
def num_input_features(self) -> int:
# assumes the first value in the tuple is x
assert len(self.df) > 0, "No files to load!"
output = self[0]
if isinstance(output, tuple):
return output[0].shape[1]
else:
return output.shape[1]
@property
def num_output_classes(self) -> Tuple[int, int]:
return 1, 1
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if (self.cache) & (self.x is not None):
# if we upsample, the caching might not have happened yet
return (
cast(torch.Tensor, self.x)[index],
cast(torch.Tensor, self.y)[index],
cast(torch.Tensor, self.weights)[index],
)
row = self.df.iloc[index]
x = row[EO_DATA][self.start_month_index : self.start_month_index + self.input_months]
x = self._normalize(x)
# If x is a partial time series, pad it to full length
if x.shape[0] < self.input_months:
x = np.concatenate([x, np.full((self.input_months - x.shape[0], x.shape[1]), np.nan)])
crop_int = int(row["is_crop"])
is_global = int(not row["is_local"])
return (
torch.from_numpy(x).float(),
torch.tensor(crop_int).float(),
torch.tensor(is_global).float(),
)