Skip to content

Commit

Permalink
feat (backend): two behavior modes for labeling counts
Browse files Browse the repository at this point in the history
  • Loading branch information
arielge authored and alonh committed Jan 31, 2023
1 parent 0130c7a commit 2ff8f30
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 36 deletions.
3 changes: 2 additions & 1 deletion label_sleuth/data_access/core/data_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class Label:
label_type: LabelType = field(default=LabelType.Standard)

def get_detailed_label_name(self):
return str(self.label) if self.label_type == LabelType.Standard else f'{self.label_type.name}_{self.label}'
return str(self.label).lower() if self.label_type == LabelType.Standard \
else f'{self.label_type.name}_{self.label}'.lower()

def to_dict(self):
dict_for_json = {'label': self.label, 'metadata': self.metadata, 'label_type': self.label_type.value}
Expand Down
7 changes: 4 additions & 3 deletions label_sleuth/data_access/data_access_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ def get_labeled_text_elements(self, workspace_id: str, dataset_name: str, catego

@abc.abstractmethod
def get_label_counts(self, workspace_id: str, dataset_name: str, category_id: int, remove_duplicates=False,
label_types: Set[LabelType] = frozenset({LabelType.Standard})) \
-> Mapping[bool, int]:
label_types: Set[LabelType] = frozenset(LabelType._member_map_.values()),
fine_grained_counts=True) -> Mapping[Union[str, bool], int]:
"""
Return for each label value, assigned to category_id, the total count of its appearances in dataset_name.
:param workspace_id: the workspace_id of the labeling effort.
:param dataset_name: the name of the dataset from which labels count should be generated
:param category_id: the id of the category whose label information is the target
:param remove_duplicates: if True, do not include elements that are duplicates of each other.
:param label_types: by default, only the LabelType.Standard (strong labels) are retrieved.
:param label_types: by default, labels of all types are retrieved.
:param fine_grained_counts: if True, count labels of each label type separately.
:return: a map whose keys are label values, and the values are the number of TextElements this label was
assigned to.
"""
Expand Down
15 changes: 10 additions & 5 deletions label_sleuth/data_access/file_based/file_based_data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,16 @@ def get_labeled_text_elements(self, workspace_id: str, dataset_name: str, catego
return results_dict

def get_label_counts(self, workspace_id: str, dataset_name: str, category_id: int, remove_duplicates=False,
label_types: Set[LabelType] = frozenset({LabelType.Standard})) -> Mapping[bool, int]:
label_types: Set[LabelType] = frozenset(LabelType._member_map_.values()),
fine_grained_counts=True) -> Mapping[Union[str, bool], int]:
"""
Return for each label value, assigned to category_id, the total count of its appearances in dataset_name.
:param workspace_id: the workspace_id of the labeling effort.
:param dataset_name: the name of the dataset from which labels count should be generated
:param category_id: the id of the category whose label information is the target
:param remove_duplicates: if True, do not include elements that are duplicates of each other.
:param label_types: by default, only the LabelType.Standard (strong labels) are retrieved.
:param label_types: by default, labels of all types are retrieved.
:param fine_grained_counts: if True, count labels of each label type separately.
:return: a map whose keys are label values, and the values are the number of TextElements this label was
assigned to.
"""
Expand All @@ -355,10 +357,13 @@ def get_label_counts(self, workspace_id: str, dataset_name: str, category_id: in
if uri in uris_to_keep}

