Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor entities #586

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/medperf/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute(
please run the command again with the --no-cache option.\n"""
)
else:
ResultSubmission.run(result.generated_uid, approved=approval)
ResultSubmission.run(result.local_id, approved=approval)
config.ui.print("✅ Done!")


Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/commands/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run_compatibility_test(self):
self.ui.print("Running compatibility test")
self.bmk.write()
data_uid, results = CompatibilityTestExecution.run(
benchmark=self.bmk.generated_uid,
benchmark=self.bmk.local_id,
no_cache=self.no_cache,
skip_data_preparation_step=self.skip_data_preparation_step,
)
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/commands/compatibility_test/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def cached_results(self):
"""
if self.no_cache:
return
uid = self.report.generated_uid
uid = self.report.local_id
try:
report = TestReport.get(uid)
except InvalidArgumentError:
Expand Down
14 changes: 7 additions & 7 deletions cli/medperf/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ def prepare(self):
logging.debug(f"tmp results output: {self.results_path}")

def __setup_logs_path(self):
model_uid = self.model.generated_uid
eval_uid = self.evaluator.generated_uid
data_hash = self.dataset.generated_uid
model_uid = self.model.local_id
eval_uid = self.evaluator.local_id
data_uid = self.dataset.local_id

logs_path = os.path.join(
config.experiments_logs_folder, str(model_uid), str(data_hash)
config.experiments_logs_folder, str(model_uid), str(data_uid)
)
os.makedirs(logs_path, exist_ok=True)
model_logs_path = os.path.join(logs_path, "model.log")
metrics_logs_path = os.path.join(logs_path, f"metrics_{eval_uid}.log")
return model_logs_path, metrics_logs_path

def __setup_predictions_path(self):
model_uid = self.model.generated_uid
data_hash = self.dataset.generated_uid
model_uid = self.model.local_id
data_uid = self.dataset.local_id
preds_path = os.path.join(
config.predictions_folder, str(model_uid), str(data_hash)
config.predictions_folder, str(model_uid), str(data_uid)
)
if os.path.exists(preds_path):
msg = f"Found existing predictions for model {self.model.id} on dataset "
Expand Down
17 changes: 13 additions & 4 deletions cli/medperf/commands/list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Type
from medperf.entities.interface import Entity
from medperf.exceptions import InvalidArgumentError
from tabulate import tabulate

Expand All @@ -8,8 +10,8 @@
class EntityList:
@staticmethod
def run(
entity_class,
fields,
entity_class: Type[Entity],
fields: List[str],
unregistered: bool = False,
mine_only: bool = False,
**kwargs,
Expand All @@ -18,7 +20,7 @@ def run(

Args:
unregistered (bool, optional): Display only local unregistered results. Defaults to False.
mine_only (bool, optional): Display all current-user results. Defaults to False.
mine_only (bool, optional): Display all registered current-user results. Defaults to False.
kwargs (dict): Additional parameters for filtering entity lists.
"""
entity_list = EntityList(
Expand All @@ -29,7 +31,14 @@ def run(
entity_list.filter()
entity_list.display()

def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs):
def __init__(
self,
entity_class: Type[Entity],
fields: List[str],
unregistered: bool,
mine_only: bool,
**kwargs,
):
self.entity_class = entity_class
self.fields = fields
self.unregistered = unregistered
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/commands/result/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def print_summary(self):
data_lists_for_display.append(
[
experiment["model_uid"],
experiment["result"].generated_uid,
experiment["result"].local_id,
experiment["result"].metadata["partial"],
experiment["cached"],
experiment["error"],
Expand Down
29 changes: 14 additions & 15 deletions cli/medperf/commands/result/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from medperf.exceptions import CleanExit
from medperf.utils import remove_path, dict_pretty_print, approval_prompt
from medperf.entities.result import Result
from medperf.enums import Status
from medperf import config


class ResultSubmission:
@classmethod
def run(cls, result_uid, approved=False):
sub = cls(result_uid, approved=approved)
sub.get_result()
updated_result_dict = sub.upload_results()
sub.to_permanent_path(updated_result_dict)
sub.write(updated_result_dict)
Expand All @@ -21,27 +21,26 @@ def __init__(self, result_uid, approved=False):
self.ui = config.ui
self.approved = approved

def request_approval(self, result):
if result.approval_status == Status.APPROVED:
return True
def get_result(self):
self.result = Result.get(self.result_uid)

dict_pretty_print(result.results)
def request_approval(self):
dict_pretty_print(self.result.results)
self.ui.print("Above are the results generated by the model")

approved = approval_prompt(
"Do you approve uploading the presented results to the MLCommons comms? [Y/n]"
"Do you approve uploading the presented results to the MedPerf? [Y/n]"
)

return approved

def upload_results(self):
result = Result.get(self.result_uid)
approved = self.approved or self.request_approval(result)
approved = self.approved or self.request_approval()

if not approved:
raise CleanExit("Results upload operation cancelled")

updated_result_dict = result.upload()
updated_result_dict = self.result.upload()
return updated_result_dict

def to_permanent_path(self, result_dict: dict):
Expand All @@ -50,12 +49,12 @@ def to_permanent_path(self, result_dict: dict):
Args:
result_dict (dict): updated results dictionary
"""
result = Result(**result_dict)
result_storage = config.results_folder
old_res_loc = os.path.join(result_storage, result.generated_uid)
new_res_loc = result.path
remove_path(new_res_loc)
os.rename(old_res_loc, new_res_loc)

old_result_loc = self.result.path
updated_result = Result(**result_dict)
new_result_loc = updated_result.path
remove_path(new_result_loc)
os.rename(old_result_loc, new_result_loc)

def write(self, updated_result_dict):
result = Result(**updated_result_dict)
Expand Down
13 changes: 10 additions & 3 deletions cli/medperf/commands/view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import yaml
import json
from typing import Union
from typing import Union, Type

from medperf import config
from medperf.account_management import get_medperf_user_data
Expand All @@ -12,7 +12,7 @@ class EntityView:
@staticmethod
def run(
entity_id: Union[int, str],
entity_class: Entity,
entity_class: Type[Entity],
format: str = "yaml",
unregistered: bool = False,
mine_only: bool = False,
Expand Down Expand Up @@ -41,7 +41,14 @@ def run(
entity_view.store()

def __init__(
self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs
self,
entity_id: Union[int, str],
entity_class: Type[Entity],
format: str,
unregistered: bool,
mine_only: bool,
output: str,
**kwargs,
):
self.entity_id = entity_id
self.entity_class = entity_class
Expand Down
12 changes: 7 additions & 5 deletions cli/medperf/entities/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import medperf.config as config
from medperf.entities.interface import Entity
from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema
from medperf.entities.schemas import ApprovableSchema, DeployableSchema
from medperf.account_management import get_medperf_user_data


class Benchmark(Entity, MedperfSchema, ApprovableSchema, DeployableSchema):
class Benchmark(Entity, ApprovableSchema, DeployableSchema):
"""
Class representing a Benchmark

Expand Down Expand Up @@ -58,10 +58,12 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}"
@property
def local_id(self):
return self.name
VukW marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _Entity__remote_prefilter(cls, filters: dict) -> callable:
@staticmethod
def remote_prefilter(filters: dict) -> callable:
"""Applies filtering logic that must be done before retrieving remote entities

Args:
Expand Down
15 changes: 9 additions & 6 deletions cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
spawn_and_kill,
)
from medperf.entities.interface import Entity
from medperf.entities.schemas import MedperfSchema, DeployableSchema
from medperf.entities.schemas import DeployableSchema
from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError
import medperf.config as config
from medperf.comms.entity_resources import resources
from medperf.account_management import get_medperf_user_data


class Cube(Entity, MedperfSchema, DeployableSchema):
class Cube(Entity, DeployableSchema):
"""
Class representing an MLCube Container

Expand Down Expand Up @@ -70,14 +70,17 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

self.generated_uid = self.name
self.cube_path = os.path.join(self.path, config.cube_filename)
self.params_path = None
if self.git_parameters_url:
self.params_path = os.path.join(self.path, config.params_filename)

@classmethod
def _Entity__remote_prefilter(cls, filters: dict):
@property
def local_id(self):
return self.name

@staticmethod
def remote_prefilter(filters: dict):
"""Applies filtering logic that must be done before retrieving remote entities

Args:
Expand Down Expand Up @@ -245,7 +248,7 @@ def run(
"""
kwargs.update(string_params)
cmd = f"mlcube --log-level {config.loglevel} run"
cmd += f" --mlcube=\"{self.cube_path}\" --task={task} --platform={config.platform} --network=none"
cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform} --network=none'
if config.gpus is not None:
cmd += f" --gpus={config.gpus}"
if read_protected_input:
Expand Down
13 changes: 8 additions & 5 deletions cli/medperf/entities/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from medperf.utils import remove_path
from medperf.entities.interface import Entity
from medperf.entities.schemas import MedperfSchema, DeployableSchema
from medperf.entities.schemas import DeployableSchema

import medperf.config as config
from medperf.account_management import get_medperf_user_data


class Dataset(Entity, MedperfSchema, DeployableSchema):
class Dataset(Entity, DeployableSchema):
"""
Class representing a Dataset

Expand Down Expand Up @@ -62,13 +62,16 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.data_path = os.path.join(self.path, "data")
self.labels_path = os.path.join(self.path, "labels")
self.report_path = os.path.join(self.path, config.report_file)
self.metadata_path = os.path.join(self.path, config.metadata_folder)
self.statistics_path = os.path.join(self.path, config.statistics_filename)

@property
def local_id(self):
return self.generated_uid

def set_raw_paths(self, raw_data_path: str, raw_labels_path: str):
raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file)
data = {"data_path": raw_data_path, "labels_path": raw_labels_path}
Expand All @@ -94,8 +97,8 @@ def is_ready(self):
flag_file = os.path.join(self.path, config.ready_flag_file)
return os.path.exists(flag_file)

@classmethod
def _Entity__remote_prefilter(cls, filters: dict) -> callable:
@staticmethod
def remote_prefilter(filters: dict) -> callable:
"""Applies filtering logic that must be done before retrieving remote entities

Args:
Expand Down