Skip to content
Permalink
Browse files

Add MetaData model handling. Refactor. Cf #40

  • Loading branch information...
ocourtin committed Apr 6, 2019
1 parent 36710fd commit a5b79c77dc82bf783cd0b3af11e208133a86c375
@@ -77,8 +77,8 @@ it_train:
# Integration Tests: Post Training
it_post:
@echo "==================================================================================="
@rsp export --checkpoint it/pth/checkpoint-00005-of-00005.pth --config config.toml --type jit it/pth/export.jit
@rsp export --checkpoint it/pth/checkpoint-00005-of-00005.pth --config config.toml --type onnx it/pth/export.onnx
@rsp export --checkpoint it/pth/checkpoint-00005-of-00005.pth --type jit it/pth/export.jit
@rsp export --checkpoint it/pth/checkpoint-00005-of-00005.pth --type onnx it/pth/export.onnx
@rsp predict --config config.toml --bs 4 --checkpoint it/pth/checkpoint-00005-of-00005.pth it/prediction it/prediction/masks
@rsp compare --images it/prediction/images it/prediction/labels it/prediction/masks --mode stack --labels it/prediction/labels --masks it/prediction/masks it/prediction/compare
@rsp compare --images it/prediction/images it/prediction/compare --mode side it/prediction/compare_side
@@ -40,8 +40,8 @@
# Learning rate for the optimizer
#lr = 0.000025

# Model internal input tile size
#ts = 512
# Model internal input tile size [W, H]
#ts = [512, 512]

