Skip to content

Commit

Permalink
refactor DataConfig and implement DataConfig.build_dataset() (#2023)
Browse files Browse the repository at this point in the history
Co-authored-by: Adeel Hassan <ahassan@element84.com>
  • Loading branch information
AdeelH and AdeelH committed Jan 4, 2024
1 parent 14d2d98 commit 2516e62
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from enum import Enum
import logging

Expand All @@ -15,6 +15,9 @@
ClassificationRandomWindowGeoDataset)
from rastervision.pytorch_learner.utils import adjust_conv_channels

if TYPE_CHECKING:
from rastervision.core.data import SceneConfig

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -53,11 +56,13 @@ class ClassificationGeoDataConfig(ClassificationDataConfig, GeoDataConfig):
See :mod:`rastervision.pytorch_learner.dataset.classification_dataset`.
"""

def build_scenes(self, tmp_dir: str):
for s in self.scene_dataset.all_scenes:
def build_scenes(self,
scene_configs: Iterable['SceneConfig'],
tmp_dir: Optional[str] = None):
for s in scene_configs:
if s.label_source is not None:
s.label_source.lazy = True
return super().build_scenes(tmp_dir=tmp_dir)
return super().build_scenes(scene_configs, tmp_dir=tmp_dir)

def scene_to_dataset(self,
scene: Scene,
Expand Down

0 comments on commit 2516e62

Please sign in to comment.