Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor entities #586

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions cli/cli_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
################### Start Testing ########################
##########################################################


##########################################################
echo "=========================================="
echo "Printing MedPerf version"
Expand Down Expand Up @@ -186,7 +185,7 @@ echo "Running data submission step"
echo "====================================="
medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name="dataset_a" --description="mock dataset a" --location="mock location a" -y
checkFailed "Data submission step failed"
DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1)
DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1)
hasan7n marked this conversation as resolved.
Show resolved Hide resolved
##########################################################

echo "\n"
Expand All @@ -212,7 +211,6 @@ DSET_A_GENUID=$(medperf dataset view $DSET_A_UID | grep generated_uid | cut -d "

echo "\n"


##########################################################
echo "====================================="
echo "Moving storage to some other location"
Expand Down
16 changes: 9 additions & 7 deletions cli/medperf/commands/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
@app.command("ls")
@clean_except
def list(
local: bool = typer.Option(False, "--local", help="Get local benchmarks"),
unregistered: bool = typer.Option(
False, "--unregistered", help="Get unregistered benchmarks"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"),
):
"""List benchmarks stored locally and remotely from the user"""
"""List benchmarks"""
EntityList.run(
Benchmark,
fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"],
local_only=local,
unregistered=unregistered,
mine_only=mine,
)

Expand Down Expand Up @@ -162,10 +164,10 @@ def view(
"--format",
help="Format to display contents. Available formats: [yaml, json]",
),
local: bool = typer.Option(
unregistered: bool = typer.Option(
False,
"--local",
help="Display local benchmarks if benchmark ID is not provided",
"--unregistered",
help="Display unregistered benchmarks if benchmark ID is not provided",
),
mine: bool = typer.Option(
False,
Expand All @@ -180,4 +182,4 @@ def view(
),
):
"""Displays the information of one or more benchmarks"""
EntityView.run(entity_id, Benchmark, format, local, mine, output)
EntityView.run(entity_id, Benchmark, format, unregistered, mine, output)
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def run(
@clean_except
def list():
"""List previously executed tests reports."""
EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"])
EntityList.run(
TestReport,
fields=["UID", "Data Source", "Model", "Evaluator"],
unregistered=True,
)


@app.command("view")
Expand All @@ -116,4 +120,4 @@ def view(
),
):
"""Displays the information of one or more test reports"""
EntityView.run(entity_id, TestReport, format, output=output)
EntityView.run(entity_id, TestReport, format, unregistered=True, output=output)
16 changes: 8 additions & 8 deletions cli/medperf/commands/compatibility_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,23 @@ def create_test_dataset(
# TODO: existing dataset could make problems
# make some changes since this is a test dataset
config.tmp_paths.remove(data_creation.dataset.path)
data_creation.dataset.write()
VukW marked this conversation as resolved.
Show resolved Hide resolved
if skip_data_preparation_step:
data_creation.make_dataset_prepared()
dataset = data_creation.dataset
old_generated_uid = dataset.generated_uid
old_path = dataset.path

# prepare/check dataset
DataPreparation.run(dataset.generated_uid)

# update dataset generated_uid
old_path = dataset.path
generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path])
dataset.generated_uid = generated_uid
dataset.write()
if dataset.input_data_hash != dataset.generated_uid:
VukW marked this conversation as resolved.
Show resolved Hide resolved
new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path])
if new_generated_uid != old_generated_uid:
# move to a correct location if it underwent preparation
new_path = old_path.replace(dataset.input_data_hash, generated_uid)
new_path = old_path.replace(old_generated_uid, new_generated_uid)
remove_path(new_path)
os.rename(old_path, new_path)
dataset.generated_uid = new_generated_uid
dataset.write()

return generated_uid
return new_generated_uid
16 changes: 10 additions & 6 deletions cli/medperf/commands/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
@app.command("ls")
@clean_except
def list(
local: bool = typer.Option(False, "--local", help="Get local datasets"),
unregistered: bool = typer.Option(
False, "--unregistered", help="Get unregistered datasets"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"),
mlcube: int = typer.Option(
None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube"
),
):
"""List datasets stored locally and remotely from the user"""
"""List datasets"""
EntityList.run(
Dataset,
fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"],
local_only=local,
unregistered=unregistered,
mine_only=mine,
mlcube=mlcube,
)
Expand Down Expand Up @@ -149,8 +151,10 @@ def view(
"--format",
help="Format to display contents. Available formats: [yaml, json]",
),
local: bool = typer.Option(
False, "--local", help="Display local datasets if dataset ID is not provided"
unregistered: bool = typer.Option(
False,
"--unregistered",
help="Display unregistered datasets if dataset ID is not provided",
),
mine: bool = typer.Option(
False,
Expand All @@ -165,4 +169,4 @@ def view(
),
):
"""Displays the information of one or more datasets"""
EntityView.run(entity_id, Dataset, format, local, mine, output)
EntityView.run(entity_id, Dataset, format, unregistered, mine, output)
14 changes: 8 additions & 6 deletions cli/medperf/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,29 @@ class EntityList:
def run(
entity_class,
fields,
local_only: bool = False,
unregistered: bool = False,
mine_only: bool = False,
**kwargs,
):
"""Lists all local datasets

Args:
local_only (bool, optional): Display all local results. Defaults to False.
unregistered (bool, optional): Display only local unregistered results. Defaults to False.
mine_only (bool, optional): Display all current-user results. Defaults to False.
hasan7n marked this conversation as resolved.
Show resolved Hide resolved
kwargs (dict): Additional parameters for filtering entity lists.
"""
entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs)
entity_list = EntityList(
entity_class, fields, unregistered, mine_only, **kwargs
)
entity_list.prepare()
entity_list.validate()
entity_list.filter()
entity_list.display()

def __init__(self, entity_class, fields, local_only, mine_only, **kwargs):
def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs):
self.entity_class = entity_class
self.fields = fields
self.local_only = local_only
self.unregistered = unregistered
self.mine_only = mine_only
self.filters = kwargs
self.data = []
Expand All @@ -40,7 +42,7 @@ def prepare(self):
self.filters["owner"] = get_medperf_user_data()["id"]

