diff --git a/neurons/validator.py b/neurons/validator.py index 846683d29..6f8370d2d 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -40,7 +40,7 @@ def __init__(self, config=None): mock=self.config.mock, ) - if sum(self.config.neuron.task_p) != 1: + if abs(1-sum(self.config.neuron.task_p)) > 0.001: raise ValueError("Task probabilities do not sum to 1.") # Filter out tasks with 0 probability diff --git a/prompting/task_registry.py b/prompting/task_registry.py index 22ea3b19f..22c1d6e4c 100644 --- a/prompting/task_registry.py +++ b/prompting/task_registry.py @@ -5,7 +5,7 @@ mock_task, mock_dataset = MockTask.name, [MockDataset.name] summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name] qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name] -debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] +#debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] math_task, math_dataset = MathTask.name, [MathDataset.name] date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name] generic_instruction_task, generic_instruction_dataset = GenericInstructionTask.name, [GenericInstructionDataset.name] @@ -14,7 +14,7 @@ mock_task: mock_dataset, summarization_task: summarization_dataset, qa_task: qa_dataset, - debugging_task: debugging_dataset, + #debugging_task: debugging_dataset, math_task: math_dataset, date_qa_task: date_qa_dataset, generic_instruction_task: generic_instruction_dataset diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index 81e8e373c..bc08c7fe3 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -13,7 +13,7 @@ QuestionAnsweringTask.name: QuestionAnsweringTask, DateQuestionAnsweringTask.name: DateQuestionAnsweringTask, SummarizationTask.name: SummarizationTask, - DebuggingTask.name: DebuggingTask, + #DebuggingTask.name: DebuggingTask, GenericInstructionTask.name: GenericInstructionTask, MathTask.name: MathTask, } diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index 8ee07bc0d..e9ef44df4 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -13,7 +13,7 @@ DATASETS = { MockDataset.name: MockDataset, - HFCodingDataset.name: HFCodingDataset, + #HFCodingDataset.name: HFCodingDataset, WikiDataset.name: WikiDataset, #StackOverflowDataset.name: StackOverflowDataset, MathDataset.name: MathDataset, diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 6da117f8c..d52208193 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -286,7 +286,7 @@ def add_validator_args(cls, parser): type=float, nargs="+", help="The probability of sampling each task.", - default=[.2, .2, .2, 0, .2, .2], + default=[1.0 / (len(TASKS)-1)] * (len(TASKS)-1), ) parser.add_argument(