Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc. minor fixes and improvements to examples #2033

Merged
merged 6 commits into from Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,16 +1,16 @@
# flake8: noqa

import os
from os.path import join

import albumentations as A

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import ChipClassificationConfig
from rastervision.core.data import (
ChipClassificationLabelSourceConfig, ClassConfig,
ClassInferenceTransformerConfig, DatasetConfig, GeoJSONVectorSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchChipClassificationConfig
from rastervision.pytorch_learner import (
Backbone, ClassificationGeoDataConfig, ClassificationImageDataConfig,
ClassificationModelConfig, ExternalModuleConfig, GeoDataWindowConfig,
GeoDataWindowMethod, SolverConfig)
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)

Expand Down
@@ -1,14 +1,19 @@
# flake8: noqa

import os
from os.path import join

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import (ObjectDetectionConfig,
ObjectDetectionChipOptions,
ObjectDetectionPredictOptions)
from rastervision.core.data import (
ClassConfig, ClassInferenceTransformerConfig, DatasetConfig,
GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchObjectDetectionConfig
from rastervision.pytorch_learner import (
Backbone, ExternalModuleConfig, GeoDataWindowMethod,
ObjectDetectionGeoDataConfig, ObjectDetectionGeoDataWindowConfig,
ObjectDetectionImageDataConfig, ObjectDetectionModelConfig, PlotOptions,
SolverConfig)
from rastervision.pytorch_backend.examples.utils import save_image_crop

TRAIN_IDS = [
Expand Down
@@ -1,14 +1,18 @@
# flake8: noqa

import os
from os.path import join

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.core.rv_pipeline import (ObjectDetectionConfig,
ObjectDetectionChipOptions,
ObjectDetectionPredictOptions)
from rastervision.core.data import (
ClassConfig, ClassInferenceTransformerConfig, DatasetConfig,
GeoJSONVectorSourceConfig, ObjectDetectionLabelSourceConfig,
RasterioSourceConfig, SceneConfig)
from rastervision.pytorch_backend import PyTorchObjectDetectionConfig
from rastervision.pytorch_learner import (
Backbone, GeoDataWindowMethod, ObjectDetectionGeoDataConfig,
ObjectDetectionGeoDataWindowConfig, ObjectDetectionImageDataConfig,
ObjectDetectionModelConfig, SolverConfig)
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)

Expand Down
@@ -1,16 +1,22 @@
# flake8: noqa

import os
from typing import Optional
from os.path import join, basename

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
save_image_crop)
import albumentations as A

from rastervision.core.rv_pipeline import (SemanticSegmentationConfig,
SemanticSegmentationChipOptions,
SemanticSegmentationWindowMethod)
from rastervision.core.data import (
ClassConfig, DatasetConfig, PolygonVectorOutputConfig,
RasterioSourceConfig, RGBClassTransformerConfig, SceneConfig,
SemanticSegmentationLabelSourceConfig,
SemanticSegmentationLabelStoreConfig)
from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig
from rastervision.pytorch_learner import (
Backbone, ExternalModuleConfig, GeoDataWindowConfig, GeoDataWindowMethod,
PlotOptions, SolverConfig, SemanticSegmentationGeoDataConfig,
SemanticSegmentationImageDataConfig, SemanticSegmentationModelConfig)
from rastervision.pytorch_backend.examples.utils import save_image_crop
from rastervision.pytorch_backend.examples.semantic_segmentation.utils import (
example_multiband_transform, example_rgb_transform, imagenet_stats,
Unnormalize)
Expand All @@ -32,22 +38,24 @@

