diff --git a/farm/data_handler/data_silo.py b/farm/data_handler/data_silo.py index 9b24e21f3..b1d1f8ffe 100644 --- a/farm/data_handler/data_silo.py +++ b/farm/data_handler/data_silo.py @@ -478,8 +478,12 @@ def calculate_class_weights(self, task_name, source="train"): else: raise Exception("source argument expects one of [\"train\", \"all\"]") for dataset in datasets: - if dataset is not None: + if "multilabel" in self.processor.tasks[task_name]["task_type"]: + for x in dataset: + observed_labels += [label_list[label_id] for label_id in (x[tensor_idx] == 1).nonzero()] + else: observed_labels += [label_list[x[tensor_idx].item()] for x in dataset] + #TODO scale e.g. via logarithm to avoid crazy spikes for rare classes class_weights = list(compute_class_weight("balanced", np.asarray(label_list), observed_labels)) return class_weights