Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions business_objects/information_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from sqlalchemy import cast, TEXT
from typing import Dict, List, Any, Optional

from submodules.model import enums
Expand Down Expand Up @@ -45,6 +46,20 @@ def get_all(project_id: str) -> List[InformationSource]:
)


def get_all_ids_by_labeling_task_id(
project_id: str, labeling_task_id: str
) -> List[str]:
values = (
session.query(cast(InformationSource.id, TEXT))
.filter(
InformationSource.project_id == project_id,
InformationSource.labeling_task_id == labeling_task_id,
)
.all()
)
return [value[0] for value in values]


def get_all_statistics(project_id: str) -> List[InformationSourceStatistics]:
return (
session.query(InformationSourceStatistics)
Expand Down Expand Up @@ -90,23 +105,6 @@ def get_selected_information_sources(project_id: str) -> str:
return ", ".join([str(x.name) for x in information_sources])


def get_task_information_sources(project_id: str, labeling_task_id: str) -> str:
information_sources = (
session.query(InformationSource.name)
.filter(
InformationSource.project_id == project_id,
InformationSource.labeling_task_id == labeling_task_id,
InformationSourceStatistics.source_id == InformationSource.id,
InformationSourceStatistics.true_positives
> 0, # only collect valid options
)
.all()
)
if not information_sources:
return ""
return ", ".join([str(x.name) for x in information_sources])


def get_payloads_by_project_id(project_id: str) -> List[Any]:
query: str = f"""
SELECT
Expand Down
21 changes: 17 additions & 4 deletions business_objects/record_label_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,12 +800,25 @@ def check_label_duplication_classification(
return True


def is_any_record_manually_labeled(project_id: str):
def is_any_record_manually_labeled(
project_id: str, labeling_task_id: Optional[str] = None
) -> bool:
query_join_add = ""
query_where_add = ""
if labeling_task_id:
query_join_add = """
INNER JOIN labeling_task_label ltl
ON rla.labeling_task_label_id = ltl.id AND ltl.project_id = rla.project_id"""
query_where_add = f"""
AND ltl.labeling_task_id = '{labeling_task_id}'"""

query = f"""
SELECT id
SELECT rla.id
FROM record_label_association rla
WHERE project_id = '{project_id}'
AND source_type = '{enums.LabelSource.MANUAL.value}'
{query_join_add}
WHERE rla.project_id = '{project_id}'
AND rla.source_type = '{enums.LabelSource.MANUAL.value}'
{query_where_add}
LIMIT 1
"""
value = general.execute_first(query)
Expand Down