def get_config(runner,
raw_uri: str,
processed_uri: str,
root_uri: str,
processed_uri: Optional[str] = None,
multiband: bool = False,
external_model: bool = True,
augment: bool = False,
nochip: bool = True,
test: bool = False):
num_epochs: int = 10,
batch_sz: int = 8,
test: bool = False) -> SemanticSegmentationConfig:
"""Generate the pipeline config for this task. This function will be called
by RV, with arguments from the command line, when this example is run.

Args:
runner (Runner): Runner for the pipeline. Will be provided by RV.
raw_uri (str): Directory where the raw data resides
processed_uri (str): Directory for storing processed data.
E.g. crops for testing.
root_uri (str): Directory where all the output will be written.
processed_uri (str): Directory for storing processed data.
E.g. crops for testing. Defaults to None.
multiband (bool, optional): If True, all 4 channels (R, G, B, & IR)
available in the raster source will be used. If False, only
IR, R, G (in that order) will be used. Defaults to False.
Expand All @@ -61,6 +69,8 @@ def get_config(runner,
training instead of from pre-generated chips. The analyze and chip
commands should not be run, if this is set to True. Defaults to
True.
num_epochs (int): Number of epochs to train for.
batch_sz (int): Batch size.
test (bool, optional): If True, does the following simplifications:
(1) Uses only the first 2 scenes
(2) Uses only a 600x600 crop of the scenes
Expand Down Expand Up @@ -203,7 +213,7 @@ def make_scene(id) -> SceneConfig:
num_classes = len(class_config)
model = SemanticSegmentationModelConfig(
external_def=ExternalModuleConfig(
github_repo='AdeelH/pytorch-fpn:0.2',
github_repo='AdeelH/pytorch-fpn:0.3',
name='fpn',
entrypoint='make_fpn_resnet',
entrypoint_kwargs={
Expand All @@ -217,11 +227,15 @@ def make_scene(id) -> SceneConfig:
else:
model = SemanticSegmentationModelConfig(backbone=Backbone.resnet50)

num_epochs = 2 if test else int(num_epochs)
batch_sz = 2 if test else int(batch_sz)
solver = SolverConfig(
lr=1e-4, num_epochs=num_epochs, batch_sz=batch_sz, one_cycle=True)

backend = PyTorchSemanticSegmentationConfig(
data=data,
model=model,
solver=SolverConfig(
lr=1e-4, num_epochs=10, batch_sz=8, one_cycle=True),
solver=solver,
log_tensorboard=True,
run_tensorboard=False,
)
Expand Down
@@ -1,17 +1,24 @@
# flake8: noqa

from typing import Optional
import re
import random
import os
from abc import abstractmethod

from rastervision.pipeline.file_system import list_paths
from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.pipeline.file_system.utils import list_paths
from rastervision.core.rv_pipeline import (SemanticSegmentationConfig,
SemanticSegmentationChipOptions,
SemanticSegmentationWindowMethod)
from rastervision.core.data import (
BufferTransformerConfig, ClassConfig, ClassInferenceTransformerConfig,
DatasetConfig, GeoJSONVectorSourceConfig, PolygonVectorOutputConfig,
RasterioSourceConfig, RasterizedSourceConfig, RasterizerConfig,
SceneConfig, SemanticSegmentationLabelSourceConfig,
SemanticSegmentationLabelStoreConfig, StatsTransformerConfig)
from rastervision.pytorch_backend import PyTorchSemanticSegmentationConfig
from rastervision.pytorch_learner import (
Backbone, GeoDataWindowConfig, GeoDataWindowMethod, SolverConfig,
SemanticSegmentationGeoDataConfig, SemanticSegmentationImageDataConfig,
SemanticSegmentationModelConfig)

BUILDINGS = 'buildings'
ROADS = 'roads'
Expand All @@ -28,21 +35,22 @@ def create(raw_uri, target):
elif target.lower() == ROADS:
return VegasRoads(raw_uri)
else:
raise ValueError('{} is not a valid target.'.format(target))
raise ValueError(f'{target} is not a valid target.')

def get_raster_source_uri(self, id):
filename = f'{self.raster_fn_prefix}{id}.tif'
return os.path.join(self.raw_uri, self.base_dir, self.raster_dir,
'{}{}.tif'.format(self.raster_fn_prefix, id))
filename)

def get_geojson_uri(self, id):
filename = f'{self.label_fn_prefix}{id}.geojson'
return os.path.join(self.raw_uri, self.base_dir, self.label_dir,
'{}{}.geojson'.format(self.label_fn_prefix, id))
filename)

def get_scene_ids(self):
label_dir = os.path.join(self.raw_uri, self.base_dir, self.label_dir)
label_paths = list_paths(label_dir, ext='.geojson')
label_re = re.compile(r'.*{}(\d+)\.geojson'.format(
self.label_fn_prefix))
label_re = re.compile(rf'.*{self.label_fn_prefix}(\d+)\.geojson')
scene_ids = [
label_re.match(label_path).group(1) for label_path in label_paths
]
Expand Down
Expand Up @@ -3,6 +3,7 @@
import csv
from io import StringIO

from rastervision.pipeline.file_system.utils import file_exists
from rastervision.core.data import (RasterioSource, GeoJSONVectorSource,
ClassInferenceTransformer)
from rastervision.core.data.utils import geoms_to_geojson, crop_geotiff
Expand Down Expand Up @@ -48,6 +49,8 @@ def save_image_crop(
ValueError if cannot find a crop satisfying min_features constraint.
"""
print(f'Saving test crop to {image_crop_uri}...')
if file_exists(image_crop_uri):
print(f'Already exists. Skipping.')
old_environ = os.environ.copy()
try:
request_payer = S3FileSystem.get_request_payer()
Expand Down