Skip to content
Permalink
Browse files

Enhance Data Augmentation with Albumentations. cf #14

  • Loading branch information...
ocourtin committed Mar 30, 2019
1 parent 261ad51 commit 416d435195c41e4694dda2f89ab03c16da5c8441
@@ -11,20 +11,18 @@
#
# sub: dataset subdirectory name
# bands: bands to keep from sub source. Order is meaningful
# mean: bands mean value
# std: bands std value
# Nota: (default mean and std are based on ImageNet DataSet, cf pretrained model)
# mean: bands mean value [default, if model pretrained: [0.485, 0.456, 0.406] ]
# std: bands std value [default, if model pretrained: [0.229, 0.224, 0.225] ]

[[channels]]
name = "images"
bands = [1, 2, 3]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]



# Output Classes configuration
# Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation.
# Nota: only support binary classification for now.

[[classes]]
title = "background"
color = "white"
@@ -51,11 +49,14 @@
# Learning rate for the optimizer
lr = 0.000025

# Data augmentation
data_augmentation = 0.75

# Model input tile size
# Model internal input tile size
tile_size = 512

# Dataset loader name
loader = "SemSegTiles"

# Kind of data augmentation to apply while training
da = "GeoSpatial"

# Data Augmentation probability
dap = 0.75
@@ -13,20 +13,18 @@
#
# sub: dataset subdirectory name
# bands: bands to keep from sub source. Order is meaningful
# mean: bands mean value
# std: bands std value
# Nota: (default mean and std are based on ImageNet DataSet, cf pretrained model)
# mean: bands mean value [default, if model pretrained: [0.485, 0.456, 0.406] ]
# std: bands std value [default, if model pretrained: [0.229, 0.224, 0.225] ]
[[channels]]
name = "images"
bands = [1, 2, 3]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# Output Classes configuration
# Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation.
# Nota: only support binary classification for now.
[[classes]]
title = "background"
color = "white"
@@ -53,12 +51,15 @@
# Learning rate for the optimizer
lr = 0.000025
# Data augmentation
data_augmentation = 0.75
# Model input tile size
# Model internal input tile size
tile_size = 512
# Dataset loader name
loader = "SemSegTiles"
# Kind of data augmentation to apply while training
da = "GeoSpatial"
# Data Augmentation probability
dap = 0.75
```
@@ -234,8 +234,8 @@ Web UI:
```
usage: rsp train [-h] [--config CONFIG] [--loader LOADER] [--workers WORKERS]
[--batch_size BATCH_SIZE] [--lr LR] [--model MODEL]
[--loss LOSS] [--epochs EPOCHS] [--resume]
[--checkpoint CHECKPOINT]
[--loss LOSS] [--da DA] [--dap DAP] [--epochs EPOCHS]
[--resume] [--checkpoint CHECKPOINT]
dataset out
optional arguments:
@@ -252,6 +252,8 @@ Hyper Parameters [if set override config file value]:
--lr LR learning rate
--model MODEL model name
--loss LOSS model loss
--da DA kind of data augmentation
--dap DAP data augmentation probability
Model Training:
--epochs EPOCHS number of epochs to train [default 10]
@@ -1,5 +1,6 @@
numpy
pillow
albumentations>=0.2.2
opencv-python-headless
torchvision>=0.2.2
tqdm>=4.29.0
@@ -43,9 +43,9 @@ def check_channels(config):

# TODO Add name check

for channel in config["channels"]:
if not (len(channel["bands"]) == len(channel["mean"]) == len(channel["std"])):
sys.exit("CONFIG ERROR: Inconsistent channel bands, mean or std lenght in config file")
# for channel in config["channels"]:
# if not (len(channel["bands"]) == len(channel["mean"]) == len(channel["std"])):
# sys.exit("CONFIG ERROR: Inconsistent channel bands, mean or std lenght in config file")


def check_classes(config):
@@ -74,7 +74,7 @@ def check_model(config):
"batch_size": "int",
"lr": "float",
"tile_size": "int",
"data_augmentation": "float",
"da": "str",
}

for hp in hps:
@@ -5,20 +5,22 @@
import torch.utils.data

from robosat_pink.tiles import tiles_from_slippy_map, tile_image_buffer, tile_image_from_file, tile_label_from_file
from robosat_pink.transforms.core import to_normalized_tensor


class SemSegTiles(torch.utils.data.Dataset):
def __init__(self, config, root, transform, mode, overlap=0):
def __init__(self, config, root, mode, overlap=0):
super().__init__()

self.root = os.path.expanduser(root)
self.transform = transform
self.overlap = overlap
self.config = config
self.mode = mode

self.tiles = {}

assert mode == "train" or mode == "predict"

