Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import FastAPI, HTTPException, responses, status
from pydantic import BaseModel
from typing import Union, Dict, Optional

from controller import stats
from controller import integration
Expand All @@ -14,6 +15,7 @@ class WeakSupervisionRequest(BaseModel):
labeling_task_id: str
user_id: str
weak_supervision_task_id: str
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]]


class TaskStatsRequest(BaseModel):
Expand All @@ -31,6 +33,7 @@ class SourceStatsRequest(BaseModel):
class ExportWsStatsRequest(BaseModel):
project_id: str
labeling_task_id: str
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]]


@app.post("/fit_predict")
Expand All @@ -43,6 +46,7 @@ def weakly_supervise(
request.labeling_task_id,
request.user_id,
request.weak_supervision_task_id,
request.overwrite_weak_supervision,
)
general.remove_and_refresh_session(session_token)
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
Expand Down Expand Up @@ -80,7 +84,7 @@ def calculate_source_stats(
def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
status_code, message = integration.export_weak_supervision_stats(
request.project_id, request.labeling_task_id
request.project_id, request.labeling_task_id, request.overwrite_weak_supervision
)
general.remove_and_refresh_session(session_token)

Expand Down
150 changes: 106 additions & 44 deletions controller/integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional, Union
import traceback
import pandas as pd
import pickle
Expand All @@ -18,19 +18,62 @@
labeling_task,
record_label_association,
weak_supervision,
labeling_task_label,
information_source,
)

NO_LABEL_WS_PRECISION = 0.8


def __create_quality_metrics(
project_id: str,
labeling_task_id: str,
overwrite_weak_supervision: Union[float, Dict[str, float]],
) -> Dict[Tuple[str, str], Dict[str, float]]:
if isinstance(overwrite_weak_supervision, float):
ws_weights = {}
for heuristic_id in information_source.get_all_ids_by_labeling_task_id(
project_id, labeling_task_id
):
ws_weights[str(heuristic_id)] = overwrite_weak_supervision
else:
ws_weights = overwrite_weak_supervision

ws_stats = {}
for heuristic_id in ws_weights:
label_ids = labeling_task_label.get_all_ids(project_id, labeling_task_id)
for (label_id,) in label_ids:
ws_stats[(heuristic_id, str(label_id))] = {
"precision": ws_weights[heuristic_id]
}
return ws_stats


def fit_predict(
project_id: str, labeling_task_id: str, user_id: str, weak_supervision_task_id: str
project_id: str,
labeling_task_id: str,
user_id: str,
weak_supervision_task_id: str,
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
):
quality_metrics_overwrite = None
if overwrite_weak_supervision is not None:
quality_metrics_overwrite = __create_quality_metrics(
project_id, labeling_task_id, overwrite_weak_supervision
)
elif not record_label_association.is_any_record_manually_labeled(
project_id, labeling_task_id
):
quality_metrics_overwrite = __create_quality_metrics(
project_id, labeling_task_id, NO_LABEL_WS_PRECISION
)

task_type, df = collect_data(project_id, labeling_task_id, True)
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
results = integrate_classification(df)

results = integrate_classification(df, quality_metrics_overwrite)
else:
results = integrate_extraction(df)
results = integrate_extraction(df, quality_metrics_overwrite)
weak_supervision.store_data(
project_id,
labeling_task_id,
Expand All @@ -52,46 +95,62 @@ def fit_predict(


def export_weak_supervision_stats(
project_id: str, labeling_task_id: str
project_id: str,
labeling_task_id: str,
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
) -> Tuple[int, str]:
if overwrite_weak_supervision is not None:
ws_stats = __create_quality_metrics(
project_id, labeling_task_id, overwrite_weak_supervision
)
elif not record_label_association.is_any_record_manually_labeled(
project_id, labeling_task_id
):
ws_stats = __create_quality_metrics(
project_id, labeling_task_id, NO_LABEL_WS_PRECISION
)
else:
task_type, df = collect_data(project_id, labeling_task_id, False)
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
cnlm = util.get_cnlm_from_df(df)
stats_df = cnlm.quality_metrics()
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
enlm = util.get_enlm_from_df(df)
stats_df = enlm.quality_metrics()
else:
return 404, f"Task type {task_type} not implemented"

task_type, df = collect_data(project_id, labeling_task_id, False)
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
cnlm = util.get_cnlm_from_df(df)
stats_df = cnlm.quality_metrics()
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
enlm = util.get_enlm_from_df(df)
stats_df = enlm.quality_metrics()
else:
return 404, f"Task type {task_type} not implemented"
if len(stats_df) != 0:
ws_stats = stats_df.set_index(["identifier", "label_name"]).to_dict(
orient="index"
)
else:
return 404, "Can't compute weak supervision"

if len(stats_df) != 0:
stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict(
orient="index"
)
else:
return 404, "Can't compute weak supervision"
except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
return 500, "Internal server error"

os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
with open(
os.path.join(
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
),
"wb",
) as f:
pickle.dump(stats_lkp, f)
os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
with open(
os.path.join(
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
),
"wb",
) as f:
pickle.dump(ws_stats, f)

except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
return 500, "Internal server error"
return 200, "OK"


def integrate_classification(df: pd.DataFrame):
def integrate_classification(
df: pd.DataFrame,
quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None,
):
cnlm = util.get_cnlm_from_df(df)
weak_supervision_results = cnlm.weakly_supervise()
weak_supervision_results = cnlm.weakly_supervise(quality_metrics_overwrite)
return_values = defaultdict(list)
for record_id, (
label_id,
Expand All @@ -103,9 +162,12 @@ def integrate_classification(df: pd.DataFrame):
return return_values


def integrate_extraction(df: pd.DataFrame):
def integrate_extraction(
df: pd.DataFrame,
quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None,
):
enlm = util.get_enlm_from_df(df)
weak_supervision_results = enlm.weakly_supervise()
weak_supervision_results = enlm.weakly_supervise(quality_metrics_overwrite)
return_values = defaultdict(list)
for record_id, preds in weak_supervision_results.items():
for pred in preds:
Expand All @@ -128,12 +190,12 @@ def collect_data(

query_results = []
if labeling_task_item.task_type == enums.LabelingTaskType.CLASSIFICATION.value:
for information_source in labeling_task_item.information_sources:
if only_selected and not information_source.is_selected:
for information_source_item in labeling_task_item.information_sources:
if only_selected and not information_source_item.is_selected:
continue
results = (
record_label_association.get_all_classifications_for_information_source(
project_id, information_source.id
project_id, information_source_item.id
)
)
query_results.extend(results)
Expand All @@ -149,11 +211,11 @@ def collect_data(
labeling_task_item.task_type
== enums.LabelingTaskType.INFORMATION_EXTRACTION.value
):
for information_source in labeling_task_item.information_sources:
if only_selected and not information_source.is_selected:
for information_source_item in labeling_task_item.information_sources:
if only_selected and not information_source_item.is_selected:
continue
results = record_label_association.get_all_extraction_tokens_for_information_source(
project_id, information_source.id
project_id, information_source_item.id
)
query_results.extend(results)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,5 @@ urllib3==1.26.16
# requests
uvicorn==0.22.0
# via -r requirements/common-requirements.txt
weak-nlp==0.0.12
weak-nlp==0.0.13
# via -r requirements/requirements.in
2 changes: 1 addition & 1 deletion requirements/requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-r common-requirements.txt
weak-nlp==0.0.12
weak-nlp==0.0.13