-
Notifications
You must be signed in to change notification settings - Fork 5
/
deepcoder_dataset_loader.py
79 lines (68 loc) · 2.62 KB
/
deepcoder_dataset_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from Predictions.models import RulesPredictor
import json
from program import BasicPrimitive, Function, Variable
from type_system import BOOL, INT, Arrow, List, Type
from typing import Any, Tuple
import typing
from DSL import deepcoder
import dsl
from experiment_helper import filter_examples
my_dsl = dsl.DSL(deepcoder.semantics, deepcoder.primitive_types, deepcoder.no_repetitions)
def load_tasks(file: str) -> Tuple[typing.List[Tuple[str,Any]], set]:
tasks = []
all_types = set()
with open(file, "r") as fd:
raw_tasks = json.load(fd)
for raw_task in raw_tasks:
name = raw_task["program"]
raw_examples = raw_task["examples"]
examples = [(raw_example["inputs"] + [None], raw_example["output"])
for raw_example in raw_examples]
prog, type_request = __str2prog(name)
tasks.append((prog, examples))
all_types.add(type_request)
return tasks, all_types
def __str2prog(s: str):
parts = s.split("|")
stack = []
var = 0
type_stack = []
for part in parts:
subparts = part.split(",")
name = subparts.pop(0)
if name == "LIST":
stack.append(Variable(var, List(INT)))
var += 1
type_stack.append(List(INT))
continue
if name == "INT":
stack.append(Variable(var, INT))
var += 1
type_stack.append(INT)
continue
if name not in deepcoder.primitive_types:
name = name + "[" + subparts.pop(0) + "]"
primitive = BasicPrimitive(name, deepcoder.primitive_types[name])
targets = [int(x) for x in subparts]
arguments = [stack[x] for x in targets]
stack.append(Function(primitive, arguments, type_=primitive.type.returns()))
type_request = stack[-1].type
while type_stack:
type_request = Arrow(type_stack.pop(), type_request)
return stack[-1], type_request
def filter_tasks_for_model(tasks, model) -> typing.List[Tuple[str, Any]]:
filtered_tasks = []
for task in tasks:
name, examples = task
# Remove tasks that return null
if any(o is None for _, o in examples):
continue
type_request: Type = name.type
if isinstance(model, RulesPredictor) and type_request != Arrow(List(INT), List(INT)):
continue
examples = filter_examples(
examples, model.IOEncoder.nb_arguments_max, model.IOEncoder.size_max, model.IOEncoder.symbolToIndex)
if len(examples) == 0:
continue
filtered_tasks.append((name, examples))
return filtered_tasks