diff --git a/app.py b/app.py index e790bc6..f134a45 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, HTTPException, responses +from fastapi import FastAPI, HTTPException, responses, status from pydantic import BaseModel from controller import stats @@ -34,7 +34,9 @@ class ExportWsStatsRequest(BaseModel): @app.post("/fit_predict") -async def weakly_supervise(request: WeakSupervisionRequest) -> int: +async def weakly_supervise( + request: WeakSupervisionRequest, +) -> responses.PlainTextResponse: session_token = general.get_ctx_token() integration.fit_predict( request.project_id, @@ -43,21 +45,25 @@ async def weakly_supervise(request: WeakSupervisionRequest) -> int: request.weak_supervision_task_id, ) general.remove_and_refresh_session(session_token) - return None, 200 + return responses.PlainTextResponse(status_code=status.HTTP_200_OK) @app.post("/labeling_task_statistics") -async def calculate_task_stats(request: TaskStatsRequest): +async def calculate_task_stats( + request: TaskStatsRequest, +) -> responses.PlainTextResponse: session_token = general.get_ctx_token() stats.calculate_quality_statistics_for_labeling_task( request.project_id, request.labeling_task_id, request.user_id ) general.remove_and_refresh_session(session_token) - return None, 200 + return responses.PlainTextResponse(status_code=status.HTTP_200_OK) @app.post("/source_statistics") -async def calculate_source_stats(request: SourceStatsRequest): +async def calculate_source_stats( + request: SourceStatsRequest, +) -> responses.PlainTextResponse: session_token = general.get_ctx_token() has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source( request.project_id, request.source_id, request.user_id @@ -67,11 +73,11 @@ async def calculate_source_stats(request: SourceStatsRequest): request.project_id, request.source_id, request.user_id ) general.remove_and_refresh_session(session_token) - return None, 200 + return responses.PlainTextResponse(status_code=status.HTTP_200_OK) @app.post("/export_ws_stats") -async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLResponse: +async def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse: session_token = general.get_ctx_token() status_code, message = integration.export_weak_supervision_stats( request.project_id, request.labeling_task_id @@ -80,4 +86,4 @@ async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLRespon if status_code != 200: raise HTTPException(status_code=status_code, detail=message) - return responses.HTMLResponse(status_code=status_code) + return responses.PlainTextResponse(status_code=status_code) diff --git a/controller/integration.py b/controller/integration.py index 47b5256..9f6251d 100644 --- a/controller/integration.py +++ b/controller/integration.py @@ -96,7 +96,7 @@ def integrate_classification(df: pd.DataFrame): for record_id, ( label_id, confidence, - ) in weak_supervision_results.dropna().iteritems(): + ) in weak_supervision_results.dropna().items(): return_values[record_id].append( {"label_id": label_id, "confidence": confidence} ) @@ -107,7 +107,7 @@ def integrate_extraction(df: pd.DataFrame): enlm = util.get_enlm_from_df(df) weak_supervision_results = enlm.weakly_supervise() return_values = defaultdict(list) - for record_id, preds in weak_supervision_results.iteritems(): + for record_id, preds in weak_supervision_results.items(): for pred in preds: label, confidence, token_min, token_max = pred return_values[record_id].append(