Skip to content

Commit

Permalink
fix circular import problems (#2059)
Browse files Browse the repository at this point in the history
Co-authored-by: Adeel Hassan <ahassan@element84.com>
  • Loading branch information
AdeelH and AdeelH committed Feb 12, 2024
1 parent cf95e94 commit e6510d7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from rastervision.pytorch_backend.pytorch_learner_backend import (
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_cc
from rastervision.pytorch_learner import (
ClassificationGeoDataConfig, ClassificationSlidingWindowGeoDataset)
from rastervision.pytorch_learner.dataset import (
ClassificationSlidingWindowGeoDataset)
from rastervision.core.data import ChipClassificationLabels

if TYPE_CHECKING:
import numpy as np
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions
from rastervision.pytorch_learner import ClassificationGeoDataConfig


class PyTorchChipClassificationSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -89,7 +90,8 @@ def predict_scene(self, scene: 'Scene', predict_options: 'PredictOptions'

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> ClassificationGeoDataConfig:
chip_options: 'ChipOptions') -> 'ClassificationGeoDataConfig':
from rastervision.pytorch_learner import (ClassificationGeoDataConfig)
data_config = ClassificationGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from rastervision.core.data.utils.misc import save_img
from rastervision.core.data_sample import DataSample
from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.learner_config import DataConfig

if TYPE_CHECKING:
from torch.utils.data import Dataset
from rastervision.core.data import ClassConfig, DatasetConfig, Scene
from rastervision.core.rv_pipeline import RVPipelineConfig, ChipOptions
from rastervision.pytorch_learner.learner_config import LearnerConfig
from rastervision.pytorch_learner import DataConfig, LearnerConfig

SPLITS = ['train', 'valid', 'test']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from rastervision.pytorch_backend.utils import chip_collate_fn_od
from rastervision.pytorch_learner.dataset import (
ObjectDetectionSlidingWindowGeoDataset)
from rastervision.pytorch_learner.object_detection_learner_config import (
ObjectDetectionGeoDataConfig)

if TYPE_CHECKING:
from rastervision.core.data import DatasetConfig, Scene
from rastervision.core.rv_pipeline import (ChipOptions,
ObjectDetectionPredictOptions)
from rastervision.pytorch_learner.object_detection_utils import BoxList
from rastervision.pytorch_learner.object_detection_learner_config import (
ObjectDetectionGeoDataConfig)


class PyTorchObjectDetectionSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -154,7 +154,8 @@ def predict_scene(self, scene: 'Scene',

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> ObjectDetectionGeoDataConfig:
chip_options: 'ChipOptions') -> 'ObjectDetectionGeoDataConfig':
from rastervision.pytorch_learner import (ObjectDetectionGeoDataConfig)
data_config = ObjectDetectionGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
from rastervision.pytorch_learner.dataset import (
SemanticSegmentationSlidingWindowGeoDataset)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig

if TYPE_CHECKING:
from rastervision.core.data import (DatasetConfig, Scene,
SemanticSegmentationLabelStore)
from rastervision.core.rv_pipeline import (
ChipOptions, SemanticSegmentationPredictOptions)
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig


class PyTorchSemanticSegmentationSampleWriter(PyTorchLearnerSampleWriter):
Expand Down Expand Up @@ -118,9 +118,11 @@ def predict_scene(self, scene: 'Scene',

return labels

def _make_chip_data_config(
self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions') -> SemanticSegmentationGeoDataConfig:
def _make_chip_data_config(self, dataset: 'DatasetConfig',
chip_options: 'ChipOptions'
) -> 'SemanticSegmentationGeoDataConfig':
from rastervision.pytorch_learner import (
SemanticSegmentationGeoDataConfig)
data_config = SemanticSegmentationGeoDataConfig(
scene_dataset=dataset, sampling=chip_options.sampling)
return data_config

0 comments on commit e6510d7

Please sign in to comment.