Skip to content
Permalink
Browse files

Extent metrics cf #53. Centralize import module, cf #53. Fix symlink …

…overwrite Cf #58
  • Loading branch information...
ocourtin committed Apr 15, 2019
1 parent 3c32a7c commit 18d3ebb9eedb98d27dc62fcc4a476d4618b53fbf
@@ -139,6 +139,7 @@ echo '
da = "Strong"
loss = "Lovasz"
loader = "SemSegTiles"
metrics = ["iou"]
' > ~/.rsp_config
```
@@ -51,3 +51,6 @@

# Data Augmentation probability
#dap = 1.0

# Metrics
metrics = ["iou", "mcc"]
@@ -53,4 +53,7 @@
# Data Augmentation probability
#dap = 1.0
# Metrics
metrics = ["iou", "mcc"]
```
@@ -44,6 +44,7 @@ echo '
da = "Strong"
bs = 4
lr = 0.000025
metrics = ["iou"]
' > ~/.rsp_config
```

@@ -2,6 +2,7 @@
import sys
import glob
import toml
from importlib import import_module

import re
import colorsys
@@ -11,6 +12,17 @@
from robosat_pink.tiles import tile_pixel_to_location, tiles_to_geojson


#
# Import module
#
def load_module(module):
try:
module = import_module(module)
except:
sys.exit("ERROR: Unable to load {} module".format(module))
return module


#
# Config
#
@@ -158,6 +170,8 @@ def web_ui(out, base_url, coverage_tiles, selected_tiles, ext, template):
templates = glob.glob(os.path.join(Path(__file__).parent, "web_ui", "*"))
if os.path.isfile(template):
templates.append(template)
if os.path.isfile(os.path.join(out, "index.html")):
os.remove(os.path.join(out, "index.html")) # if already existing output dir, as symlink can't be overwriten
os.symlink(os.path.basename(template), os.path.join(out, "index.html"))

def process_template(template):

This file was deleted.

Oops, something went wrong.
@@ -0,0 +1,35 @@
import torch
from robosat_pink.core import load_module


class Metrics:
def __init__(self, metrics, config=None):
self.config = config
self.metrics = {metric: 0.0 for metric in metrics}
self.modules = {metric: load_module("robosat_pink.metrics." + metric) for metric in metrics}
self.n = 0

def add(self, mask, output):
assert self.modules
assert self.metrics
self.n += 1
for metric, module in self.modules.items():
dist = module.get(mask, output, self.config)
dist = dist if dist == dist else 0.0
self.metrics[metric] += dist

def get(self):
assert self.metrics
assert self.n
return {metric: value / self.n for metric, value in self.metrics.items()}


def confusion(predicted, label):
confusion = predicted.view(-1).float() / label.view(-1).float()

tn = torch.sum(torch.isnan(confusion)).item()
fn = torch.sum(confusion == float("inf")).item()
fp = torch.sum(confusion == 0).item()
tp = torch.sum(confusion == 1).item()

return tn, fn, fp, tp
@@ -0,0 +1,13 @@
from robosat_pink.metrics.core import confusion


def get(label, predicted, config=None):

tn, fn, fp, tp = confusion(label, predicted)

try:
iou = tp / (tp + fn + fp)
except ZeroDivisionError:
iou = float("NaN")

return iou
@@ -0,0 +1,14 @@
import math
from robosat_pink.metrics.core import confusion


def get(label, predicted, config=None):

tn, fn, fp, tp = confusion(label, predicted)

try:
mcc = (tp * tn - fp * fn) / math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
except ZeroDivisionError:
mcc = float("NaN")

return mcc
@@ -0,0 +1,23 @@
import torch
import math

from robosat_pink.metrics.core import confusion


def get(label, mask, config=None):

tn, fn, fp, tp = confusion(label, mask)

try:
iou = tp / (tp + fn + fp)
except ZeroDivisionError:
iou = float("NaN")

W, H = mask.size()
ratio = float(100 * torch.max(torch.sum(mask != 0), torch.sum(label != 0)) / (W * H))
dist = 0.0 if iou != iou else 1.0 - iou

qod = 100 - (dist * (math.log(ratio + 1.0) + 1e-7) * (100 / math.log(100)))
qod = 0.0 if qod < 0.0 else qod # Corner case prophilaxy

return (dist, ratio, qod)
@@ -13,7 +13,7 @@

from robosat_pink.core import web_ui, Logs
from robosat_pink.tiles import tiles_from_slippy_map, tile_from_slippy_map, tile_image_from_file, tile_image_to_file
from robosat_pink.metrics import Metrics
from robosat_pink.metrics.qod import get as compare


def add_parser(subparser, formatter_class):
@@ -51,27 +51,6 @@ def add_parser(subparser, formatter_class):
parser.set_defaults(func=main)


def compare(masks, labels, tile):

x, y, z = list(map(str, tile))
label = np.array(Image.open(os.path.join(labels, z, x, "{}.png".format(y))))
mask = np.array(Image.open(os.path.join(masks, z, x, "{}.png".format(y))))

assert label.shape == mask.shape

metrics = Metrics()
metrics.add(torch.from_numpy(label), torch.from_numpy(mask), is_prob=False)
fg_iou = metrics.get_fg_iou()

fg_ratio = 100 * max(np.sum(mask != 0), np.sum(label != 0)) / mask.size
dist = 0.0 if math.isnan(fg_iou) else 1.0 - fg_iou

qod = 100 - (dist * (math.log(fg_ratio + 1.0) + np.finfo(float).eps) * (100 / math.log(100)))
qod = 0.0 if qod < 0.0 else qod # Corner case prophilaxy

return dist, fg_ratio, qod


def main(args):

args.out = os.path.expanduser(args.out)
@@ -116,8 +95,13 @@ def worker(tile):

if args.masks and args.labels:

label = np.array(Image.open(os.path.join(args.labels, z, x, "{}.png".format(y))))
mask = np.array(Image.open(os.path.join(args.masks, z, x, "{}.png".format(y))))

assert label.shape == mask.shape

try:
dist, fg_ratio, qod = compare(args.masks, args.labels, tile)
dist, fg_ratio, qod = compare(torch.as_tensor(label, device="cpu"), torch.as_tensor(mask, device="cpu"))
except:
progress.update()
return False, tile
@@ -1,11 +1,11 @@
import sys

from importlib import import_module

import torch
import torch.onnx
import torch.autograd

from robosat_pink.core import load_module


def add_parser(subparser, formatter_class):
parser = subparser.add_parser("export", help="Export a model to ONNX or Torch JIT", formatter_class=formatter_class)
@@ -24,11 +24,12 @@ def main(args):
try:
chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu"))
assert chkpt["producer_name"] == "RoboSat.pink"
model_module = import_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to("cpu")
except:
sys.exit("ERROR: Unable to load checkpoint: {}".format(args.checkpoint))

model_module = load_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to("cpu")

print("RoboSat.pink - export model to {}".format(args.type))
print("Model: {} - UUID: {} - Torch {}".format(chkpt["nn"], chkpt["uuid"], torch.__version__))
print(chkpt["doc_string"])
@@ -1,6 +1,7 @@
import os
import sys
from importlib import import_module

from robosat_pink.core import load_module


def add_parser(subparser, formatter_class):
@@ -20,13 +21,10 @@ def main(args):

print("RoboSat.pink - extract {} from {}. Could take some time. Please wait.".format(args.type, args.pbf))

try:
module = import_module("robosat_pink.osm.{}".format(args.type.lower()))
except:
sys.exit("ERROR: Unknown OSM {} type extactor".format(args.type))
module = load_module("robosat_pink.osm.{}".format(args.type.lower()))
osmium_handler = getattr(module, "{}Handler".format(args.type))()

try:
osmium_handler = getattr(module, "{}Handler".format(args.type))()
osmium_handler.apply_file(filename=os.path.expanduser(args.pbf), locations=True)
except:
sys.exit("ERROR: Unable to extract {} from {}".format(args.type, args.pbf))
@@ -1,7 +1,6 @@
import os
import sys
from tqdm import tqdm
from importlib import import_module

import numpy as np
import mercantile
@@ -10,7 +9,7 @@
import torch.backends.cudnn
from torch.utils.data import DataLoader

from robosat_pink.core import load_config, check_classes, check_channels, make_palette, web_ui, Logs
from robosat_pink.core import load_config, load_module, check_classes, check_channels, make_palette, web_ui, Logs
from robosat_pink.tiles import tiles_from_slippy_map, tile_label_to_file


@@ -59,7 +58,7 @@ def main(args):
try:
chkpt = torch.load(args.checkpoint, map_location=device)
assert chkpt["producer_name"] == "RoboSat.pink"
model_module = import_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
model_module = load_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to(device)
nn = torch.nn.DataParallel(nn)
nn.load_state_dict(chkpt["state_dict"])
@@ -69,11 +68,8 @@ def main(args):

log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

try:
loader_module = import_module("robosat_pink.loaders.{}".format(chkpt["loader"].lower()))
loader_predict = getattr(loader_module, chkpt["loader"])(config, chkpt["shape_in"][1:3], args.tiles, mode="predict")
except:
sys.exit("ERROR: Unable to load {} data loader.".format(chkpt["loader"]))
loader_module = load_module("robosat_pink.loaders.{}".format(chkpt["loader"].lower()))
loader_predict = getattr(loader_module, chkpt["loader"])(config, chkpt["shape_in"][1:3], args.tiles, mode="predict")

loader = DataLoader(loader_predict, batch_size=args.bs, num_workers=args.workers)
palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"])
Oops, something went wrong.

0 comments on commit 18d3ebb

Please sign in to comment.
You can’t perform that action at this time.