Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, responses
from pydantic import BaseModel

from controller import stats
from controller import integration
Expand All @@ -7,8 +8,6 @@
# API creation and description
app = FastAPI()

from pydantic import BaseModel


class WeakSupervisionRequest(BaseModel):
project_id: str
Expand All @@ -29,6 +28,11 @@ class SourceStatsRequest(BaseModel):
user_id: str


class ExportWsStatsRequest(BaseModel):
project_id: str
labeling_task_id: str


@app.post("/fit_predict")
async def weakly_supervise(request: WeakSupervisionRequest) -> int:
session_token = general.get_ctx_token()
Expand All @@ -43,7 +47,7 @@ async def weakly_supervise(request: WeakSupervisionRequest) -> int:


@app.post("/labeling_task_statistics")
async def calculate_stats(request: TaskStatsRequest):
async def calculate_task_stats(request: TaskStatsRequest):
session_token = general.get_ctx_token()
stats.calculate_quality_statistics_for_labeling_task(
request.project_id, request.labeling_task_id, request.user_id
Expand All @@ -53,7 +57,7 @@ async def calculate_stats(request: TaskStatsRequest):


@app.post("/source_statistics")
async def calculate_stats(request: SourceStatsRequest):
async def calculate_source_stats(request: SourceStatsRequest):
session_token = general.get_ctx_token()
has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source(
request.project_id, request.source_id, request.user_id
Expand All @@ -64,3 +68,16 @@ async def calculate_stats(request: SourceStatsRequest):
)
general.remove_and_refresh_session(session_token)
return None, 200


@app.post("/export_ws_stats")
async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLResponse:
session_token = general.get_ctx_token()
status_code, message = integration.export_weak_supervision_stats(
request.project_id, request.labeling_task_id
)
general.remove_and_refresh_session(session_token)

if status_code != 200:
raise HTTPException(status_code=status_code, detail=message)
return responses.HTMLResponse(status_code=status_code)
43 changes: 42 additions & 1 deletion controller/integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Any, Dict, List, Tuple
import traceback
import pandas as pd
import pickle
from collections import defaultdict

from submodules.model.models import (
Expand All @@ -26,6 +28,7 @@ def fit_predict(
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
results = integrate_classification(df)

else:
results = integrate_extraction(df)
weak_supervision.store_data(
Expand All @@ -37,7 +40,7 @@ def fit_predict(
weak_supervision_task_id,
with_commit=True,
)
except:
except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
weak_supervision.update_state(
Expand All @@ -48,6 +51,44 @@ def fit_predict(
)


def export_weak_supervision_stats(
project_id: str, labeling_task_id: str
) -> Tuple[int, str]:

task_type, df = collect_data(project_id, labeling_task_id, False)
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
cnlm = util.get_cnlm_from_df(df)
stats_df = cnlm.quality_metrics()
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
enlm = util.get_enlm_from_df(df)
stats_df = enlm.quality_metrics()
else:
return 404, f"Task type {task_type} not implemented"

if len(stats_df) != 0:
stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict(
orient="index"
)
else:
return 404, "Can't compute weak supervision"

os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
with open(
os.path.join(
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
),
"wb",
) as f:
pickle.dump(stats_lkp, f)

except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
return 500, "Internal server error"
return 200, "OK"


def integrate_classification(df: pd.DataFrame):
cnlm = util.get_cnlm_from_df(df)
weak_supervision_results = cnlm.weakly_supervise()
Expand Down
15 changes: 15 additions & 0 deletions start
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ echo -ne 'stopping old container...'
docker stop refinery-weak-supervisor > /dev/null 2>&1
echo -ne '\t [done]\n'

INFERENCE_DIR=${PWD%/*}/dev-setup/inference/
if [ ! -d "$_DIR" ]
then
INFERENCE_DIR=${PWD%/*/*}/dev-setup/inference/
if [ ! -d "$INFERENCE_DIR" ]
then
# to include volume for local development, use the dev-setup inference folder:
# alternative use manual logic with
# -v /path/to/dev-setup/inference:/models \
echo "Can't find model data directory: $INFERENCE_DIR -> stopping"
exit 1
fi
fi

echo -ne 'building container...'
docker build -t refinery-weak-supervisor-dev -f dev.Dockerfile . > /dev/null 2>&1
echo -ne '\t\t [done]\n'
Expand All @@ -17,6 +31,7 @@ docker run -d --rm \
-e WS_NOTIFY_ENDPOINT="http://refinery-websocket:8080" \
--mount type=bind,source="$(pwd)"/,target=/app \
-v /var/run/docker.sock:/var/run/docker.sock \
-v "$INFERENCE_DIR":/inference \
--network dev-setup_default \
refinery-weak-supervisor-dev > /dev/null 2>&1
echo -ne '\t\t\t [done]\n'
Expand Down