Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import FastAPI, HTTPException, responses
from fastapi import FastAPI, HTTPException, responses, status
from pydantic import BaseModel

from controller import stats
Expand Down Expand Up @@ -34,7 +34,9 @@ class ExportWsStatsRequest(BaseModel):


@app.post("/fit_predict")
async def weakly_supervise(request: WeakSupervisionRequest) -> int:
async def weakly_supervise(
request: WeakSupervisionRequest,
) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
integration.fit_predict(
request.project_id,
Expand All @@ -43,21 +45,25 @@ async def weakly_supervise(request: WeakSupervisionRequest) -> int:
request.weak_supervision_task_id,
)
general.remove_and_refresh_session(session_token)
return None, 200
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)


@app.post("/labeling_task_statistics")
async def calculate_task_stats(request: TaskStatsRequest):
async def calculate_task_stats(
request: TaskStatsRequest,
) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
stats.calculate_quality_statistics_for_labeling_task(
request.project_id, request.labeling_task_id, request.user_id
)
general.remove_and_refresh_session(session_token)
return None, 200
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)


@app.post("/source_statistics")
async def calculate_source_stats(request: SourceStatsRequest):
async def calculate_source_stats(
request: SourceStatsRequest,
) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source(
request.project_id, request.source_id, request.user_id
Expand All @@ -67,11 +73,11 @@ async def calculate_source_stats(request: SourceStatsRequest):
request.project_id, request.source_id, request.user_id
)
general.remove_and_refresh_session(session_token)
return None, 200
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)


@app.post("/export_ws_stats")
async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLResponse:
async def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
status_code, message = integration.export_weak_supervision_stats(
request.project_id, request.labeling_task_id
Expand All @@ -80,4 +86,4 @@ async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLRespon

if status_code != 200:
raise HTTPException(status_code=status_code, detail=message)
return responses.HTMLResponse(status_code=status_code)
return responses.PlainTextResponse(status_code=status_code)
4 changes: 2 additions & 2 deletions controller/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def integrate_classification(df: pd.DataFrame):
for record_id, (
label_id,
confidence,
) in weak_supervision_results.dropna().iteritems():
) in weak_supervision_results.dropna().items():
return_values[record_id].append(
{"label_id": label_id, "confidence": confidence}
)
Expand All @@ -107,7 +107,7 @@ def integrate_extraction(df: pd.DataFrame):
enlm = util.get_enlm_from_df(df)
weak_supervision_results = enlm.weakly_supervise()
return_values = defaultdict(list)
for record_id, preds in weak_supervision_results.iteritems():
for record_id, preds in weak_supervision_results.items():
for pred in preds:
label, confidence, token_min, token_max = pred
return_values[record_id].append(
Expand Down