diff --git a/business_objects/information_source.py b/business_objects/information_source.py index 1c165758..bbefa2da 100644 --- a/business_objects/information_source.py +++ b/business_objects/information_source.py @@ -1,4 +1,5 @@ from datetime import datetime +from sqlalchemy import cast, TEXT from typing import Dict, List, Any, Optional from submodules.model import enums @@ -45,6 +46,20 @@ def get_all(project_id: str) -> List[InformationSource]: ) +def get_all_ids_by_labeling_task_id( + project_id: str, labeling_task_id: str +) -> List[str]: + values = ( + session.query(cast(InformationSource.id, TEXT)) + .filter( + InformationSource.project_id == project_id, + InformationSource.labeling_task_id == labeling_task_id, + ) + .all() + ) + return [value[0] for value in values] + + def get_all_statistics(project_id: str) -> List[InformationSourceStatistics]: return ( session.query(InformationSourceStatistics) @@ -90,23 +105,6 @@ def get_selected_information_sources(project_id: str) -> str: return ", ".join([str(x.name) for x in information_sources]) -def get_task_information_sources(project_id: str, labeling_task_id: str) -> str: - information_sources = ( - session.query(InformationSource.name) - .filter( - InformationSource.project_id == project_id, - InformationSource.labeling_task_id == labeling_task_id, - InformationSourceStatistics.source_id == InformationSource.id, - InformationSourceStatistics.true_positives - > 0, # only collect valid options - ) - .all() - ) - if not information_sources: - return "" - return ", ".join([str(x.name) for x in information_sources]) - - def get_payloads_by_project_id(project_id: str) -> List[Any]: query: str = f""" SELECT diff --git a/business_objects/record_label_association.py b/business_objects/record_label_association.py index 932db31b..1721ce94 100644 --- a/business_objects/record_label_association.py +++ b/business_objects/record_label_association.py @@ -800,12 +800,25 @@ def check_label_duplication_classification( return True -def is_any_record_manually_labeled(project_id: str): +def is_any_record_manually_labeled( + project_id: str, labeling_task_id: Optional[str] = None +) -> bool: + query_join_add = "" + query_where_add = "" + if labeling_task_id: + query_join_add = """ + INNER JOIN labeling_task_label ltl + ON rla.labeling_task_label_id = ltl.id AND ltl.project_id = rla.project_id""" + query_where_add = f""" + AND ltl.labeling_task_id = '{labeling_task_id}'""" + query = f""" - SELECT id + SELECT rla.id FROM record_label_association rla - WHERE project_id = '{project_id}' - AND source_type = '{enums.LabelSource.MANUAL.value}' + {query_join_add} + WHERE rla.project_id = '{project_id}' + AND rla.source_type = '{enums.LabelSource.MANUAL.value}' + {query_where_add} LIMIT 1 """ value = general.execute_first(query)