for channel in config["channels"]:
path = os.path.join(self.root, channel["name"])
self.tiles[channel["name"]] = [(tile, path) for tile, path in tiles_from_slippy_map(path)]
@@ -63,12 +65,11 @@ def __getitem__(self, i):
mask = tile_label_from_file(self.tiles["labels"][i][1])
assert mask is not None, "Dataset mask not retrieved"

image, mask = self.transform(image, mask)
image, mask = to_normalized_tensor(self.config, self.mode, image, mask)
return image, mask, tile

if self.mode == "predict":
image = self.transform(image)
return image, torch.IntTensor([tile.x, tile.y, tile.z])
return to_normalized_tensor(self.config, self.mode, image), torch.IntTensor([tile.x, tile.y, tile.z])

def remove_overlap(self, probs):
C, W, H = probs.shape
@@ -8,15 +8,13 @@
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize

from tqdm import tqdm
from PIL import Image

from robosat_pink.tiles import tiles_from_slippy_map
from robosat_pink.config import load_config, check_model, check_classes, check_channels
from robosat_pink.colors import make_palette
from robosat_pink.transforms import ImageToTensor
from robosat_pink.web_ui import web_ui
from robosat_pink.logs import Logs

@@ -89,19 +87,9 @@ def map_location(storage, _):
except:
sys.exit("ERROR: Unable to load {} in {} model.".format(args.checkpoint, config["model"]["name"]))

std = []
mean = []
for channel in config["channels"]:
std.extend(channel["std"])
mean.extend(channel["mean"])

loader_module = import_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower()))
loader_predict = getattr(loader_module, config["model"]["loader"])(
config,
args.tiles,
transform=Compose([ImageToTensor(), Normalize(mean=mean, std=std)]),
mode="predict",
overlap=args.tile_overlap,
config, args.tiles, mode="predict", overlap=args.tile_overlap
)
loader = DataLoader(loader_predict, batch_size=config["model"]["batch_size"], num_workers=args.workers)
palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"])
@@ -1,24 +1,13 @@
import os
import sys
from tqdm import tqdm
from importlib import import_module

import torch
import torch.backends.cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize

from tqdm import tqdm

from importlib import import_module

from robosat_pink.transforms import (
JointCompose,
JointResize,
JointTransform,
JointRandomFlipOrRotate,
ImageToTensor,
MaskToTensor,
)
from robosat_pink.metrics import Metrics
from robosat_pink.config import load_config, check_model, check_channels, check_classes, check_dataset
from robosat_pink.logs import Logs
@@ -39,6 +28,8 @@ def add_parser(subparser, formatter_class):
hp.add_argument("--lr", type=float, help="learning rate")
hp.add_argument("--model", type=str, help="model name")
hp.add_argument("--loss", type=str, help="model loss")
hp.add_argument("--da", type=str, help="kind of data augmentation")
hp.add_argument("--dap", type=str, help="data augmentation probability")

mt = parser.add_argument_group("Model Training")
mt.add_argument("--epochs", type=int, default=10, help="number of epochs to train [default 10]")
@@ -61,6 +52,8 @@ def main(args):
config["model"]["batch_size"] = args.batch_size if args.batch_size else config["model"]["batch_size"]
config["model"]["name"] = args.model if args.model else config["model"]["name"]
config["model"]["loss"] = args.loss if args.loss else config["model"]["loss"]
config["model"]["da"] = args.da if args.da else config["model"]["da"]
config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
check_dataset(config)
check_classes(config)
check_channels(config)
@@ -252,24 +245,9 @@ def validate(loader, config, device, net, criterion):

def get_dataset_loaders(path, config, num_workers):

std = []
mean = []
for channel in config["channels"]:
std.extend(channel["std"])
mean.extend(channel["mean"])

transform = JointCompose(
[
JointResize(config["model"]["tile_size"]),
JointRandomFlipOrRotate(config["model"]["data_augmentation"]),
JointTransform(ImageToTensor(), MaskToTensor()),
JointTransform(Normalize(mean=mean, std=std), None),
]
)

loader = import_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower()))
loader_train = getattr(loader, config["model"]["loader"])(config, os.path.join(path, "training"), transform, "train")
loader_val = getattr(loader, config["model"]["loader"])(config, os.path.join(path, "validation"), transform, "train")
loader_train = getattr(loader, config["model"]["loader"])(config, os.path.join(path, "training"), "train")
loader_val = getattr(loader, config["model"]["loader"])(config, os.path.join(path, "validation"), "train")

batch_size = config["model"]["batch_size"]
train_loader = DataLoader(loader_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)

This file was deleted.

Oops, something went wrong.
No changes.
Oops, something went wrong.

0 comments on commit 416d435

Please sign in to comment.
You can’t perform that action at this time.