Skip to content

Commit

Permalink
feat(backend): add INSUFFCIENT_TRAIN_DATA status and save iteration a…
Browse files Browse the repository at this point in the history
…nd label source information
  • Loading branch information
alonh committed Jul 31, 2023
1 parent f90ef15 commit e506471
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
5 changes: 4 additions & 1 deletion label_sleuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,8 @@ def set_element_label(workspace_id, element_id):

category_id = int(post_data["category_id"])
value = post_data["value"]
iteration = int(post_data.get("iteration", -1))
source = post_data.get("source", "n/a")
update_counter = post_data.get('update_counter', True)

if value == 'none':
Expand All @@ -729,7 +731,8 @@ def set_element_label(workspace_id, element_id):
else:
raise Exception(f"cannot convert label to boolean. Input label = {value}")

uri_with_updated_label = {element_id: {category_id: Label(value)}}
uri_with_updated_label = {element_id: {category_id: Label(value,
metadata={"iteration":iteration, "source":source})}}
curr_app.orchestrator_api. \
set_labels(workspace_id, uri_with_updated_label,
apply_to_duplicate_texts=curr_app.config["CONFIGURATION"].apply_labels_to_duplicate_texts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class IterationStatus(Enum):
READY = 4
ERROR = 5
MODEL_DELETED = 6
INSUFFICIENT_TRAIN_DATA = 7


@dataclass
Expand Down
26 changes: 20 additions & 6 deletions label_sleuth/orchestrator/orchestrator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ def run_iteration(self, workspace_id: str, dataset_name: str, category_id: int,
def _train(self, workspace_id, category_id, model_type, iteration_index, future):
try:
train_data = future.result()
if train_data is None:
logging.info(f"labeled data was not provided in workspace `{workspace_id} category id `{category_id}`"
f" on iteration {iteration_index}. Stopping iteration. A new model will not be trained.")
self.orchestrator_state.update_iteration_status(workspace_id=workspace_id, category_id=category_id,
iteration_index=iteration_index,
new_status=IterationStatus.INSUFFICIENT_TRAIN_DATA)
return

except Exception:
logging.exception(f"Train set selection failed. Marking workspace '{workspace_id}' "
f"category id '{category_id}' iteration {iteration_index} as error")
Expand Down Expand Up @@ -703,7 +711,8 @@ def train_if_recommended(self, workspace_id: str, category_id: int, force=False)

try:
iterations_without_errors = [iteration for iteration in iterations
if iteration.status != IterationStatus.ERROR]
if iteration.status not in [IterationStatus.ERROR,
IterationStatus.INSUFFICIENT_TRAIN_DATA]]

changes_since_last_model = \
self.orchestrator_state.get_label_change_count_since_last_train(workspace_id, category_id)
Expand Down Expand Up @@ -783,7 +792,8 @@ def infer(self, workspace_id: str, category_id: int, elements_to_infer: Sequence
else:
iteration = iterations[iteration_index]
if iteration.status in [IterationStatus.PREPARING_DATA, IterationStatus.TRAINING,
IterationStatus.MODEL_DELETED, IterationStatus.ERROR]:
IterationStatus.MODEL_DELETED, IterationStatus.ERROR,
IterationStatus.INSUFFICIENT_TRAIN_DATA]:
raise Exception(
f"iteration {iteration_index} in workspace '{workspace_id}' category id '{category_id}' "
f"is not ready for inference. "
Expand Down Expand Up @@ -1095,9 +1105,12 @@ def add_documents_from_file(self, dataset_name, temp_file_path):
self.data_access.set_labels(workspace_id, uri_to_label, apply_to_duplicate_texts=True)

if len(category.iterations) > 0:
iteration_index = len(category.iterations) - 1
new_data_infer_thread_pool.submit(self._infer_missing_elements, workspace_id, category_id,
dataset_name, iteration_index)
all_iterations_and_indices = self. \
get_all_iterations_by_status(workspace_id, category_id, IterationStatus.READY)
if len(all_iterations_and_indices)>0:
iteration_index = all_iterations_and_indices[-1][1]
new_data_infer_thread_pool.submit(self._infer_missing_elements, workspace_id, category_id,
dataset_name, iteration_index)
total_infer_jobs += 1
logging.info(f"done adding documents to {dataset_name} upload statistics: {document_statistics}."
f"{total_infer_jobs} infer jobs were submitted in the background")
Expand Down Expand Up @@ -1180,7 +1193,8 @@ def recover_unfinished_iterations(self):
for workspace_id in self.list_workspaces():
for category_id, category in self.get_all_categories(workspace_id).items():
if len(category.iterations) > 0 \
and category.iterations[-1].status not in [IterationStatus.ERROR, IterationStatus.READY]:
and category.iterations[-1].status not in [IterationStatus.ERROR, IterationStatus.READY
, IterationStatus.INSUFFICIENT_TRAIN_DATA]:
logging.info(f"workspace '{workspace_id}', category id {category_id} ('{category.name}') has "
f"iteration in status {category.iterations[-1]}. Restarting iteration")
self.restart_last_iteration(workspace_id, category_id)
Expand Down

0 comments on commit e506471

Please sign in to comment.