-
Notifications
You must be signed in to change notification settings - Fork 106
/
modules.py
219 lines (187 loc) · 8.02 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
from datetime import datetime
from math import ceil, log10
from typing import Any, Iterator, List, Optional, Tuple
import pytorch_lightning as pl
from torch import save as torch_save
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import GridGeoSampler, RandomGeoSampler
from torchgeo.samplers.single import GeoSampler
from tqdm import tqdm
from vibe_core.data import CategoricalRaster, Raster
from .constants import CROP_INDICES
from .datasets import CDLMask, NDVIDataset
def save_chips_locally(dataloader: DataLoader, output_dir: str) -> None:
ndvi_path = os.path.join(output_dir, "ndvi")
os.makedirs(ndvi_path, exist_ok=True)
cdl_path = os.path.join(output_dir, "cdl")
os.makedirs(cdl_path, exist_ok=True)
batch = next(iter(dataloader))
batch_size = batch["image"].size(0)
zfill = ceil(log10(len(dataloader) * batch_size))
sample_idx = 1
for batch in tqdm(dataloader):
ndvi_batch, cdl_batch = batch["image"], batch["mask"]
zfill = ceil(log10(len(dataloader) * ndvi_batch.size(0)))
for i in range(ndvi_batch.size(0)):
torch_save(
ndvi_batch[i, :, :, :].clone(),
os.path.join(ndvi_path, f"{sample_idx + i:0>{zfill}}.pt"),
)
torch_save(
cdl_batch[i, :, :, :].clone(),
os.path.join(cdl_path, f"{sample_idx + i:0>{zfill}}.pt"),
)
sample_idx += ndvi_batch.size(0)
def year_bbox(bbox: BoundingBox) -> BoundingBox:
"""Method that set the bounding box's
mint and maxt to comprise the whole year
"""
year = datetime.fromtimestamp(bbox.mint).year
bounding_box = BoundingBox(
minx=bbox.minx,
maxx=bbox.maxx,
miny=bbox.miny,
maxy=bbox.maxy,
mint=datetime(year, 1, 1).timestamp(),
maxt=datetime(year + 1, 1, 1).timestamp() - 1,
)
return bounding_box
class YearRandomGeoSampler(RandomGeoSampler):
"""Samples elements from a region of interest randomly.
The main difference to RandomGeoSampler is that we explicitly
alter the time range to fetch all samples from a single year.
This is required for sampling a stacked NDVI from NDVIDataset
"""
def __iter__(self) -> Iterator[BoundingBox]:
for bbox in super().__iter__():
yield year_bbox(bbox)
class YearGridGeoSampler(GridGeoSampler):
"""Samples elements in a grid-like fashion.
The main difference to GridGeoSampler is that we explicitly
alter the time range to fetch all samples from a single year.
This is required for sampling a stacked NDVI from NDVIDataset
"""
def __iter__(self) -> Iterator[BoundingBox]:
for bbox in super().__iter__():
yield year_bbox(bbox)
class CropSegDataModule(pl.LightningDataModule):
def __init__(
self,
ndvi_rasters: List[Raster],
cdl_rasters: List[CategoricalRaster],
ndvi_stack_bands: int = 37,
img_size: Tuple[int, int] = (256, 256),
epoch_size: int = 1024,
batch_size: int = 16,
num_workers: int = 4,
val_ratio: float = 0.2,
positive_indices: List[int] = CROP_INDICES,
train_years: List[int] = [2020],
val_years: List[int] = [2020],
):
"""
Init a CropSegDataModule instance
Args:
ndvi_rasters: NDVI rasters generated by TerraVibes workflow
cdl_rasters: CDL maps downloaded by TerraVibes workflow
ndvi_stack_bands: how many daily NDVI maps will be stacked as training input.
img_size: tuple that defines the size of each chip that is fed to the network.
epoch_size: how many samples are sampled during training for one epoch.
batch_size: how many samples are fed to the network in a single batch.
num_workers: how many worker processes to use in the data loader.
val_ratio: how much of the data to separate for validation.
positive_indices: which CDL indices are considered as positive samples.
Crop types with a minimum of 1e5 pixels in the RoI are available
in the module `notebook_lib.constants`. You can combine multiple
constants by adding them (e.g., `constants.POTATO_INDEX + constants.CORN_INDEX`)
train_years: years used for training.
val_years: years used for validation.
"""
super().__init__()
self.ndvi_rasters = ndvi_rasters
self.cdl_rasters = cdl_rasters
self.img_size = img_size
self.batch_size = batch_size
self.num_workers = num_workers
self.epoch_size = epoch_size
self.val_ratio = val_ratio
self.positive_indices = positive_indices
self.train_years = train_years
self.val_years = val_years
self.years = list(set(self.train_years) | set(self.val_years))
self.ndvi_stack_bands = ndvi_stack_bands
def prepare_data(self) -> None:
# Skipping prepare_data as TerraVibes has already downloaded it
pass
def setup(self, stage: Optional[str] = None):
input_dataset = NDVIDataset(
self.ndvi_rasters,
self.ndvi_stack_bands,
)
target_dataset = CDLMask(
self.cdl_rasters,
positive_indices=self.positive_indices,
)
self.train_dataset = input_dataset & target_dataset # Intersection dataset
# Use the same dataset for training and validation, use different RoIs
self.val_dataset = self.train_dataset
self.test_dataset = self.train_dataset
def _get_dataloader(self, dataset: GeoDataset, sampler: GeoSampler) -> DataLoader:
return DataLoader(
dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
prefetch_factor=5 if self.num_workers else 2,
collate_fn=stack_samples,
)
def _get_split_roi(self, ref_dataset: GeoDataset):
minx, maxx, miny, maxy, _, _ = ref_dataset.bounds
width = ref_dataset.bounds.maxx - ref_dataset.bounds.minx
height = ref_dataset.bounds.maxy - ref_dataset.bounds.miny
if height > width:
train_x = maxx
val_x = minx
train_y = maxy - self.val_ratio * height
val_y = maxy - self.val_ratio * height
else:
train_x = maxx - self.val_ratio * width
val_x = maxx - self.val_ratio * width
train_y = maxy
val_y = miny
train_mint = datetime(min(self.train_years), 1, 1).timestamp()
train_maxt = datetime(max(self.train_years) + 1, 1, 1).timestamp() - 1
val_mint = datetime(min(self.val_years), 1, 1).timestamp()
val_maxt = datetime(max(self.val_years) + 1, 1, 1).timestamp() - 1
train_roi = BoundingBox(minx, train_x, miny, train_y, train_mint, train_maxt)
val_roi = BoundingBox(val_x, maxx, val_y, maxy, val_mint, val_maxt)
return train_roi, val_roi
def train_dataloader(self) -> DataLoader:
# Use the first dataset as index source
train_roi, _ = self._get_split_roi(self.train_dataset)
sampler = YearRandomGeoSampler(
self.train_dataset,
size=self.img_size,
length=self.epoch_size,
roi=train_roi,
)
return self._get_dataloader(self.train_dataset, sampler)
def val_dataloader(self) -> DataLoader:
_, val_roi = self._get_split_roi(self.val_dataset)
sampler = YearGridGeoSampler(
self.val_dataset,
size=self.img_size,
stride=self.img_size,
roi=val_roi,
)
return self._get_dataloader(self.val_dataset, sampler)
def test_dataloader(self) -> DataLoader:
return self.val_dataloader()
def predict_dataloader(self) -> DataLoader:
return self.val_dataloader()
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int):
batch["bbox"] = [(a for a in b) for b in batch["bbox"]]
return batch