diff --git a/app.py b/app.py index ef54340..feed7a9 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,6 @@ from fastapi import FastAPI, HTTPException, responses, status from pydantic import BaseModel +from typing import Union, Dict, Optional from controller import stats from controller import integration @@ -14,6 +15,7 @@ class WeakSupervisionRequest(BaseModel): labeling_task_id: str user_id: str weak_supervision_task_id: str + overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] class TaskStatsRequest(BaseModel): @@ -31,6 +33,7 @@ class SourceStatsRequest(BaseModel): class ExportWsStatsRequest(BaseModel): project_id: str labeling_task_id: str + overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] @app.post("/fit_predict") @@ -43,6 +46,7 @@ def weakly_supervise( request.labeling_task_id, request.user_id, request.weak_supervision_task_id, + request.overwrite_weak_supervision, ) general.remove_and_refresh_session(session_token) return responses.PlainTextResponse(status_code=status.HTTP_200_OK) @@ -80,7 +84,7 @@ def calculate_source_stats( def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse: session_token = general.get_ctx_token() status_code, message = integration.export_weak_supervision_stats( - request.project_id, request.labeling_task_id + request.project_id, request.labeling_task_id, request.overwrite_weak_supervision ) general.remove_and_refresh_session(session_token) diff --git a/controller/integration.py b/controller/integration.py index 9f6251d..954b2e3 100644 --- a/controller/integration.py +++ b/controller/integration.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Optional, Union import traceback import pandas as pd import pickle @@ -18,19 +18,62 @@ labeling_task, record_label_association, weak_supervision, + labeling_task_label, + information_source, ) +NO_LABEL_WS_PRECISION = 0.8 + + +def __create_quality_metrics( + project_id: str, + labeling_task_id: str, + overwrite_weak_supervision: Union[float, Dict[str, float]], +) -> Dict[Tuple[str, str], Dict[str, float]]: + if isinstance(overwrite_weak_supervision, float): + ws_weights = {} + for heuristic_id in information_source.get_all_ids_by_labeling_task_id( + project_id, labeling_task_id + ): + ws_weights[str(heuristic_id)] = overwrite_weak_supervision + else: + ws_weights = overwrite_weak_supervision + + ws_stats = {} + for heuristic_id in ws_weights: + label_ids = labeling_task_label.get_all_ids(project_id, labeling_task_id) + for (label_id,) in label_ids: + ws_stats[(heuristic_id, str(label_id))] = { + "precision": ws_weights[heuristic_id] + } + return ws_stats + def fit_predict( - project_id: str, labeling_task_id: str, user_id: str, weak_supervision_task_id: str + project_id: str, + labeling_task_id: str, + user_id: str, + weak_supervision_task_id: str, + overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None, ): + quality_metrics_overwrite = None + if overwrite_weak_supervision is not None: + quality_metrics_overwrite = __create_quality_metrics( + project_id, labeling_task_id, overwrite_weak_supervision + ) + elif not record_label_association.is_any_record_manually_labeled( + project_id, labeling_task_id + ): + quality_metrics_overwrite = __create_quality_metrics( + project_id, labeling_task_id, NO_LABEL_WS_PRECISION + ) + task_type, df = collect_data(project_id, labeling_task_id, True) try: if task_type == enums.LabelingTaskType.CLASSIFICATION.value: - results = integrate_classification(df) - + results = integrate_classification(df, quality_metrics_overwrite) else: - results = integrate_extraction(df) + results = integrate_extraction(df, quality_metrics_overwrite) weak_supervision.store_data( project_id, labeling_task_id, @@ -52,46 +95,62 @@ def fit_predict( def export_weak_supervision_stats( - project_id: str, labeling_task_id: str + project_id: str, + labeling_task_id: str, + overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None, ) -> Tuple[int, str]: + if overwrite_weak_supervision is not None: + ws_stats = __create_quality_metrics( + project_id, labeling_task_id, overwrite_weak_supervision + ) + elif not record_label_association.is_any_record_manually_labeled( + project_id, labeling_task_id + ): + ws_stats = __create_quality_metrics( + project_id, labeling_task_id, NO_LABEL_WS_PRECISION + ) + else: + task_type, df = collect_data(project_id, labeling_task_id, False) + try: + if task_type == enums.LabelingTaskType.CLASSIFICATION.value: + cnlm = util.get_cnlm_from_df(df) + stats_df = cnlm.quality_metrics() + elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value: + enlm = util.get_enlm_from_df(df) + stats_df = enlm.quality_metrics() + else: + return 404, f"Task type {task_type} not implemented" - task_type, df = collect_data(project_id, labeling_task_id, False) - try: - if task_type == enums.LabelingTaskType.CLASSIFICATION.value: - cnlm = util.get_cnlm_from_df(df) - stats_df = cnlm.quality_metrics() - elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value: - enlm = util.get_enlm_from_df(df) - stats_df = enlm.quality_metrics() - else: - return 404, f"Task type {task_type} not implemented" + if len(stats_df) != 0: + ws_stats = stats_df.set_index(["identifier", "label_name"]).to_dict( + orient="index" + ) + else: + return 404, "Can't compute weak supervision" - if len(stats_df) != 0: - stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict( - orient="index" - ) - else: - return 404, "Can't compute weak supervision" + except Exception: + print(traceback.format_exc(), flush=True) + general.rollback() + return 500, "Internal server error" - os.makedirs(os.path.join("/inference", project_id), exist_ok=True) - with open( - os.path.join( - "/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl" - ), - "wb", - ) as f: - pickle.dump(stats_lkp, f) + os.makedirs(os.path.join("/inference", project_id), exist_ok=True) + with open( + os.path.join( + "/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl" + ), + "wb", + ) as f: + pickle.dump(ws_stats, f) - except Exception: - print(traceback.format_exc(), flush=True) - general.rollback() - return 500, "Internal server error" return 200, "OK" -def integrate_classification(df: pd.DataFrame): +def integrate_classification( + df: pd.DataFrame, + quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None, +): cnlm = util.get_cnlm_from_df(df) - weak_supervision_results = cnlm.weakly_supervise() + weak_supervision_results = cnlm.weakly_supervise(quality_metrics_overwrite) return_values = defaultdict(list) for record_id, ( label_id, @@ -103,9 +162,12 @@ def integrate_classification(df: pd.DataFrame): return return_values -def integrate_extraction(df: pd.DataFrame): +def integrate_extraction( + df: pd.DataFrame, + quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None, +): enlm = util.get_enlm_from_df(df) - weak_supervision_results = enlm.weakly_supervise() + weak_supervision_results = enlm.weakly_supervise(quality_metrics_overwrite) return_values = defaultdict(list) for record_id, preds in weak_supervision_results.items(): for pred in preds: @@ -128,12 +190,12 @@ def collect_data( query_results = [] if labeling_task_item.task_type == enums.LabelingTaskType.CLASSIFICATION.value: - for information_source in labeling_task_item.information_sources: - if only_selected and not information_source.is_selected: + for information_source_item in labeling_task_item.information_sources: + if only_selected and not information_source_item.is_selected: continue results = ( record_label_association.get_all_classifications_for_information_source( - project_id, information_source.id + project_id, information_source_item.id ) ) query_results.extend(results) @@ -149,11 +211,11 @@ def collect_data( labeling_task_item.task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value ): - for information_source in labeling_task_item.information_sources: - if only_selected and not information_source.is_selected: + for information_source_item in labeling_task_item.information_sources: + if only_selected and not information_source_item.is_selected: continue results = record_label_association.get_all_extraction_tokens_for_information_source( - project_id, information_source.id + project_id, information_source_item.id ) query_results.extend(results) diff --git a/requirements.txt b/requirements.txt index c7d03cd..2daeb60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -107,5 +107,5 @@ urllib3==1.26.16 # requests uvicorn==0.22.0 # via -r requirements/common-requirements.txt -weak-nlp==0.0.12 +weak-nlp==0.0.13 # via -r requirements/requirements.in diff --git a/requirements/requirements.in b/requirements/requirements.in index ba36fb9..71279f9 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -1,2 +1,2 @@ -r common-requirements.txt -weak-nlp==0.0.12 +weak-nlp==0.0.13 diff --git a/submodules/model b/submodules/model index bfd0695..5c96146 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit bfd06954dfb5669e3c812f406ecf69c83dd38991 +Subproject commit 5c96146323de5150450c192ba1583917bc53076d