Permalink
Browse files

First move to use Classes in config file

  • Loading branch information...
ocourtin committed Dec 22, 2018
1 parent 0179985 commit 314521186cc0d3c95919fdcd3d3e50442b7bb57e
@@ -1,24 +1,29 @@
# RoboSat.pink Configuration


[dataset]
# The slippy map dataset's base directory.
path = "~/rsp_dataset"

[classes]
# Human representation for classes.
titles = ["background", "building"]

# Color map for visualization and representing classes in masks.
# Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation.
colors = ["white", "deeppink"]
# Classes configuration.
# Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation.
[[classes]]
title = "background"
color = "white"

[[classes]]
title = "building"
color = "deeppink"


# Channels configuration let your indicate wich dataset sub-directory and bands to take as input.
# Indicate which dataset sub-directory and bands to take as input.
# You could so, add several channels blocks to compose your input Tensor. Orders are meaningful.
[[channels]]
sub = "images"
bands = [1,2,3]

sub = "images"
bands = [1, 2, 3]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


[model]
@@ -48,6 +53,6 @@

# Data augmentation, Flip or Rotate probability.
data_augmentation = 0.75
#

# Weight decay l2 penalty for the optimizer.
decay = 0.0
@@ -1,30 +1,36 @@
"""Configuration handling.
""" Dictionary-based configuration with a TOML-based on-disk representation. Cf: https://github.com/toml-lang/toml """

Dictionary-based configuration with a TOML-based on-disk represenation.
import sys
import toml

See https://github.com/toml-lang/toml
"""

import toml
def check_config(config):
"""Check if config file is consistent. Exit on error if not."""

if not "dataset" in config.keys() or not "path" in config["dataset"].keys():
sys.exit("dataset path is mandatory in config file")

def load_config(path):
"""Loads a dictionary from configuration file.
if not "classes" in config.keys():
sys.exit("At least one class is mandatory in config file")

if not "channels" in config.keys():
sys.exit("At least one channel is mandatory in config file")

Args:
path: the path to load the configuration from.
for channel in config["channels"]:
if not (len(channel["bands"]) == len(channel["mean"]) == len(channel["std"])):
sys.exit("Inconsistent channel bands, mean or std lenght in config file")


def load_config(path):
"""Loads a dictionary from configuration file."""

