Skip to content

Commit

Permalink
feat(be): adapt force_train for multiclass workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
martinscooper committed Nov 6, 2023
1 parent 9603ed5 commit 3472e43
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
18 changes: 11 additions & 7 deletions label_sleuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,25 +1046,29 @@ def get_elements_to_label(workspace_id):
@login_if_required
@validate_category_id
@validate_workspace_id
def force_train_for_category(workspace_id):
def force_train(workspace_id):
"""
This call is used for manually triggering a new Iteration flow.
:param workspace_id:
:request_arg category_id:
"""
category_id = int(request.args['category_id'])
is_multiclass = request.args.get('mode') == WorkspaceModelType.MultiClass.name
category_id = request.args.get('category_id')
if category_id is not None:
category_id = int(category_id)

dataset_name = curr_app.orchestrator_api.get_dataset_name(workspace_id)

model_id = curr_app.orchestrator_api.train_if_recommended(workspace_id, category_id, force=True)
curr_app.orchestrator_api.train_if_recommended(workspace_id, category_id, force=True)

labeling_counts = curr_app.orchestrator_api.get_label_counts(workspace_id, dataset_name, category_id)
logging.info(f"force training a new model in workspace '{workspace_id}' for category '{category_id}', "
f"model id: {model_id}")
to_log_message = f"Force training a new model in workspace '{workspace_id}'"
if not is_multiclass: to_log_message += f" for category '{category_id}'"
logging.info(to_log_message)

return jsonify({
"labeling_counts": labeling_counts,
"model_id": model_id
"labeling_counts": labeling_counts
})


Expand Down
4 changes: 3 additions & 1 deletion label_sleuth/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,6 @@ def test_full_flow_multiclass(self):
for element in res.get_json()["elements"]:
self.assertEqual(0, element["model_predictions"], msg=f"element {element} expected prediction is 0 but got {element['model_predictions']}")


# delete category 0
res = self.client.delete(f"/workspace/{workspace_name}/category/0?mode=MultiClass", headers=HEADERS)
self.assertEqual(200, res.status_code, msg="Failed to delete category 0 in multiclass workspace")
Expand All @@ -771,6 +770,9 @@ def test_full_flow_multiclass(self):
self.assertEqual({'elements': [], 'hit_count': 0},
res.get_json(), msg="labeled elements for category 0 were deleted should not exist in this workspace")

# force model training
res = self.client.get(f"/workspace/{workspace_name}/force_train?mode=MultiClass", headers=HEADERS)
self.assertEqual(200, res.status_code, msg="Failed to force train in multiclass workspace")

# delete workspace
res = self.client.delete(f"/workspace/{workspace_name}",
Expand Down

0 comments on commit 3472e43

Please sign in to comment.