Skip to content

Commit

Permalink
Continue implementing model zoo export
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed May 20, 2021
1 parent a0d52e6 commit 0cacdc1
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 14 deletions.
6 changes: 4 additions & 2 deletions experiments/plantseg/ovules/train_contrastive_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_loader(split, patch_shape, batch_size,
label_key = 'label'
paths = get_paths(split, patch_shape, raw_key)

sampler = torch_em.data.MinForegroundSampler(min_fraction=0.3, p_reject=1.)
sampler = torch_em.data.MinForegroundSampler(min_fraction=0.1, p_reject=1.)
label_transform = partial(torch_em.transform.label.connected_components, ensure_zero=True)

return torch_em.default_segmentation_loader(
Expand Down Expand Up @@ -62,7 +62,9 @@ def get_model():

def train_contrastive(args):
model = get_model()
patch_shape = [1, 384, 384]
patch_shape = [1, 736, 688]
# can train with larger batch sizes for scatter
batch_size = 4 if args.impl == 'scatter' else 1

train_loader = get_loader(
split='train',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
license="MIT",
entry_points={
"console_scripts": [
"torch_em.export_bioimageio_model = torch_em.util.modelzoo.main"
"torch_em.export_bioimageio_model = torch_em.util.modelzoo:main"
]
}
)
6 changes: 6 additions & 0 deletions test/util/test_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from torch_em.model import UNet2d
from torch_em.trainer import DefaultTrainer

try:
import bioimageio
except ImportError:
bioimageio = None


@unittest.skipIf(bioimageio is None, "Need bioimageio package")
class TestModelzoo(unittest.TestCase):
data_path = './data.h5'
checkpoint_folder = './checkpoints'
Expand Down
13 changes: 8 additions & 5 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def from_checkpoint(cls, checkpoint_folder, name='best'):
model_class = getattr(import_module(model_p), model_m)
model = model_class(**init_data['model_kwargs'])

optimizer_p, optimizer_m = init_data['optimizer_class'].rsplit('.', 1)
optimizer_class = getattr(import_module(optimizer_p), optimizer_m)
optimizer = optimizer_class(model.parameters(), **init_data['optimizer_kwargs'])

def _init(name, optional=False, only_class=False):
this_cls = init_data.get(f'{name}_class', None)
if this_cls is None and optional:
Expand All @@ -87,8 +91,8 @@ def _init(name, optional=False, only_class=False):
if only_class:
return this_cls
kwargs = init_data[f'{name}_kwargs']
if name == 'optimizer':
return this_cls(model.parameters(), **kwargs)
if name == 'lr_scheduler':
return this_cls(optimizer, **kwargs)
else:
return this_cls(**kwargs)

Expand All @@ -106,7 +110,7 @@ def _init_loader(name):
val_loader=_init_loader('val'),
model=model,
loss=_init('loss'),
optimizer=_init('optimizer'),
optimizer=optimizer,
metric=_init('metric'),
device=torch.device(init_data['device']),
lr_scheduler=_init('lr_scheduler', optional=True),
Expand All @@ -117,7 +121,6 @@ def _init_loader(name):
)

trainer._initialize(0, save_dict)
print(trainer._iteration, trainer._epoch)
return trainer

def _build_init(self):
Expand Down Expand Up @@ -151,7 +154,7 @@ def _update_loader(init_data, loader, name):
init_data = _update_loader(init_data, self.val_loader, 'val')
if self.lr_scheduler is not None:
init_data['lr_scheduler_class'] = _full_class_path(self.lr_scheduler)
init_data['lr_scheduler_kwargs'] = _full_class_path(self.lr_scheduler)
init_data['lr_scheduler_kwargs'] = get_constructor_arguments(self.lr_scheduler)
return init_data

def _initialize(self, iterations, load_from_checkpoint):
Expand Down
49 changes: 45 additions & 4 deletions torch_em/util/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _get_model(trainer):
model = trainer.model
model.eval()
model_kwargs = model.init_kwargs
# TODO warn if we strip any non-standard arguments
# clear the kwargs of non builtins
# TODO warn if we strip any non-standard arguments
model_kwargs = {k: v for k, v in model_kwargs.items()
if not isinstance(v, type)}
return model, model_kwargs
Expand Down Expand Up @@ -65,11 +65,30 @@ def _write_depedencies(export_folder, dependencies):
copyfile(dependencies, dep_path)


def _get_normalizer(trainer):
dataset = trainer.train_loader.dataset
if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
dataset = dataset.datasets[0]
# TODO the raw transform may contian multiple transformations beside the
# normalization functions. Try to parse this to only return the normalization.
preprocesser = dataset.raw_transform
return preprocesser


def _write_data(input_data, model, trainer, export_folder):
test_input = _pad(input_data, trainer)
# normalize the input data if we have a normalization function
normalizer = _get_normalizer(trainer)
test_input = input_data if normalizer is None else normalizer(input_data)

# pad to 4d/5d
test_input = _pad(test_input, trainer)

# run prediction
with torch.no_grad():
test_tensor = torch.from_numpy(test_input).to(trainer.device)
test_output = model(test_tensor).cpu().numpy()

# save the input / output
test_in_path = os.path.join(export_folder, 'test_input.npy')
np.save(test_in_path, test_input)
test_out_path = os.path.join(export_folder, 'test_output.npy')
Expand All @@ -81,7 +100,7 @@ def _write_source(model, export_folder):
# copy the model source file if it's a torch_em model
# (for now only u-net). otherwise just put the full python class
module = str(model.__class__.__module__)
cls_name = str(module.__class__.__name__)
cls_name = str(model.__class__.__name__)
if module == 'torch_em.model.unet':
source_path = os.path.join(
os.path.split(__file__)[0],
Expand All @@ -98,6 +117,7 @@ def _write_source(model, export_folder):
def _get_kwargs(trainer, name, description,
authors, tags,
license, documentation,
git_repo, cite,
export_folder, input_optional_parameters):
if input_optional_parameters:
print("Enter values for the optional parameters.")
Expand All @@ -121,6 +141,7 @@ def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None):
return f'./{fname}'

if is_list and isinstance(val, str):
val = val.replace("'", '"') # enable single quotes
val = json.loads(val)
if is_list:
assert isinstance(val, (list, tuple))
Expand All @@ -141,18 +162,32 @@ def _default_authors():

return [author]

def _default_repo():
try:
call_res = subprocess.run(['git', 'remote', '-v'], capture_output=True)
repo = call_res.stdout.decode('utf8').split('\n')[0].split()[1]
repo = repo if repo else None
except Exception:
repo = None
return repo

# TODO derive better default values:
# - description: derive something from trainer.ndim, trainer.loss, trainer.model, ...
# - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ...
# - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ...
# - cite: make doi for torch_em and add it instead of url + derive citation from model
kwargs = {
'name': _get_kwarg('name', name, lambda: trainer.name),
'description': _get_kwarg('description', name, lambda: trainer.name),
'authors': _get_kwarg('authors', authors, _default_authors, is_list=True),
'tags': _get_kwarg('tags', tags, lambda: [trainer.name], is_list=True),
'license': _get_kwarg('license', license, lambda: 'MIT'),
'documentation': _get_kwarg('documentation', documentation, lambda: trainer.name,
fname='documentation.md')
fname='documentation.md'),
'git_repo': _get_kwarg('git_repo', git_repo, _default_repo),
'cite': _get_kwarg('cite', cite,
lambda: ['https://github.com/constantinpape/torch-em.git'],
is_list=True)
}

