This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 142
/
move_model.py
156 lines (127 loc) · 5.63 KB
/
move_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import json
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import Tuple
from attr import dataclass
from azureml.core import Environment, Model, Workspace
innereye_root = Path(__file__).resolve().parent.parent.parent
if (innereye_root / "InnerEye").is_dir():
innereye_root_str = str(innereye_root)
if innereye_root_str not in sys.path:
print(f"Adding InnerEye folder to sys.path: {innereye_root_str}")
sys.path.insert(0, innereye_root_str)
from InnerEye.ML.common import FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
MODEL_PATH = "MODEL"
ENVIRONMENT_PATH = "ENVIRONMENT"
MODEL_JSON = "model.json"
@dataclass
class MoveModelConfig:
model_id: str
path: str
action: str
workspace_name: str = ""
subscription_id: str = ""
resource_group: str = ""
def get_paths(self) -> Tuple[Path, Path]:
"""
Gets paths and creates folders if necessary
:param path: Base path
:param model_id: The model ID
:return: model_path, environment_path
"""
model_id_path = Path(self.path) / self.model_id.replace(":", "_")
model_id_path.mkdir(parents=True, exist_ok=True)
model_path = model_id_path / MODEL_PATH
model_path.mkdir(parents=True, exist_ok=True)
env_path = model_id_path / ENVIRONMENT_PATH
env_path.mkdir(parents=True, exist_ok=True)
return model_path, env_path
def download_model(ws: Workspace, config: MoveModelConfig) -> Model:
"""
Downloads an InnerEye model from an AzureML workspace
:param ws: The AzureML workspace
:param config: move config
:return: the exported Model
"""
model = Model(ws, id=config.model_id)
model_path, environment_path = config.get_paths()
with open(model_path / MODEL_JSON, 'w') as f:
json.dump(model.serialize(), f)
model.download(target_dir=str(model_path))
env_name = model.tags.get(PYTHON_ENVIRONMENT_NAME)
environment = ws.environments.get(env_name)
environment.save_to_directory(str(environment_path), overwrite=True)
return model
def upload_model(ws: Workspace, config: MoveModelConfig) -> Model:
"""
Uploads an InnerEye model to an AzureML workspace
:param ws: The AzureML workspace
:param config: move config
:return: imported Model
"""
model_path, environment_path = config.get_paths()
with open(model_path / MODEL_JSON, 'r') as f:
model_dict = json.load(f)
# Find the folder containing the final model.
final_model_path = model_path / FINAL_MODEL_FOLDER
full_model_path = final_model_path if final_model_path.exists() else model_path / FINAL_ENSEMBLE_MODEL_FOLDER
new_model = Model.register(ws, model_path=str(full_model_path), model_name=model_dict['name'],
tags=model_dict['tags'], properties=model_dict['properties'],
description=model_dict['description'])
env = Environment.load_from_directory(str(environment_path))
env.register(workspace=ws)
print(f"Environment {env.name} registered")
return new_model
def get_workspace(config: MoveModelConfig) -> Workspace:
"""
Get workspace based on command line input config
:param config: MoveModelConfig
:return: an Azure ML workspace
"""
return Workspace.get(name=config.workspace_name, subscription_id=config.subscription_id,
resource_group=config.resource_group)
def get_move_model_parser() -> argparse.ArgumentParser:
parser = ArgumentParser()
parser.add_argument("-a", "--action", type=str, required=True,
help="Action (download or upload)")
parser.add_argument("-w", "--workspace_name", type=str, required=True,
help="Azure ML workspace name")
parser.add_argument("-s", "--subscription_id", type=str, required=True,
help="AzureML subscription id")
parser.add_argument("-r", "--resource_group", type=str, required=True,
help="AzureML resource group")
parser.add_argument("-p", "--path", type=str, required=True,
help="The path to download or upload model")
parser.add_argument("-m", "--model_id", type=str, required=True,
help="The AzureML model ID")
return parser
def main() -> None:
parser = get_move_model_parser()
args = parser.parse_args()
config = MoveModelConfig(workspace_name=args.workspace_name, subscription_id=args.subscription_id,
resource_group=args.resource_group,
path=args.path, action=args.action, model_id=args.model_id)
ws = get_workspace(config)
move(ws, config)
def move(ws: Workspace, config: MoveModelConfig) -> Model:
"""
Moves a model: downloads or uploads the model depending on the configs
:param config: the move model config
:param ws: The Azure ML workspace
:return: the download or upload model
"""
if config.action == "download":
return download_model(ws, config)
elif config.action == "upload":
return upload_model(ws, config)
else:
raise ValueError(f'Invalid action {config.action}, allowed values: import or export')
if __name__ == "__main__":
main()