Skip to content

Commit

Permalink
Switch to using OverSamplingCallback in fastai
Browse files Browse the repository at this point in the history
  • Loading branch information
lewfish committed Jul 5, 2019
1 parent 54db684 commit d681f29
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Expand Up @@ -61,9 +61,9 @@ RUN pip install ptvsd==4.2.*
# See https://github.com/mapbox/rasterio/issues/1289
ENV CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt

ARG FASTAI_COMMIT=39b3a3f85489b250c286d22dde4bfe1c52730c60
ARG FASTAI_COMMIT=585d107709c9af8d88ddf2e20eb06b4ad7f4f70f
RUN cd /tmp && \
wget https://github.com/azavea/fastai/archive/$FASTAI_COMMIT.zip && \
wget https://github.com/fastai/fastai/archive/$FASTAI_COMMIT.zip && \
unzip $FASTAI_COMMIT.zip && \
cd fastai-$FASTAI_COMMIT && \
pip install . && \
Expand Down
24 changes: 24 additions & 0 deletions examples/semantic_segmentation/potsdam.py
Expand Up @@ -147,6 +147,30 @@ def exp_better_resnet18(self, raw_uri, processed_uri, root_uri, test=False):
return self.get_exp(exp_id, config, raw_uri, processed_uri, root_uri,
test=test, pred_chip_size=pred_chip_size)

def exp_better_resnet18_oversample_test(
self, raw_uri, processed_uri, root_uri, test=False):
# A set of hyperparams that result in greater accuracy.
exp_id = 'resnet18_better'
config = {
'batch_sz': 16,
'num_epochs': 20,
'debug': False,
'lr': 1e-4,
'one_cycle': True,
'sync_interval': 1,
'tta': True,
'model_arch': 'resnet18',
'flip_vert': True,
'oversample': {
'rare_class_ids': [2],
'rare_target_prop': 1.0
}
}

pred_chip_size = 1200
return self.get_exp(exp_id, config, raw_uri, processed_uri, root_uri,
test=test, pred_chip_size=pred_chip_size)

def exp_resnet50(self, raw_uri, processed_uri, root_uri, test=False):
exp_id = 'resnet50'
config = {
Expand Down
54 changes: 26 additions & 28 deletions fastai_plugin/semantic_segmentation_backend.py
Expand Up @@ -13,10 +13,9 @@
import torch
from fastai.vision import (
SegmentationItemList, get_transforms, models, unet_learner, Image)
from fastai.callbacks import TrackEpochCallback
from fastai.callbacks import TrackEpochCallback, OverSamplingCallback
from fastai.basic_train import load_learner
from fastai.vision.transform import dihedral
from torch.utils.data.sampler import WeightedRandomSampler

from rastervision.utils.files import (
get_local_path, make_dir, upload_or_copy, list_paths,
Expand Down Expand Up @@ -95,13 +94,14 @@ def _make_debug_chips(split):
_make_debug_chips('val')


def get_weighted_sampler(dataset, rare_class_ids, rare_target_prop):
"""Return a WeightedRandomSampler to oversample chips with rare classes.
def get_oversampling_weights(dataset, rare_class_ids, rare_target_prop):
"""Return weight vector for oversampling chips with rare classes.
Args:
dataset: PyTorch DataSet with semantic segmentation data
rare_class_ids: list of rare class ids
rare_target_prop: probability of sampling a chip covering the rare classes
rare_target_prop: desired probability of sampling a chip covering the
rare classes
"""
def filter_chip_inds():
chip_inds = []
Expand All @@ -125,8 +125,7 @@ def get_sample_weights(num_samples, rare_chip_inds, rare_target_prob):
chip_inds = filter_chip_inds()
print('prop of rare chips before oversampling: ', len(chip_inds) / len(dataset))
weights = get_sample_weights(len(dataset), chip_inds, rare_target_prop)
sampler = WeightedRandomSampler(weights, len(weights))
return sampler
return weights


def tta_predict(learner, im_arr):
Expand Down Expand Up @@ -351,27 +350,13 @@ def get_label_path(im_path):
train_img_dir = subset_training_data(
chip_dir, self.train_opts.train_count, self.train_opts.train_prop)

def get_data(train_sampler=None):
data = (SegmentationItemList.from_folder(chip_dir)
.split_by_folder(train=train_img_dir, valid='val-img')
.label_from_func(get_label_path, classes=classes)
.transform(get_transforms(flip_vert=self.train_opts.flip_vert),
size=size, tfm_y=True)
.databunch(bs=self.train_opts.batch_sz,
num_workers=num_workers,
train_sampler=train_sampler))
return data

data = get_data()
oversample = self.train_opts.oversample
if oversample:
sampler = get_weighted_sampler(
data.train_ds, oversample['rare_class_ids'],
oversample['rare_target_prop'])
data = get_data(train_sampler=sampler)

if self.train_opts.debug:
make_debug_chips(data, class_map, tmp_dir, train_uri)
data = (SegmentationItemList.from_folder(chip_dir)
.split_by_folder(train=train_img_dir, valid='val-img')
.label_from_func(get_label_path, classes=classes)
.transform(get_transforms(flip_vert=self.train_opts.flip_vert),
size=size, tfm_y=True)
.databunch(bs=self.train_opts.batch_sz,
num_workers=num_workers))

# Setup learner.
ignore_idx = 0
Expand Down Expand Up @@ -413,6 +398,19 @@ def get_data(train_sampler=None):
self.train_opts.sync_interval)
]

oversample = self.train_opts.oversample
if oversample:
weights = get_oversampling_weights(
data.train_ds, oversample['rare_class_ids'],
oversample['rare_target_prop'])
oversample_callback = OverSamplingCallback(learn, weights=weights)
callbacks.append(oversample_callback)

if self.train_opts.debug:
if oversample:
oversample_callback.on_train_begin()
make_debug_chips(data, class_map, tmp_dir, train_uri)

lr = self.train_opts.lr
num_epochs = self.train_opts.num_epochs
if self.train_opts.one_cycle:
Expand Down

0 comments on commit d681f29

Please sign in to comment.