entities = self.entity_class.all(
local_only=self.local_only, filters=self.filters
unregistered=self.unregistered, filters=self.filters
)
self.data = [entity.display_dict() for entity in entities]

Expand Down
16 changes: 10 additions & 6 deletions cli/medperf/commands/mlcube/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
@app.command("ls")
@clean_except
def list(
local: bool = typer.Option(False, "--local", help="Get local mlcubes"),
unregistered: bool = typer.Option(
False, "--unregistered", help="Get unregistered mlcubes"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"),
):
"""List mlcubes stored locally and remotely from the user"""
"""List mlcubes"""
EntityList.run(
Cube,
fields=["UID", "Name", "State", "Registered"],
local_only=local,
unregistered=unregistered,
mine_only=mine,
)

Expand Down Expand Up @@ -148,8 +150,10 @@ def view(
"--format",
help="Format to display contents. Available formats: [yaml, json]",
),
local: bool = typer.Option(
False, "--local", help="Display local mlcubes if mlcube ID is not provided"
unregistered: bool = typer.Option(
False,
"--unregistered",
help="Display unregistered mlcubes if mlcube ID is not provided",
),
mine: bool = typer.Option(
False,
Expand All @@ -164,4 +168,4 @@ def view(
),
):
"""Displays the information of one or more mlcubes"""
EntityView.run(entity_id, Cube, format, local, mine, output)
EntityView.run(entity_id, Cube, format, unregistered, mine, output)
5 changes: 4 additions & 1 deletion cli/medperf/commands/result/create.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Optional
from medperf.account_management.account_management import get_medperf_user_data
from medperf.commands.execution import Execution
from medperf.entities.result import Result
from tabulate import tabulate
Expand Down Expand Up @@ -143,7 +144,9 @@ def __validate_models(self, benchmark_models):
raise InvalidArgumentError(msg)

def load_cached_results(self):
results = Result.all()
user_id = get_medperf_user_data()["id"]
results = Result.all(filters={"owner": user_id})
results += Result.all(unregistered=True)
benchmark_dset_results = [
result
for result in results
Expand Down
18 changes: 12 additions & 6 deletions cli/medperf/commands/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,19 @@ def submit(
@app.command("ls")
@clean_except
def list(
local: bool = typer.Option(False, "--local", help="Get local results"),
unregistered: bool = typer.Option(
False, "--unregistered", help="Get unregistered results"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user results"),
benchmark: int = typer.Option(
None, "--benchmark", "-b", help="Get results for a given benchmark"
),
):
"""List results stored locally and remotely from the user"""
"""List results"""
EntityList.run(
Result,
fields=["UID", "Benchmark", "Model", "Dataset", "Registered"],
local_only=local,
unregistered=unregistered,
mine_only=mine,
benchmark=benchmark,
)
Expand All @@ -88,8 +90,10 @@ def view(
"--format",
help="Format to display contents. Available formats: [yaml, json]",
),
local: bool = typer.Option(
False, "--local", help="Display local results if result ID is not provided"
unregistered: bool = typer.Option(
False,
"--unregistered",
help="Display unregistered results if result ID is not provided",
),
mine: bool = typer.Option(
False,
Expand All @@ -107,4 +111,6 @@ def view(
),
):
"""Displays the information of one or more results"""
EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark)
EntityView.run(
entity_id, Result, format, unregistered, mine, output, benchmark=benchmark
)
12 changes: 6 additions & 6 deletions cli/medperf/commands/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(
entity_id: Union[int, str],
entity_class: Entity,
format: str = "yaml",
local_only: bool = False,
unregistered: bool = False,
mine_only: bool = False,
output: str = None,
**kwargs,
Expand All @@ -24,14 +24,14 @@ def run(
Args:
entity_id (Union[int, str]): Entity identifies
entity_class (Entity): Entity type
local_only (bool, optional): Display all local entities. Defaults to False.
unregistered (bool, optional): Display only local unregistered entities. Defaults to False.
mine_only (bool, optional): Display all current-user entities. Defaults to False.
format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml.
output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed.
kwargs (dict): Additional parameters for filtering entity lists.
"""
entity_view = EntityView(
entity_id, entity_class, format, local_only, mine_only, output, **kwargs
entity_id, entity_class, format, unregistered, mine_only, output, **kwargs
)
entity_view.validate()
entity_view.prepare()
Expand All @@ -41,12 +41,12 @@ def run(
entity_view.store()

def __init__(
self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs
self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs
):
self.entity_id = entity_id
self.entity_class = entity_class
self.format = format
self.local_only = local_only
self.unregistered = unregistered
self.mine_only = mine_only
self.output = output
self.filters = kwargs
Expand All @@ -65,7 +65,7 @@ def prepare(self):
self.filters["owner"] = get_medperf_user_data()["id"]

entities = self.entity_class.all(
local_only=self.local_only, filters=self.filters
unregistered=self.unregistered, filters=self.filters
)
self.data = [entity.todict() for entity in entities]

Expand Down