Skip to content

Commit

Permalink
Made sure all labels are represented in the training set
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed May 8, 2024
1 parent 49afeb1 commit 96312c7
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mteb/abstasks/AbsTaskMultilabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,11 @@ def _undersample_data_indices(self, y, samples_per_label, idxs=None):
idxs = np.arange(len(y))
np.random.shuffle(idxs)
label_counter = defaultdict(int)
unique_labels = set(itertools.chain.from_iterable(y))
for i in idxs:
if any((label_counter[label] < samples_per_label) for label in y[i]):
if any(
(label_counter[label] < samples_per_label) for label in unique_labels
):
sample_indices.append(i)
for label in y[i]:
label_counter[label] += 1
Expand Down

0 comments on commit 96312c7

Please sign in to comment.