Returns:
The configuration dictionary loaded from the file.
"""
config = toml.load(path)
check_config(config)

return toml.load(path)
return config


def save_config(attrs, path):
"""Saves a configuration dictionary to a file.
Args:
path: the path to save the configuration dictionary to.
"""
"""Saves a configuration dictionary to a file."""

toml.dump(attrs, path)
@@ -99,8 +99,8 @@ def main(args):
x, y, z = list(map(str, tile))

if args.masks and args.labels and args.config:
classes = load_config(args.config)["classes"]["titles"]
dist, fg_ratio, qod = compare(args.masks, args.labels, tile, classes)
titles = [classe["title"] for classe in load_config(args.config)["classes"]]
dist, fg_ratio, qod = compare(args.masks, args.labels, tile, titles)
if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
continue

@@ -49,8 +49,6 @@ def main(args):
log = Logs(os.path.join(args.out, "log"), out=sys.stderr)
log.log("Begin download from {}".format(args.url))

# tqdm has problems with concurrent.futures.ThreadPoolExecutor; explicitly call `.update`
# https://github.com/tqdm/tqdm/issues/97
progress = tqdm(total=len(tiles), ascii=True, unit="image")

with futures.ThreadPoolExecutor(num_workers) as executor:
@@ -64,6 +62,7 @@ def worker(tile):
path = os.path.join(args.out, z, x, "{}.{}".format(y, args.ext))

if os.path.isfile(path):
progress.update()
return tile, None, True

if args.type == "XYZ":
@@ -82,6 +81,7 @@ def worker(tile):
try:
image = Image.open(res)
image.save(path, optimize=True)
progress.update()
except OSError:
return tile, url, False

@@ -93,8 +93,6 @@ def worker(tile):
if time_for_req < time_per_worker:
time.sleep(time_per_worker - time_for_req)

progress.update()

return tile, url, True

for tile, url, ok in executor.map(worker, tiles):
@@ -36,7 +36,7 @@ def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Workaround: PyTorch ONNX, DataParallel with GPU issue, cf https://github.com/pytorch/pytorch/issues/5315

num_classes = len(config["classes"]["titles"])
num_classes = len(config["classes"])
num_channels = 0
for channel in config["channels"]:
num_channels += len(channel["bands"])
@@ -48,7 +48,7 @@ def add_parser(subparser):

def main(args):
config = load_config(args.config)
num_classes = len(config["classes"]["titles"])
num_classes = len(config["classes"])
batch_size = args.batch_size if args.batch_size else config["model"]["batch_size"]
tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]

@@ -92,7 +92,7 @@ def map_location(storage, _):
directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size, overlap=args.overlap)
loader = DataLoader(directory, batch_size=batch_size, num_workers=args.workers)

palette = make_palette(config["classes"]["colors"][0], config["classes"]["colors"][1])
palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"])

# don't track tensors with autograd during prediction
with torch.no_grad():
@@ -82,11 +82,7 @@ def burn(tile, features, tile_size, burn_value=1):
def main(args):
config = load_config(args.config)
tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]

classes = config["classes"]["titles"]
colors = config["classes"]["colors"]
assert len(classes) == len(colors), "classes and colors coincide"
assert len(colors) == 2, "only binary models supported right now"
colors = [classe["color"] for classe in config["classes"]]

os.makedirs(args.out, exist_ok=True)

@@ -41,11 +41,7 @@ def add_parser(subparser):

def main(args):

config = load_config(args.config)
classes = config["classes"]["titles"]
colors = config["classes"]["colors"]
assert len(classes) == len(colors), "classes and colors coincide"
assert len(colors) == 2, "only binary models supported right now"
colors = [classe["color"] for classe in load_config(args.config)["classes"]]

try:
raster = rasterio_open(args.raster)
@@ -38,8 +38,9 @@ def add_parser(subparser):
parser.add_argument("--checkpoint", type=str, required=False, help="path to a model checkpoint (to retrain)")
parser.add_argument("--resume", action="store_true", help="resume training (imply to provide a checkpoint)")
parser.add_argument("--workers", type=int, default=0, help="number of workers pre-processing images")
parser.add_argument("--dataset", type=int, help="if set, override dataset path value from config file")
parser.add_argument("--dataset", type=str, help="if set, override dataset path value from config file")
parser.add_argument("--epochs", type=int, help="if set, override epochs value from config file")
parser.add_argument("--batch_size", type=int, help="if set, override batch_size value from config file")
parser.add_argument("--lr", type=float, help="if set, override learning rate value from config file")
parser.add_argument("out", type=str, help="directory to save checkpoint .pth files and log")

@@ -48,9 +49,10 @@ def add_parser(subparser):

def main(args):
config = load_config(args.config)
lr = args.lr if args.lr else config["model"]["lr"]
dataset_path = args.dataset if args.dataset else config["dataset"]["path"]
num_epochs = args.epochs if args.epochs else config["model"]["epochs"]
config["dataset"]["path"] = args.dataset if args.dataset else config["dataset"]["path"]
config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
config["model"]["epochs"] = args.epochs if args.epochs else config["model"]["epochs"]
config["model"]["batch_size"] = args.batch_size if args.batch_size else config["model"]["batch_size"]

log = Logs(os.path.join(args.out, "log"))

@@ -63,7 +65,7 @@ def main(args):
device = torch.device("cpu")
log.log("RoboSat - training on CPU, with {} workers".format(args.workers))

num_classes = len(config["classes"]["titles"])
num_classes = len(config["classes"])
num_channels = 0
for channel in config["channels"]:
num_channels += len(channel["bands"])
@@ -80,7 +82,7 @@ def main(args):
).to(device)

net = torch.nn.DataParallel(net)
optimizer = Adam(net.parameters(), lr=lr, weight_decay=config["model"]["decay"])
optimizer = Adam(net.parameters(), lr=config["model"]["lr"], weight_decay=config["model"]["decay"])

resume = 0
if args.checkpoint:
@@ -104,13 +106,13 @@ def map_location(storage, _):
loss_module = import_module("robosat_pink.losses.{}".format(config["model"]["loss"]))
criterion = getattr(loss_module, "{}".format(config["model"]["loss"].title()))().to(device)

train_loader, val_loader = get_dataset_loaders(dataset_path, config, args.workers)
train_loader, val_loader = get_dataset_loaders(config["dataset"]["path"], config, args.workers)

if resume >= num_epochs:
sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.config))
if resume >= config["model"]["epochs"]:
sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(config["model"]["epochs"], args.config))

log.log("")
log.log("--- Input tensor from Dataset: {} ---".format(dataset_path))
log.log("--- Input tensor from Dataset: {} ---".format(config["dataset"]["path"]))
num_channel = 1
for channel in config["channels"]:
for band in channel["bands"]:
@@ -125,21 +127,21 @@ def map_location(storage, _):
log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"]))
log.log("Data Augmentation:\t {}".format(config["model"]["data_augmentation"]))
log.log("Learning Rate:\t\t {}".format(lr))
log.log("Learning Rate:\t\t {}".format(config["model"]["lr"]))
log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
log.log("")

for epoch in range(resume, num_epochs):
for epoch in range(resume, config["model"]["epochs"]):

log.log("---")
log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))
log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"]))

train_hist = train(train_loader, num_classes, device, net, optimizer, criterion)
log.log(
"Train loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
train_hist["loss"],
train_hist["miou"],
config["classes"]["titles"][1],
config["classes"][1]["title"],
train_hist["fg_iou"],
train_hist["mcc"],
)
@@ -148,12 +150,12 @@ def map_location(storage, _):
val_hist = validate(val_loader, num_classes, device, net, criterion)
log.log(
"Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
val_hist["loss"], val_hist["miou"], config["classes"]["titles"][1], val_hist["fg_iou"], val_hist["mcc"]
val_hist["loss"], val_hist["miou"], config["classes"][1]["title"], val_hist["fg_iou"], val_hist["mcc"]
)
)

states = {"epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict()}
checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs))
checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, config["model"]["epochs"]))
torch.save(states, checkpoint_path)


0 comments on commit 3145211

Please sign in to comment.