# Dataset loader name
loader = "SemSegTiles"
@@ -42,8 +42,8 @@
# Learning rate for the optimizer
#lr = 0.000025
# Model internal input tile size
#ts = 512
# Model internal input tile size [W, H]
#ts = [512, 512]
# Dataset loader name
loader = "SemSegTiles"
@@ -91,17 +91,14 @@ Web UI:
```
## rsp export
```
usage: rsp export [-h] --checkpoint CHECKPOINT [--type {onnx,jit}]
[--config CONFIG]
out
usage: rsp export [-h] --checkpoint CHECKPOINT [--type {onnx,jit}] out
optional arguments:
-h, --help show this help message and exit
Inputs:
--checkpoint CHECKPOINT model checkpoint to load [required]
--type {onnx,jit} output type [default: jit]
--config CONFIG path to config file [required]
Output:
out path to save export model to [required]
@@ -122,8 +119,8 @@ Output:
```
## rsp predict
```
usage: rsp predict [-h] --checkpoint CHECKPOINT [--config CONFIG] [--nn NN]
[--ts TS] [--workers WORKERS] [--bs BS]
usage: rsp predict [-h] --checkpoint CHECKPOINT [--config CONFIG]
[--workers WORKERS] [--bs BS]
[--web_ui_base_url WEB_UI_BASE_URL]
[--web_ui_template WEB_UI_TEMPLATE] [--no_web_ui]
tiles out
@@ -135,8 +132,6 @@ Inputs:
tiles tiles directory path [required]
--checkpoint CHECKPOINT path to the trained model to use [required]
--config CONFIG path to config file [required]
--nn NN if set, override neurals network name from config file
--ts TS if set, override tile size value from config file
Outputs:
out output directory path [required]
@@ -236,8 +231,8 @@ Web UI:
## rsp train
```
usage: rsp train [-h] [--config CONFIG] [--loader LOADER] [--workers WORKERS]
[--bs BS] [--lr LR] [--nn NN] [--loss LOSS] [--da DA]
[--dap DAP] [--epochs EPOCHS] [--resume]
[--bs BS] [--lr LR] [--ts TS] [--nn NN] [--loss LOSS]
[--da DA] [--dap DAP] [--epochs EPOCHS] [--resume]
[--checkpoint CHECKPOINT]
dataset out
@@ -251,8 +246,9 @@ Dataset:
--workers WORKERS number of pre-processing images workers [default: GPUs x 2]
Hyper Parameters [if set override config file value]:
--bs BS batch_size
--bs BS batch size
--lr LR learning rate
--ts TS tile size
--nn NN neurals network name
--loss LOSS model loss
--da DA kind of data augmentation
@@ -29,7 +29,7 @@ def load_config(path):
config["model"] = {}

if "ts" not in config["model"].keys():
config["model"]["ts"] = 512
config["model"]["ts"] = (512, 512)

if "pretrained" not in config["model"].keys():
config["model"]["pretrained"] = True
@@ -67,7 +67,7 @@ def check_classes(config):

def check_model(config):

hps = {"nn": "str", "pretrained": "bool", "loss": "str", "ts": "int", "da": "str"}
hps = {"nn": "str", "pretrained": "bool", "loss": "str", "da": "str"}

for hp in hps:
if hp not in config["model"].keys() or type(config["model"][hp]).__name__ != hps[hp]:
@@ -1,27 +1,33 @@
"""PyTorch-compatible Data Augmentation."""

import sys
import cv2
import torch
import numpy as np
from importlib import import_module


def to_normalized_tensor(config, mode, image, mask=None):
def to_normalized_tensor(config, ts, mode, image, mask=None):

assert mode == "train" or mode == "predict"
assert len(ts) == 2
assert image is not None

# To Tensor and Data Augmentation
# Resize, ToTensor and Data Augmentation
if mode == "train":
try:
module = import_module("robosat_pink.da.{}".format(config["model"]["da"].lower()))
except:
sys.exit("Unable to load data augmentation module")

transform = module.transform(config, image, mask)
image = cv2.resize(image, ts, interpolation=cv2.INTER_LINEAR)
image = torch.from_numpy(np.moveaxis(transform["image"], 2, 0)).float()
mask = cv2.resize(mask, ts, interpolation=cv2.INTER_NEAREST)
mask = torch.from_numpy(transform["mask"]).long()

elif mode == "predict":
image = cv2.resize(image, ts, interpolation=cv2.INTER_LINEAR)
image = torch.from_numpy(np.moveaxis(image, 2, 0)).float()

# Normalize
@@ -34,8 +40,8 @@ def to_normalized_tensor(config, mode, image, mask=None):
mean.extend(channel["mean"])
except:
if config["model"]["pretrained"] and not len(std) and not len(mean):
mean = [0.485, 0.456, 0.406] # Use RGB ImageNet default
std = [0.229, 0.224, 0.225] # Use RGB ImageNet default
mean = [0.485, 0.456, 0.406] # RGB ImageNet default
std = [0.229, 0.224, 0.225] # RGB ImageNet default

assert len(std) and len(mean)
image.sub_(torch.as_tensor(mean, device=image.device)[:, None, None])
@@ -1,6 +1,5 @@
from albumentations import (
Compose,
Resize,
IAAAdditiveGaussianNoise,
GaussNoise,
OneOf,
@@ -25,12 +24,10 @@ def transform(config, image, mask):
p = 1

assert 0 <= p <= 1
assert 64 < config["model"]["ts"]

# Inspire by: https://albumentations.readthedocs.io/en/latest/examples.html
return Compose(
[
Resize(config["model"]["ts"], config["model"]["ts"]),
Flip(),
Transpose(),
OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.2),
@@ -9,21 +9,25 @@


class SemSegTiles(torch.utils.data.Dataset):
def __init__(self, config, root, mode):
def __init__(self, config, ts, root, mode):
super().__init__()

self.root = os.path.expanduser(root)
self.config = config
self.mode = mode

self.tiles = {}

assert mode == "train" or mode == "predict"

num_channels = 0
self.tiles = {}
for channel in config["channels"]:
path = os.path.join(self.root, channel["name"])
self.tiles[channel["name"]] = [(tile, path) for tile, path in tiles_from_slippy_map(path)]
self.tiles[channel["name"]].sort(key=lambda tile: tile[0])
num_channels += len(channel["bands"])

self.shape_in = (num_channels,) + ts # C,W,H
self.shape_out = (len(config["classes"]),) + ts # C,W,H

if self.mode == "train":
path = os.path.join(self.root, "labels")
@@ -60,8 +64,9 @@ def __getitem__(self, i):
mask = tile_label_from_file(self.tiles["labels"][i][1])
assert mask is not None, "Dataset mask not retrieved"

image, mask = to_normalized_tensor(self.config, self.mode, image, mask)
image, mask = to_normalized_tensor(self.config, self.shape_in[1:3], self.mode, image, mask)
return image, mask, tile

if self.mode == "predict":
return to_normalized_tensor(self.config, self.mode, image), torch.IntTensor([tile.x, tile.y, tile.z])
image = to_normalized_tensor(self.config, self.shape_in[1:3], self.mode, image)
return image, torch.IntTensor([tile.x, tile.y, tile.z])
@@ -26,21 +26,19 @@ def forward(self, x):


class Albunet(nn.Module):
"""U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev.
def __init__(self, shape_in, shape_out, pretrained=False):
self.doc = """
U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev.
- https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation
- https://arxiv.org/pdf/1804.08024 - Angiodysplasia Detection and Localization Using DCNN
- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation
"""
- https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation
- https://arxiv.org/pdf/1804.08024 - Angiodysplasia Detection and Localization Using DCNN
- https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation
"""
self.version = 1

def __init__(self, config):

num_classes = len(config["classes"])
pretrained = config["model"]["pretrained"]
num_filters = 32
num_channels = 0
for channel in config["channels"]:
num_channels += len(channel["bands"])
num_channels = shape_in[0]
num_classes = shape_out[0]

super().__init__()

@@ -6,57 +6,44 @@
import torch.onnx
import torch.autograd

from robosat_pink.config import load_config, check_classes, check_model


def add_parser(subparser, formatter_class):
parser = subparser.add_parser("export", help="Export a model to ONNX or Torch JIT", formatter_class=formatter_class)

inp = parser.add_argument_group("Inputs")
inp.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load [required]")
inp.add_argument("--type", type=str, choices=["onnx", "jit"], default="jit", help="output type [default: jit]")
inp.add_argument("--config", type=str, help="path to config file [required]")
out = parser.add_argument_group("Output")
out.add_argument("out", type=str, help="path to save export model to [required]")

parser.set_defaults(func=main)


def main(args):
config = load_config(args.config)
check_classes(config)
check_model(config)

print("RoboSat.pink - export to {} - (Torch:{})".format(args.type, torch.__version__))

try:
model_module = import_module("robosat_pink.models.{}".format(config["model"]["nn"].lower()))
except:
sys.exit("ERROR: Unknown {} model.".format(config["model"]["nn"]))

try:
net = getattr(model_module, config["model"]["nn"])(config).to("cpu")
chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu"))
assert chkpt["producer_name"] == "RoboSat.pink"
model_module = import_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to("cpu")
except:
sys.exit("ERROR: Unable to load {} in {} model.".format(args.checkpoint, config["model"]["nn"]))
sys.exit("ERROR: Unable to load checkpoint: {}".format(args.checkpoint))

print("RoboSat.pink - export model to {}".format(args.type))
print("Model: {} - UUID: {} - Torch {}".format(chkpt["nn"], chkpt["uuid"], torch.__version__))
print(chkpt["doc_string"])

try: # https://github.com/pytorch/pytorch/issues/9176
net.module.state_dict(chkpt["state_dict"])
nn.module.state_dict(chkpt["state_dict"])
except AttributeError:
net.state_dict(chkpt["state_dict"])

num_channels = 0
for channel in config["channels"]:
for band in channel["bands"]:
num_channels += 1

batch = torch.rand(1, num_channels, config["model"]["ts"], config["model"]["ts"])
nn.state_dict(chkpt["state_dict"])

try:
batch = torch.rand(1, *chkpt["shape_in"])
if args.type == "onnx":
torch.onnx.export(net, torch.autograd.Variable(batch), args.out)
torch.onnx.export(nn, torch.autograd.Variable(batch), args.out)

if args.type == "jit":
torch.jit.trace(net, batch).save(args.out)
torch.jit.trace(nn, batch).save(args.out)
except:
sys.exit("ERROR: Unable to export model {} in {}.".format(config["model"]["nn"]), args.type)
sys.exit("ERROR: Unable to export model {} in {}.".format(chkpt["uuid"]), args.type)
Oops, something went wrong.

0 comments on commit a5b79c7

Please sign in to comment.
You can’t perform that action at this time.