diff --git a/app.py b/app.py index 53dd07b..e790bc6 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,5 @@ -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException, responses +from pydantic import BaseModel from controller import stats from controller import integration @@ -7,8 +8,6 @@ # API creation and description app = FastAPI() -from pydantic import BaseModel - class WeakSupervisionRequest(BaseModel): project_id: str @@ -29,6 +28,11 @@ class SourceStatsRequest(BaseModel): user_id: str +class ExportWsStatsRequest(BaseModel): + project_id: str + labeling_task_id: str + + @app.post("/fit_predict") async def weakly_supervise(request: WeakSupervisionRequest) -> int: session_token = general.get_ctx_token() @@ -43,7 +47,7 @@ async def weakly_supervise(request: WeakSupervisionRequest) -> int: @app.post("/labeling_task_statistics") -async def calculate_stats(request: TaskStatsRequest): +async def calculate_task_stats(request: TaskStatsRequest): session_token = general.get_ctx_token() stats.calculate_quality_statistics_for_labeling_task( request.project_id, request.labeling_task_id, request.user_id @@ -53,7 +57,7 @@ async def calculate_stats(request: TaskStatsRequest): @app.post("/source_statistics") -async def calculate_stats(request: SourceStatsRequest): +async def calculate_source_stats(request: SourceStatsRequest): session_token = general.get_ctx_token() has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source( request.project_id, request.source_id, request.user_id @@ -64,3 +68,16 @@ async def calculate_stats(request: SourceStatsRequest): ) general.remove_and_refresh_session(session_token) return None, 200 + + +@app.post("/export_ws_stats") +async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLResponse: + session_token = general.get_ctx_token() + status_code, message = integration.export_weak_supervision_stats( + request.project_id, request.labeling_task_id + ) + general.remove_and_refresh_session(session_token) + + if status_code != 200: + raise HTTPException(status_code=status_code, detail=message) + return responses.HTMLResponse(status_code=status_code) diff --git a/controller/integration.py b/controller/integration.py index a34f621..47b5256 100644 --- a/controller/integration.py +++ b/controller/integration.py @@ -1,6 +1,8 @@ +import os from typing import Any, Dict, List, Tuple import traceback import pandas as pd +import pickle from collections import defaultdict from submodules.model.models import ( @@ -26,6 +28,7 @@ def fit_predict( try: if task_type == enums.LabelingTaskType.CLASSIFICATION.value: results = integrate_classification(df) + else: results = integrate_extraction(df) weak_supervision.store_data( @@ -37,7 +40,7 @@ def fit_predict( weak_supervision_task_id, with_commit=True, ) - except: + except Exception: print(traceback.format_exc(), flush=True) general.rollback() weak_supervision.update_state( @@ -48,6 +51,44 @@ def fit_predict( ) +def export_weak_supervision_stats( + project_id: str, labeling_task_id: str +) -> Tuple[int, str]: + + task_type, df = collect_data(project_id, labeling_task_id, False) + try: + if task_type == enums.LabelingTaskType.CLASSIFICATION.value: + cnlm = util.get_cnlm_from_df(df) + stats_df = cnlm.quality_metrics() + elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value: + enlm = util.get_enlm_from_df(df) + stats_df = enlm.quality_metrics() + else: + return 404, f"Task type {task_type} not implemented" + + if len(stats_df) != 0: + stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict( + orient="index" + ) + else: + return 404, "Can't compute weak supervision" + + os.makedirs(os.path.join("/inference", project_id), exist_ok=True) + with open( + os.path.join( + "/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl" + ), + "wb", + ) as f: + pickle.dump(stats_lkp, f) + + except Exception: + print(traceback.format_exc(), flush=True) + general.rollback() + return 500, "Internal server error" + return 200, "OK" + + def integrate_classification(df: pd.DataFrame): cnlm = util.get_cnlm_from_df(df) weak_supervision_results = cnlm.weakly_supervise() diff --git a/start b/start index b066f6d..cd6c64f 100755 --- a/start +++ b/start @@ -5,6 +5,20 @@ echo -ne 'stopping old container...' docker stop refinery-weak-supervisor > /dev/null 2>&1 echo -ne '\t [done]\n' +INFERENCE_DIR=${PWD%/*}/dev-setup/inference/ +if [ ! -d "$_DIR" ] +then + INFERENCE_DIR=${PWD%/*/*}/dev-setup/inference/ + if [ ! -d "$INFERENCE_DIR" ] + then + # to include volume for local development, use the dev-setup inference folder: + # alternative use manual logic with + # -v /path/to/dev-setup/inference:/models \ + echo "Can't find model data directory: $INFERENCE_DIR -> stopping" + exit 1 + fi +fi + echo -ne 'building container...' docker build -t refinery-weak-supervisor-dev -f dev.Dockerfile . > /dev/null 2>&1 echo -ne '\t\t [done]\n' @@ -17,6 +31,7 @@ docker run -d --rm \ -e WS_NOTIFY_ENDPOINT="http://refinery-websocket:8080" \ --mount type=bind,source="$(pwd)"/,target=/app \ -v /var/run/docker.sock:/var/run/docker.sock \ +-v "$INFERENCE_DIR":/inference \ --network dev-setup_default \ refinery-weak-supervisor-dev > /dev/null 2>&1 echo -ne '\t\t\t [done]\n'