category_label_list = \
[category_to_label[category_id].label for category_to_label in labels_by_uri.values()
[category_to_label[category_id] for category_to_label in labels_by_uri.values()
if category_id in category_to_label and category_to_label[category_id].label_type in label_types]
category_label_counts = Counter(category_label_list)
return category_label_counts

if fine_grained_counts:
return Counter(lbl_obj.get_detailed_label_name() for lbl_obj in category_label_list)
else:
return Counter(lbl_obj.label for lbl_obj in category_label_list)

def delete_all_labels(self, workspace_id, dataset_name):
"""
Expand Down
21 changes: 11 additions & 10 deletions label_sleuth/data_access/test_file_based_data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def test_unset_labels(self):
self.data_access.set_labels(workspace_id, dict(texts_and_labels_list))

labels_count = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
self.assertGreater(labels_count[True], 0)
self.assertGreater(labels_count['true'], 0)
self.data_access.unset_labels(workspace_id, category_id, [x[0] for x in texts_and_labels_list])
labels_count_after_unset = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
self.assertEqual(0, labels_count_after_unset[True])
self.assertEqual(0, labels_count_after_unset['true'])
self.data_access.delete_dataset(dataset_name)

def test_get_all_document_uris(self):
Expand Down Expand Up @@ -355,7 +355,7 @@ def set_labels_thread(self, categories, category_id, corpus, workspace_id, start
[1 for _, label_dict in sentences_and_labels_all.items() if label_dict[category_id].label == LABEL_POSITIVE])
false_count = len(
[1 for _, label_dict in sentences_and_labels_all.items() if label_dict[category_id].label == LABEL_NEGATIVE])
results[idx] = {LABEL_POSITIVE: true_count, LABEL_NEGATIVE: false_count}
results[idx] = {str(LABEL_POSITIVE).lower(): true_count, str(LABEL_NEGATIVE).lower(): false_count}

def test_get_label_counts(self):
workspace_id = 'test_get_label_counts'
Expand All @@ -376,7 +376,7 @@ def test_get_label_counts(self):
for label_val, observed_count in category_label_counts.items():
expected_count = len(
[label for uri, label in uri_to_label_dict.items() if
category_to_count in label and label_val == label[category_to_count].label]) # TODO verify
category_to_count in label and label_val == str(label[category_to_count].label).lower()])
self.assertEqual(expected_count, observed_count, f'count for {label_val} does not match.')
self.data_access.delete_dataset(dataset_name)

Expand Down Expand Up @@ -427,15 +427,15 @@ def test_duplicates_removal(self):
# set labels without propagating to duplicates
self.data_access.set_labels(workspace_id, uri_to_label_dict, apply_to_duplicate_texts=False)
labels_count = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
self.assertEqual(labels_count[LABEL_POSITIVE], len(all_without_dups))
self.assertEqual(labels_count[str(LABEL_POSITIVE).lower()], len(all_without_dups))
# unset labels
self.data_access.unset_labels(workspace_id, category_id, [elem.uri for elem in all_without_dups])
labels_count = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
self.assertEqual(labels_count[LABEL_POSITIVE], 0)
self.assertEqual(labels_count[str(LABEL_POSITIVE).lower()], 0)
# set labels with propagating to duplicates
self.data_access.set_labels(workspace_id, uri_to_label_dict, apply_to_duplicate_texts=True)
labels_count = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
self.assertEqual(labels_count[LABEL_POSITIVE], len(all_elements))
self.assertEqual(labels_count[str(LABEL_POSITIVE).lower()], len(all_elements))
self.data_access.unset_labels(workspace_id, category_id, [elem.uri for elem in all_elements])

# 2. test sampling of duplicate examples:
Expand All @@ -448,17 +448,18 @@ def test_duplicates_removal(self):
sampled = \
self.data_access.get_labeled_text_elements(workspace_id, dataset_name, category_id, 10 ** 6,
remove_duplicates=True)['results']
self.assertEqual(labels_count[LABEL_POSITIVE], len(sampled), len(non_representative_duplicates))
self.assertEqual(labels_count[str(LABEL_POSITIVE).lower()], len(sampled), len(non_representative_duplicates))
# set labels with propagating to duplicates
self.data_access.set_labels(workspace_id, uri_to_label_dict, apply_to_duplicate_texts=True)
labels_count = self.data_access.get_label_counts(workspace_id, dataset_name, category_id)
sampled = \
self.data_access.get_labeled_text_elements(workspace_id, dataset_name, category_id, 10 ** 6,
remove_duplicates=True)['results']
self.assertGreater(labels_count[LABEL_POSITIVE], len(non_representative_duplicates))
self.assertGreater(labels_count[str(LABEL_POSITIVE).lower()], len(non_representative_duplicates))
labels_count_no_dups = self.data_access.get_label_counts(workspace_id, dataset_name, category_id,
remove_duplicates=True)
self.assertEqual(labels_count_no_dups[LABEL_POSITIVE], len(sampled), len(non_representative_duplicates))
self.assertEqual(labels_count_no_dups[str(LABEL_POSITIVE).lower()], len(sampled),
len(non_representative_duplicates))
self.data_access.delete_all_labels(workspace_id, dataset_name)
self.data_access.delete_dataset(dataset_name)

Expand Down
46 changes: 29 additions & 17 deletions label_sleuth/orchestrator/orchestrator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from collections import Counter, defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import Mapping, List, Sequence, Union, Tuple
from typing import Mapping, List, Sequence, Union, Tuple, Set

import jsonpickle
import pandas as pd
Expand Down Expand Up @@ -293,18 +293,29 @@ def unset_labels(self, workspace_id: str, category_id: int, uris: Sequence[str],
self.data_access.unset_labels(
workspace_id, category_id, uris, apply_to_duplicate_texts=apply_to_duplicate_texts)

def get_label_counts(self, workspace_id: str, dataset_name: str, category_id: int, remove_duplicates=False) -> \
Mapping[bool, int]:
def get_label_counts(self, workspace_id: str, dataset_name: str, category_id: int, remove_duplicates=False,
counts_for_training=False) -> Mapping[Union[str, bool], int]:
"""
Get the number of elements that were labeled for the given category.
:param workspace_id:
:param dataset_name:
:param category_id:
:param remove_duplicates: whether to count all labeled elements or only unique instances
:param counts_for_training: if True, determine the counts as relevant for training a model, e.g. lumping
both strong and weak labels together; if False, count the different types of labels separately
:return:
"""
return self.data_access.get_label_counts(workspace_id, dataset_name, category_id,
remove_duplicates=remove_duplicates)
if counts_for_training:
train_set_selector = get_training_set_selector(self.data_access,
strategy=self.config.training_set_selection_strategy)
used_label_types = train_set_selector.get_label_types()
return self.data_access.get_label_counts(workspace_id, dataset_name, category_id,
remove_duplicates=remove_duplicates,
fine_grained_counts=False,
label_types=used_label_types)
else:
return self.data_access.get_label_counts(workspace_id, dataset_name, category_id,
remove_duplicates=remove_duplicates, fine_grained_counts=True)

# Iteration-related methods

Expand Down Expand Up @@ -613,13 +624,9 @@ def train_if_recommended(self, workspace_id: str, category_id: int, force=False)
changes_since_last_model = \
self.orchestrator_state.get_label_change_count_since_last_train(workspace_id, category_id)

train_set_selector = get_training_set_selector(self.data_access,
strategy=self.config.training_set_selection_strategy)

used_label_types = train_set_selector.get_label_types()
label_counts = self.data_access.get_label_counts(workspace_id=workspace_id, dataset_name=dataset_name,
category_id=category_id, remove_duplicates=True,
label_types=used_label_types)
label_counts = self.get_label_counts(workspace_id=workspace_id, dataset_name=dataset_name,
category_id=category_id, remove_duplicates=True,
counts_for_training=True)

if force or (LABEL_POSITIVE in label_counts
and label_counts[LABEL_POSITIVE] >= self.config.first_model_positive_threshold
Expand All @@ -636,7 +643,8 @@ def train_if_recommended(self, workspace_id: str, category_id: int, force=False)
f"(>={self.config.changed_element_threshold}). Training a new model")
iteration_num = len(iterations_without_errors)
model_type = self.config.model_policy.get_model_type(iteration_num)

train_set_selector = get_training_set_selector(self.data_access,
strategy=self.config.training_set_selection_strategy)
train_data = train_set_selector.get_train_set(workspace_id=workspace_id,
train_dataset_name=dataset_name,
category_id=category_id)
Expand Down Expand Up @@ -798,12 +806,15 @@ def batched(iterable, batch_size=100):
return elements_with_required_prediction[start_idx:start_idx + sample_size]

def get_progress(self, workspace_id: str, dataset_name: str, category_id: int):
category_label_counts = self.get_label_counts(workspace_id, dataset_name, category_id, remove_duplicates=True)
category_label_counts = self.get_label_counts(workspace_id, dataset_name, category_id, remove_duplicates=True,
counts_for_training=True)
if category_label_counts[LABEL_POSITIVE]:
changed_since_last_model_count = \
self.orchestrator_state.get_label_change_count_since_last_train(workspace_id, category_id)

return {"all": min(
# for a new training to start both the number of labels changed and the number of positives must be
# above their respective thresholds; thus, we determine the status as the minimum of the two ratios
max(0, min(round(changed_since_last_model_count / self.config.changed_element_threshold * 100), 100)),
max(0, min(round(category_label_counts[LABEL_POSITIVE] /
self.config.first_model_positive_threshold * 100), 100)))
Expand Down Expand Up @@ -847,7 +858,7 @@ def import_category_labels(self, workspace_id, labels_df_to_import: pd.DataFrame
apply_to_duplicate_texts=self.config.apply_labels_to_duplicate_texts,
update_label_counter=True)

label_counts_dict = self.get_label_counts(workspace_id, dataset_name, category_id, False)
label_counts_dict = self.get_label_counts(workspace_id, dataset_name, category_id, remove_duplicates=False)
logging.info(f"Updated total label count in workspace '{workspace_id}' for category id {category_id} "
f"is {sum(label_counts_dict.values())} ({label_counts_dict})")
categories_counter[category_id] = len(uri_to_label)
Expand Down Expand Up @@ -875,8 +886,9 @@ def export_workspace_labels(self, workspace_id, labeled_only) -> pd.DataFrame:
list_of_dicts = []

for category_id, category in categories.items():
label_counts = self.get_label_counts(workspace_id, dataset_name, category_id, False)
total_count = sum(self.get_label_counts(workspace_id, dataset_name, category_id, False).values())
label_counts = self.get_label_counts(workspace_id, dataset_name, category_id, False,
counts_for_training=True)
total_count = sum(label_counts.values())

if labeled_only or label_counts[LABEL_POSITIVE] == 0: # if there are no positive elements,
# training set selector cannot be used, so we only use the labeled elements
Expand Down

0 comments on commit 2ff8f30

Please sign in to comment.