Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions controller/payload/payload_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
InformationSourceStatisticsExclusion,
RecordLabelAssociation,
InformationSourcePayload,
User,
)
from util import daemon, doc_ock, notification
from submodules.s3 import controller as s3
Expand All @@ -54,7 +53,6 @@
from util.notification import create_notification
from util.miscellaneous_functions import chunk_dict
from controller.weak_supervision import weak_supervision_service as weak_supervision
from controller.user import manager as user_manager
import time
import uuid

Expand Down Expand Up @@ -99,8 +97,10 @@ def prepare_and_run_execution_pipeline(
payload_id: str,
project_id: str,
information_source_item: InformationSource,
in_thread: bool = False,
) -> None:
ctx_token = general.get_ctx_token()
if in_thread:
ctx_token = general.get_ctx_token()
try:
add_file_name, input_data = prepare_input_data_for_payload(
information_source_item
Expand All @@ -125,7 +125,8 @@ def prepare_and_run_execution_pipeline(
information_source_item.name,
)
finally:
general.reset_ctx_token(ctx_token, True)
if in_thread:
general.reset_ctx_token(ctx_token, True)

def prepare_input_data_for_payload(
information_source_item: InformationSource,
Expand Down Expand Up @@ -327,6 +328,7 @@ def execution_pipeline(
payload.id,
project_id,
information_source_item,
in_thread=True,
)
else:
prepare_and_run_execution_pipeline(
Expand Down
26 changes: 1 addition & 25 deletions controller/weak_supervision/manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import timeit
from typing import Any, Tuple, Optional
from typing import Optional

from submodules.model import enums, WeakSupervisionTask
from submodules.model.business_objects import labeling_task
from submodules.model.business_objects import weak_supervision
from controller.weak_supervision.weak_supervision_service import (
initiate_weak_supervision,
)


def create_task(
Expand Down Expand Up @@ -35,22 +30,3 @@ def update_weak_supervision_task_stats(
enums.PayloadState.FINISHED.value,
with_commit=True,
)


def start_weak_supervision_by_project_id(
project_id: str, user_id: str, ws_task_id: str
) -> None:
selected_tasks = labeling_task.get_labeling_tasks_by_selected_sources(project_id)
for labeling_task_item in selected_tasks:
initiate_weak_supervision(
project_id, labeling_task_item.id, user_id, ws_task_id
)


def start_weak_supervision_by_task_id(
project_id: str, task_id: str, user_id: str, ws_task_id: str
) -> Tuple[float, float]:
start = timeit.default_timer()
initiate_weak_supervision(project_id, task_id, user_id, ws_task_id)
stop = timeit.default_timer()
return start, stop
9 changes: 7 additions & 2 deletions controller/weak_supervision/weak_supervision_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any
from typing import Any, Optional, Union, Dict

from util import notification, service_requests
from util.decorator import debounce
Expand All @@ -8,14 +8,19 @@


def initiate_weak_supervision(
project_id: str, labeling_task_id: str, user_id: str, weak_supervision_task_id: str
project_id: str,
labeling_task_id: str,
user_id: str,
weak_supervision_task_id: str,
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
) -> Any:
url = f"{BASE_URI}/fit_predict"
data = {
"project_id": str(project_id),
"labeling_task_id": str(labeling_task_id),
"user_id": str(user_id),
"weak_supervision_task_id": str(weak_supervision_task_id),
"overwrite_weak_supervision": overwrite_weak_supervision,
}
return service_requests.post_call_or_raise(url, data)

Expand Down
71 changes: 55 additions & 16 deletions graphql_api/mutation/weak_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import graphene
import traceback
from typing import Optional, Dict

from controller.auth import manager as auth
from controller.weak_supervision import manager as ws_manager
Expand All @@ -16,14 +17,12 @@
)
from submodules.model.business_objects.information_source import (
get_selected_information_sources,
get_task_information_sources,
)
from submodules.model.business_objects.labeling_task import (
get_selected_labeling_task_names,
)
from submodules.model.enums import NotificationType
from util import daemon
from util import notification
from util import daemon, notification
from controller.weak_supervision.weak_supervision_service import (
initiate_weak_supervision,
)
Expand All @@ -36,10 +35,18 @@
class InitiateWeakSupervisionByProjectId(graphene.Mutation):
class Arguments:
project_id = graphene.ID(required=True)
overwrite_default_precision = graphene.Float(required=False)
overwrite_weak_supervision = graphene.JSONString(required=False)

ok = graphene.Boolean()

def mutate(self, info, project_id: str):
def mutate(
self,
info,
project_id: str,
overwrite_default_precision: Optional[float] = None,
overwrite_weak_supervision: Optional[Dict[str, float]] = None,
):
auth.check_demo_access(info)
auth.check_project_access(info, project_id)
user = auth.get_user_by_info(info)
Expand All @@ -49,7 +56,7 @@ def mutate(self, info, project_id: str):
project_id,
"Weak Supervision Task",
)
notification.send_organization_update(project_id, f"weak_supervision_started")
notification.send_organization_update(project_id, "weak_supervision_started")

weak_supervision_task = ws_manager.create_task(
project_id=project_id,
Expand All @@ -62,19 +69,29 @@ def mutate(self, info, project_id: str):
)

def execution_pipeline(
project_id: str, user_id: str, weak_supervision_task_id: str
project_id: str,
user_id: str,
weak_supervision_task_id: str,
overwrite_default_precision: Optional[float] = None,
overwrite_weak_supervision: Optional[Dict[str, float]] = None,
):
ctx_token = general.get_ctx_token()
try:
labeling_tasks = labeling_task.get_labeling_tasks_by_selected_sources(
project_id
)
for labeling_task_item in labeling_tasks:
overwrite_ws = overwrite_default_precision
if overwrite_weak_supervision is not None:
overwrite_ws = overwrite_weak_supervision.get(
str(labeling_task_item.id)
)
initiate_weak_supervision(
project_id,
labeling_task_item.id,
user_id,
weak_supervision_task_id,
overwrite_ws,
)
ws_manager.update_weak_supervision_task_stats(
weak_supervision_task_id, project_id
Expand All @@ -87,7 +104,7 @@ def execution_pipeline(
"Weak Supervision Task",
)
notification.send_organization_update(
project_id, f"weak_supervision_finished"
project_id, "weak_supervision_finished"
)
except Exception as e:
print(traceback.format_exc(), flush=True)
Expand All @@ -99,14 +116,19 @@ def execution_pipeline(
with_commit=True,
)
notification.send_organization_update(
project_id, f"weak_supervision_finished"
project_id, "weak_supervision_finished"
)
raise e
finally:
general.reset_ctx_token(ctx_token, True)

daemon.run(
execution_pipeline, project_id, str(user.id), str(weak_supervision_task.id)
execution_pipeline,
project_id,
str(user.id),
str(weak_supervision_task.id),
overwrite_default_precision,
overwrite_weak_supervision,
)
return InitiateWeakSupervisionByProjectId(ok=True)

Expand All @@ -116,28 +138,44 @@ class Arguments:
project_id = graphene.ID(required=True)
information_source_id = graphene.ID(required=True)
labeling_task_id = graphene.ID(required=True)
overwrite_default_precision = graphene.Float(required=False)
overwrite_weak_supervision = graphene.JSONString(required=False)

ok = graphene.Boolean()

def mutate(
self, info, project_id: str, information_source_id: str, labeling_task_id: str
self,
info,
project_id: str,
information_source_id: str,
labeling_task_id: str,
overwrite_default_precision: Optional[float] = None,
overwrite_weak_supervision: Optional[Dict[str, float]] = None,
):
auth.check_demo_access(info)
auth.check_project_access(info, project_id)
user = auth.get_user_by_info(info)
pl_manager.create_payload(
payload = pl_manager.create_payload(
project_id, information_source_id, user.id, asynchronous=False
)
if not payload.state == enums.PayloadState.FINISHED.value:
return RunInformationSourceAndInitiateWeakSupervisionByLabelingTaskId(
ok=True
)

source_names = []
labeling_task_item = labeling_task.get(project_id, labeling_task_id)
for information_source in labeling_task_item.information_sources:
information_source.is_selected = any(
source_statistic.true_positives > 0
for source_statistic in information_source.source_statistics
if source_statistic.true_positives is not None
information_source.is_selected = (
information_source.payloads[0].state
== enums.PayloadState.FINISHED.value
if len(information_source.payloads) > 0
else False
)
if information_source.is_selected:
source_names.append(information_source.name)
general.commit()

source_names = get_task_information_sources(project_id, labeling_task_id)
if len(source_names) > 0:
create_notification(
NotificationType.WEAK_SUPERVISION_TASK_STARTED,
Expand All @@ -164,6 +202,7 @@ def mutate(
labeling_task_id,
user.id,
weak_supervision_task.id,
overwrite_weak_supervision or overwrite_default_precision,
)
ws_manager.update_weak_supervision_task_stats(
weak_supervision_task.id, project_id
Expand Down