return kwargs
Expand Down Expand Up @@ -257,11 +292,16 @@ def _write_covers(test_in_path, test_out_path, export_folder, covers):


# TODO support conversion to onnx
# TODO more options for the bioimageio export:
# - preprocessing!
# - variable input / output shapes, halo
# - config for custom params (e.g. offsets for mws)
def export_biomageio_model(trainer, input_data, export_folder,
dependencies=None, name=None,
description=None, authors=None,
tags=None, license=None,
documentation=None, covers=None,
git_repo=None, cite=None,
input_optional_parameters=True):
"""
"""
Expand Down Expand Up @@ -295,6 +335,7 @@ def export_biomageio_model(trainer, input_data, export_folder,
kwargs = _get_kwargs(trainer, name, description,
authors, tags,
license, documentation,
git_repo, cite,
export_folder, input_optional_parameters)

model_spec = build_spec(
Expand Down
7 changes: 5 additions & 2 deletions torch_em/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ def get_constructor_arguments(obj):
def _get_args(obj, param_names):
return {name: getattr(obj, name) for name in param_names}

# we don't need to find the constructor arguments for optimizers,
# we don't need to find the constructor arguments for optimizers or schedulers
# because we deserialize the state later
if isinstance(obj, torch.optim.Optimizer):
if isinstance(obj, (torch.optim.Optimizer,
torch.optim.lr_scheduler._LRScheduler,
# ReduceLROnPlateau does not inherit from _LRScheduler
torch.optim.lr_scheduler.ReduceLROnPlateau)):
return {}

# recover the arguments for torch dataloader
Expand Down

0 comments on commit 0cacdc1

Please sign in to comment.