In [1]:
%cd libraries
import LOTlib3

/Users/meilongzhang/knightlab/codet5/libraries


In [2]:
import os
print(os.getcwd())

/Users/meilongzhang/knightlab/codet5/libraries


In [3]:
import matplotlib.pyplot as plt

In [4]:
import numpy as np

In [5]:
import pandas as pd

In [6]:
from random import sample
from random import randrange
from random import choices

In [7]:
import json

In [8]:
from LOTlib3.Hypotheses.LOTHypothesis import LOTHypothesis
from LOTlib3.DataAndObjects import FunctionData, Obj

In [9]:
from LOTlib3.DefaultGrammars import DNF
from LOTlib3.Miscellaneous import q, random
from LOTlib3.Grammar import Grammar

In [10]:
from LOTlib3.Hypotheses import FunctionHypothesis, Hypothesis

In [11]:
from LOTlib3.Samplers.MetropolisHastings import MetropolisHastingsSampler
from LOTlib3 import break_ctrlc
from LOTlib3.Miscellaneous import qq
from LOTlib3.TopN import TopN

In [14]:
from LOTlib3.Hypotheses.Priors.RationalRules import RationaRulesPrior
from LOTlib3.Hypotheses.Likelihoods.BinaryLikelihood import BinaryLikelihood
from LOTlib3.Examples.RationalRules.Model import MyHypothesis

In [13]:
from scipy.spatial import distance
from scipy.special import logsumexp
import plotly.express as px

In [15]:
DEFAULT_FEATURE_WEIGHT = 5
grammar = Grammar()
grammar.add_rule('START', '', ['DISJ'], 1.0)
grammar.add_rule('START', '', ['PRE-PREDICATE'], DEFAULT_FEATURE_WEIGHT)
grammar.add_rule('START', 'True', None, DEFAULT_FEATURE_WEIGHT)
grammar.add_rule('START', 'False', None, DEFAULT_FEATURE_WEIGHT)

grammar.add_rule('DISJ', '',     ['CONJ'], 1.0)
grammar.add_rule('DISJ', '',     ['PRE-PREDICATE'], DEFAULT_FEATURE_WEIGHT)
grammar.add_rule('DISJ', '(%s or %s)',  ['PRE-PREDICATE', 'DISJ'], 1.0)

grammar.add_rule('CONJ', '',     ['PRE-PREDICATE'], DEFAULT_FEATURE_WEIGHT)
grammar.add_rule('CONJ', '(%s and %s)', ['PRE-PREDICATE', 'CONJ'], 1.0)

# A pre-predicate is how we treat negation
grammar.add_rule('PRE-PREDICATE', '(not %s)', ['PREDICATE'], DEFAULT_FEATURE_WEIGHT)
grammar.add_rule('PRE-PREDICATE', '',     ['PREDICATE'], DEFAULT_FEATURE_WEIGHT)

PRE-PREDICATE -> ['PREDICATE']	w/ p=5.0

In [16]:
#grammar.add_rule('PREDICATE', 'is_color_', ['x', 'COLOR'], 1.0)
#grammar.add_rule('PREDICATE', 'is_shape_', ['x', 'SHAPE'], 1.0)
grammar.add_rule('PREDICATE', "x['color'] == %s", ['COLOR'], 1.0)
grammar.add_rule('PREDICATE', "x['shape'] == %s", ['SHAPE'], 1.0)

# Some colors/shapes each (for this simple demo)
# These are written in quotes so they can be evaled
grammar.add_rule('COLOR', q('red'), None, 1.0)
grammar.add_rule('COLOR', q('blue'), None, 1.0)
grammar.add_rule('COLOR', q('green'), None, 1.0)


grammar.add_rule('SHAPE', q('square'), None, 1.0)
grammar.add_rule('SHAPE', q('circle'), None, 1.0)
grammar.add_rule('SHAPE', q('triangle'), None, 1.0)

SHAPE -> 'triangle'	w/ p=1.0

In [17]:
class Clause:
    name = ''
    leftChild = None
    rightChild = None
    parent = None
    node = None
    
    def __init__(self, n):
        self.name = n
        if (n != 'sentinel'):
            self.parent = self.getSentinel()
    
    def getSentinel(self):
        return Clause('sentinel')
        
    def getName(self):
        return self.name
    
    def setNode(self, no):
        self.node = no
        
    def getNode(self):
        return self.node
    
    def getChildren(self):
        lst = []
        if (self.leftChild != None):
            lst.append(self.leftChild)
        if (self.rightChild != None):
            lst.append(self.rightChild)
        return lst
    
    def isLeaf(self):
        return self.leftChild == None and self.rightChild == None

In [18]:
def makeClauseList(code, startIndex):
    clauseList = []
    i = startIndex
    substr = ''
    while i < len(code):
        if (code[i] == '('):
            clause, i = makeClause(code, i+1, 1)
            clauseList.append(clause)
        
        elif (code[i] == 'n' and code[i+1] == 'o'):
            clauseList.append("not")
            i = i+3
            
        elif (code[i].isspace() and code[i+1] == 'a'):
            if (substr != ''):
                clauseList.append(substr)
                substr = ''
            clauseList.append("and")
            i = i+4
            
        elif (code[i].isspace() and code[i+1] == 'o'):
            if (substr != ''):
                clauseList.append(substr)
                substr = ''
            clauseList.append("or")
            i = i+3
            
        else:
            substr += code[i]
            
        i = i+1
    
    if (substr != ''):
        clauseList.append(substr)
    
    clauseList = convertToClause(clauseList)
    return clauseList
    
            
def makeClause(code, ind, num):
    clause = ''
    while (num != 0):
        if (code[ind] == '('):
            num += 1
        elif (code[ind] == ')'):
            num -= 1
            if (num == 0):
                break
        clause += code[ind]
        ind += 1
    return clause, ind


def convertToClause(lst):
    new_lst = []
    for item in lst:
        clause = Clause(item)
        new_lst.append(clause)
    
    if (len(new_lst) > 1):
        return createParent(new_lst)
    return new_lst

def createParent(lst):
    for i in range(len(lst)):
        if (lst[i].name == 'and' or lst[i].name == 'or'):
            if (lst[i-1].parent.getName() != 'sentinel'):
                lst[i].leftChild = lst[i-2]
                lst[i-2].parent = lst[i]
            else:
                lst[i].leftChild = lst[i-1]
                lst[i-1].parent = lst[i]
            lst[i].rightChild = lst[i+1]
            lst[i+1].parent = lst[i]
        elif (lst[i].name == 'not'):
            lst[i].leftChild = None
            lst[i].rightChild = lst[i+1]
            lst[i+1].parent = lst[i]
    return lst


def recurseClauseList(lst):
    for i in range(len(lst)):
        tst = makeClauseList(lst[i].getName(), 0)
        if len(tst) == 1: # is base clause
            if (lst[i].getName() == 'or'):
                node = LOTlib3.FunctionNode.FunctionNode(returntype='DISJ', name='(%s or %s)', parent=None, args=None)
                lst[i].setNode(node)
            elif (lst[i].getName() == 'and'):
                node = LOTlib3.FunctionNode.FunctionNode(returntype='CONJ', name='(%s and %s)', parent=None, args=None)
                lst[i].setNode(node)
            elif (lst[i].getName() == 'not'):
                node = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name='(not %s)', parent=None, args=None)
                lst[i].setNode(node)
            elif (lst[i].getName() == 'True'):
                node = LOTlib3.FunctionNode.FunctionNode(returntype='START', name='True', parent=None, args=None)
                lst[i].setNode(node)
            elif (lst[i].getName() == 'False'):
                node = LOTlib3.FunctionNode.FunctionNode(returntype='START', name='False', parent=None, args=None)
                lst[i].setNode(node)
            else:
                node = convert_predicate(lst[i].getName())
                lst[i].setNode(node)
        else:
            if (lst[i].parent.leftChild == lst[i]):
                rootNode = getRoot(tst)
                lst[i].parent.leftChild = rootNode
                rootNode.parent = lst[i].parent
            elif (lst[i].parent.rightChild == lst[i]):
                rootNode = getRoot(tst)
                lst[i].parent.rightChild = rootNode
                rootNode.parent = lst[i].parent
            recurseClauseList(tst) # need to continue recursing
       
            
def convert_predicate(code, par=None):
    code = code.split(' ')
    if code[0] == "x['shape']":
        node = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['shape'] == %s", parent=par, args=None)
        node2 = LOTlib3.FunctionNode.FunctionNode(returntype='SHAPE', name=code[2], parent=node, args=None)
        node.args = [node2]
    elif code[0] == "x['color']":
        node = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['color'] == %s", parent=par, args=None)
        node2 = LOTlib3.FunctionNode.FunctionNode(returntype='COLOR', name=code[2], parent=node, args=None)
        node.args = [node2]
    elif code[0] == "not":
        node = convert_negation(code, None)
    return node

def connectTree(lst):
    root = getRoot(lst)
    clauseStack = []
    clauseStack = recursiveConnect(root, clauseStack)
    while (len(clauseStack) != 0):
        c = clauseStack.pop(0)
        connectFromClause(c)
    return root.getNode()
    
    
def recursiveConnect(clause, stack):
    new_stack = stack
    if (clause.isLeaf()):
        new_stack.append(clause)
        return new_stack
    else:
        for child in clause.getChildren():
            new_stack = recursiveConnect(child, new_stack)
        new_stack.append(clause)
    return new_stack
        
def connectFromClause(clause):
    if clause.getNode().returntype == 'CONJ':
        if (clause.leftChild.getNode().returntype != 'PRE-PREDICATE'):
            left_node = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name="", parent=clause.getNode(), args=[clause.leftChild.getNode()])
            clause.leftChild.getNode().parent = left_node
        else:
            left_node = clause.leftChild.getNode()
            
            
        if (clause.rightChild.getNode().returntype != 'CONJ'):
            rNode = LOTlib3.FunctionNode.FunctionNode(returntype='CONJ', name='', parent=clause.getNode(), args=[])
            
            if (clause.rightChild.getNode().returntype != 'PRE-PREDICATE'):
                rNode2 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name="", parent=rNode, args=[clause.rightChild.getNode()])
                clause.rightChild.getNode().parent = rNode2
            else:
                rNode2 = clause.rightChild.getNode()
                
            rNode.args = [rNode2]
        else:
            rNode = clause.rightChild.getNode()
            
        clause.getNode().args = [left_node, rNode]
        
    elif clause.getNode().returntype == 'DISJ':
        if (clause.leftChild.getNode().returntype != 'PRE-PREDICATE'):
            left_node = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name="", parent=clause.getNode(), args=[clause.leftChild.getNode()])
            clause.leftChild.getNode().parent = left_node
        else:
            left_node = clause.leftChild.getNode()
            
        if (clause.rightChild.getNode().returntype != 'DISJ'):
            rNode = LOTlib3.FunctionNode.FunctionNode(returntype='DISJ', name='', parent=clause.getNode(), args=[])
            
            if (clause.rightChild.getNode().returntype == 'PREDICATE'):
                rNode2 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name="", parent=rNode, args=[clause.rightChild.getNode()])
                clause.rightChild.getNode().parent = rNode2
            else:
                rNode2 = clause.rightChild.getNode()
                
            rNode.args = [rNode2]
        else:
            rNode = clause.rightChild.getNode()
            
        clause.getNode().args = [left_node, rNode]
        
    elif clause.getNode().returntype == 'PRE-PREDICATE':
        clause.rightChild.getNode().parent = clause.getNode()
        clause.getNode().args = [clause.rightChild.getNode()]
        
    return clause    
        
        
def getRoot(lst):
    for item in lst:
        if item.parent.getName() == 'sentinel':
            return item

In [19]:
def convertToNode(string):
    string = string[11:-1]
    if string[0] == '(':
        t, i = makeClause(string, 1, 1)
        if (i == len(string) - 1):
            string = string[1:-1]
    l = makeClauseList(string, 0)
    recurseClauseList(l)
    return connectTree(l), string

In [20]:
def compareNodeString(node, string):
    nodeList = re.split('\(|\)', str(node))
    stringList = re.split('\(|\)', string)
    while '' in nodeList:
        nodeList.remove('')
    while '' in stringList:
        stringList.remove('')
        
    nodeS = ''
    for item in nodeList:
        nodeS += item
        
    stringS = ''
    for item in stringList:
        stringS += item
        
    return nodeS == stringS

In [21]:
def testMakeClause():
    assert(makeClause("(not (x['color'] == 'blue'))", 1, 1) == "not (x['color'] == 'blue')")
    assert(makeClause("(x['color'] == 'blue' and x['shape'] == 'triangle')", 1, 1) == "x['color'] == 'blue' and x['shape'] == 'triangle'")
 
testMakeClause()

AssertionError: 

In [22]:
convertToNode("lambda x: (x['color'] == 'blue' or (not x['shape'] == 'square'))")

((x['color'] == 'blue' or (not x['shape'] == 'square')),
 "x['color'] == 'blue' or (not x['shape'] == 'square')")

In [1795]:
poor = [144, 369, 396, 810, 819, 1017, 1071, 25]
skip = [38, 41, 42, 43, 49, 50, 60, 61, 68, 69, 70, 71, 87, 88, 89, 111, 113, 114, 120, 122, 124, 129, 
        132, 133, 139, 140, 141, 150, 152, 159, 160, 161]
for i in range(len(data['gen_reshaped'])):
    if i not in poor:
        print(i)
        node, orig_string = convertToNode(data['gen_reshaped'][i])
        try:
            assert compareNodeString(node, orig_string)
        except:
            print(f"+++++++++++++++++++++++++\nthis is Node: {node}\n\n\nthis is original: {orig_string} \n++++++++++++++++++++++++++++++")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
2

In [715]:
rules = {i: set() for i in range(15)}
for _ in range(100000):
    #rule = grammar.generate()
    #log_prob = grammar.log_probability(rule)
    rule = LOTHypothesis(grammar=grammar)
    depth = rule.value.depth() #rule.depth()
    rules[depth].add(rule)

In [716]:
rules = {depth: list(depth_rules) for depth, depth_rules in rules.items()}

In [26]:
colors = ['red', 'blue', 'green']
shapes = ['circle', 'square', 'triangle']
all_stimuli = []

for color in colors:
    for shape in shapes:
        all_stimuli.append({'shape':shape, 'color':color})

In [27]:
all_stimuli

[{'shape': 'circle', 'color': 'red'},
 {'shape': 'square', 'color': 'red'},
 {'shape': 'triangle', 'color': 'red'},
 {'shape': 'circle', 'color': 'blue'},
 {'shape': 'square', 'color': 'blue'},
 {'shape': 'triangle', 'color': 'blue'},
 {'shape': 'circle', 'color': 'green'},
 {'shape': 'square', 'color': 'green'},
 {'shape': 'triangle', 'color': 'green'}]

# Diverse Prompt Generation

In [28]:
import itertools

In [29]:
l = [True, False]
all_results = [list(i) for i in itertools.product(l, repeat=9)]
all_results

[[True, True, True, True, True, True, True, True, True],
 [True, True, True, True, True, True, True, True, False],
 [True, True, True, True, True, True, True, False, True],
 [True, True, True, True, True, True, True, False, False],
 [True, True, True, True, True, True, False, True, True],
 [True, True, True, True, True, True, False, True, False],
 [True, True, True, True, True, True, False, False, True],
 [True, True, True, True, True, True, False, False, False],
 [True, True, True, True, True, False, True, True, True],
 [True, True, True, True, True, False, True, True, False],
 [True, True, True, True, True, False, True, False, True],
 [True, True, True, True, True, False, True, False, False],
 [True, True, True, True, True, False, False, True, True],
 [True, True, True, True, True, False, False, True, False],
 [True, True, True, True, True, False, False, False, True],
 [True, True, True, True, True, False, False, False, False],
 [True, True, True, True, False, True, True, True, True]

In [30]:
example_data = [FunctionData(input=[all_stimuli[i]], output=all_results[10][i], alpha=0.999) for i in range(9)]
example_data

[<{'shape': 'circle', 'color': 'red'} -> True>,
 <{'shape': 'square', 'color': 'red'} -> True>,
 <{'shape': 'triangle', 'color': 'red'} -> True>,
 <{'shape': 'circle', 'color': 'blue'} -> True>,
 <{'shape': 'square', 'color': 'blue'} -> True>,
 <{'shape': 'triangle', 'color': 'blue'} -> False>,
 <{'shape': 'circle', 'color': 'green'} -> True>,
 <{'shape': 'square', 'color': 'green'} -> False>,
 <{'shape': 'triangle', 'color': 'green'} -> True>]

In [31]:
class MyHypothesis(RationaRulesPrior, BinaryLikelihood, LOTHypothesis):
    def __init__(self, **kwargs):
        LOTHypothesis.__init__(self, grammar=grammar, **kwargs)
        self.rrAlpha=2.0

In [32]:
from contextlib import suppress

In [33]:
example_hypo = MyHypothesis()
example_top = TopN(N=10)
for h in (MetropolisHastingsSampler(example_hypo, example_data, steps=10000)):
    example_top << h

In [34]:
for h in example_top:
    print(h.posterior_score, h.prior, h.likelihood, qq(h))

-19.279947774322203 -11.675044314446662 -7.6049034598755405 "lambda x: ((not x['shape'] == 'square') or ((not x['color'] == 'green') or (not x['color'] == 'green')))"
-19.279947774322203 -11.675044314446662 -7.60490345987554 "lambda x: ((not x['shape'] == 'triangle') or ((not x['shape'] == 'triangle') or (not x['color'] == 'blue')))"
-19.279947774322203 -11.675044314446662 -7.60490345987554 "lambda x: ((not x['shape'] == 'triangle') or ((not x['color'] == 'blue') or (not x['color'] == 'blue')))"
-17.04635555281511 -9.44145209293957 -7.6049034598755405 "lambda x: ((not x['color'] == 'green') or (not x['shape'] == 'square'))"
-17.04635555281511 -9.44145209293957 -7.60490345987554 "lambda x: ((not x['shape'] == 'triangle') or (not x['color'] == 'blue'))"
-17.04635555281511 -9.44145209293957 -7.60490345987554 "lambda x: ((not x['color'] == 'blue') or (not x['shape'] == 'triangle'))"
-16.591600155495826 -1.386294361119889 -15.205305794375938 "lambda x: True"
-16.353208372255168 -8.748304912

In [35]:
posts = []
priors = []
likelihoods = []
codes = []
for h in example_top:
    posts.append(h.posterior_score)
    priors.append(h.prior)
    likelihoods.append(h.likelihood)
    codes.append(qq(h))

In [36]:
indices = [i for i in range(len(likelihoods)) if (likelihoods[i] == max(likelihoods))]

In [37]:
ex_c = codes[posts.index(max([posts[i] for i in indices]))]

In [38]:
ex_c

'"lambda x: ((not x[\'color\'] == \'green\') or (not x[\'shape\'] == \'square\'))"'

In [39]:
exec(f"def classify(x): return {ex_c[11:len(ex_c)-1]}")

In [40]:
correct = 0
for i in range(len(all_stimuli)):
    correct += classify(all_stimuli[i]) == all_results[10][i]

In [41]:
correct

8

# Automated

In [42]:
good_indices = []
with open('../data/revised_codex_prompts_2.json', 'w') as out:
    da = []
    for results in all_results:
    #for results in all_results:
        print(results)
        objs = [FunctionData(input=[all_stimuli[i]], output=results[i], alpha=0.999) for i in range(9)]
        hypo = MyHypothesis()
        top = TopN(N=10)
        print(f"sampling {all_results.index(results)}")
        for h in MetropolisHastingsSampler(hypo, objs, steps=10000):
            top << h
        
        """
        posts = []
        priors = []
        likelihoods = []
        codes = []
        for h in top:
            print(h.posterior_score, h.prior, h.likelihood, qq(h))
            posts.append(h.posterior_score)
            priors.append(h.prior)
            likelihoods.append(h.likelihood)
            codes.append(qq(h))
        
        indices = [i for i in range(len(likelihoods)) if (likelihoods[i] == max(likelihoods))]
        """
        codes = []
        posts = []
        priors = []
        likelihoods = []
        for h in top:
            codes.append(qq(h))
            posts.append(h.posterior_score)
            priors.append(h.prior)
            likelihoods.append(h.likelihood)
            
        corrects = []
        for code in codes:
            exec(f"def classify(x): return {code[11:len(code)-1]}")
            correct = 0
            for i in range(len(all_stimuli)):
                correct += classify(all_stimuli[i]) == results[i]
            corrects.append(correct)
        print(corrects)
            
        best_indices = [i for i in range(len(corrects)) if corrects[i] == max(corrects)]
        data = {}
        print(codes[posts.index(max([posts[i] for i in best_indices]))])
        best_index = posts.index(max([posts[i] for i in best_indices]))
        data["code"] = str(codes[best_index])
        data["accuracy"] = str(corrects[best_index] / 9)
        data["stims"] = str(all_stimuli)
        data["results"] = str(results)
        da.append(data)
        
        if (max(corrects) == 9):
            good_indices.append(all_results.index(results))

    out.write(json.dumps(da))
out.close()

[True, True, True, True, True, True, True, True, True]
sampling 0
[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]
"lambda x: True"
[True, True, True, True, True, True, True, True, False]
sampling 1


KeyboardInterrupt: 

In [1810]:
good_indices

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 11,
 16,
 18,
 19,
 24,
 25,
 27,
 31,
 32,
 36,
 38,
 40,
 45,
 48,
 50,
 54,
 56,
 58,
 63,
 64,
 65,
 69,
 72,
 73,
 79,
 89,
 91,
 104,
 121,
 127,
 128,
 130,
 144,
 146,
 150,
 151,
 152,
 155,
 176,
 178,
 182,
 184,
 186,
 191,
 192,
 193,
 194,
 195,
 199,
 200,
 201,
 208,
 211,
 216,
 217,
 219,
 223,
 255,
 256,
 260,
 261,
 262,
 288,
 292,
 293,
 295,
 296,
 316,
 319,
 320,
 321,
 324,
 325,
 352,
 357,
 360,
 361,
 365,
 367,
 381,
 383,
 384,
 386,
 388,
 390,
 391,
 432,
 434,
 438,
 439,
 440,
 446,
 447,
 448,
 451,
 452,
 455,
 457,
 463,
 466,
 475,
 479,
 484,
 487,
 493,
 495,
 502,
 503,
 504,
 505,
 506,
 507,
 508,
 509,
 510,
 511]

In [779]:
f = json.load(open("../data/revised_codex_prompts_2.json"))

In [780]:
f

[{'code': '"lambda x: True"',
  'accuracy': '1.0',
  'stims': "[{'shape': 'circle', 'color': 'red'}, {'shape': 'square', 'color': 'red'}, {'shape': 'triangle', 'color': 'red'}, {'shape': 'circle', 'color': 'blue'}, {'shape': 'square', 'color': 'blue'}, {'shape': 'triangle', 'color': 'blue'}, {'shape': 'circle', 'color': 'green'}, {'shape': 'square', 'color': 'green'}, {'shape': 'triangle', 'color': 'green'}]",
  'results': '[True, True, True, True, True, True, True, True, True]'},
 {'code': '"lambda x: ((not (x[\'color\'] == \'green\')) or (not (x[\'shape\'] == \'triangle\')))"',
  'accuracy': '1.0',
  'stims': "[{'shape': 'circle', 'color': 'red'}, {'shape': 'square', 'color': 'red'}, {'shape': 'triangle', 'color': 'red'}, {'shape': 'circle', 'color': 'blue'}, {'shape': 'square', 'color': 'blue'}, {'shape': 'triangle', 'color': 'blue'}, {'shape': 'circle', 'color': 'green'}, {'shape': 'square', 'color': 'green'}, {'shape': 'triangle', 'color': 'green'}]",
  'results': '[True, True, Tr

In [274]:
df = pd.read_json("../data/revised_codex_prompts.json")

In [275]:
df

Unnamed: 0,code,stims,results
0,"""lambda x: True""","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[True, True, True, True, True, True, True, Tru..."
1,"""lambda x: (not (x['shape']=='triangle') or no...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[True, True, True, True, True, True, True, Tru..."
2,"""lambda x: (not (x['shape']=='circle') or not ...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[True, True, True, True, True, True, False, Tr..."
3,"""lambda x: not (x['color']=='green')""","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[True, True, True, True, True, True, False, Fa..."
4,"""lambda x: (not (x['color']=='blue') or (x['co...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[True, True, True, True, True, False, True, Tr..."
...,...,...,...
128,"""lambda x: (x['color']=='green' and not (x['sh...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[False, False, False, False, False, False, Tru..."
129,"""lambda x: (x['color']=='green' and x['shape']...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[False, False, False, False, False, False, Tru..."
130,"""lambda x: (x['color']=='green' and not (x['sh...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[False, False, False, False, False, False, Fal..."
131,"""lambda x: (x['color']=='green' and x['shape']...","[{'shape': 'circle', 'color': 'red', 'alpha': ...","[False, False, False, False, False, False, Fal..."


# LOTlib3 on Codex Generated Code

In [46]:
data = pd.read_csv("../data/full_output.csv").drop("Unnamed: 0", axis=1)

In [47]:
data

Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped
0,1,0.555556,True,(color == 'red' or shape == 'circle'),4,37,1,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ..."
1,1,0.555556,True,(color == 'red' or shape == 'square'),4,37,2,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'circle') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ..."
2,1,1.000000,True,True,4,4,3,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
3,1,1.000000,True,True,4,4,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
4,1,1.000000,True,True,4,4,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,133,0.777778,(shape == 'triangle' and color == 'green'),(color == 'blue' and shape == 'square'),42,39,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn (color...",1,1,"""lambda x: (x['color'] == 'blue' and x['shape'..."
1193,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""
1194,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""
1195,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""


In [48]:
val = grammar.generate(x=eval(data['gen_reshaped'][0]))
type(val)
#mh = LOTHypothesis(grammar = grammar, value = val)

str

In [49]:
LOTlib3.FunctionNode.FunctionNode(returntype='START', name='', parent = None, args = [eval(eval(data['gen_reshaped'][0]))])

TypeError: expected string or bytes-like object

In [50]:
a = grammar.generate('DISJ')
a

(not x['shape'] == 'circle')

In [51]:
a.args[0].args[0].name

'(not %s)'

In [52]:
eval(data['gen_reshaped'][1])

"lambda x: (x['color'] == 'red' or x['shape'] == 'square')"

In [53]:
eval(data['gen_reshaped'][0])

"lambda x: (x['color'] == 'red' or x['shape'] == 'circle')"

In [54]:
for i in range(len(data['gen_reshaped'])):
    print(data['gen_reshaped'][i])

"lambda x: (x['color'] == 'red' or x['shape'] == 'circle')"
"lambda x: (x['color'] == 'red' or x['shape'] == 'square')"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: (x['color'] == 'red' or x['shape'] == 'square')"
"lambda x: True"
"lambda x: (x['color'] == 'red' or x['color'] == 'blue')"
"lambda x: (x['color'] == 'red' or x['color'] == 'blue')"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: (not x['shape'] == 'triangle')"
"lambda x: True"
"lambda x: (x['color'] == 'red' or x['shape'] == 'square')"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: True"
"lambda x: (not x['color'] == 'green')"
"lambda x: (x['color'] != 'green' or x['shape'] == 'square')"
"lambda x: ((x['color'] == 'green' and x['shape'] == 'square') or (not x['color'] == 'green'))"
"lambda x: True"
"lambda x: (x['color'] == 'red' or x['shape'] == 'square')"
"lambda x: True"
"lambda x: True"
"la

In [55]:
node1 = LOTlib3.FunctionNode.FunctionNode(returntype='SHAPE', name="'circle'", parent= None, args=None)
node2 = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['shape'] == %s", parent= None, args=[node1])
node7 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name='', parent=None, args=[node2])
node6 = LOTlib3.FunctionNode.FunctionNode(returntype='DISJ', name='', parent=None, args=[node7])
node1.parent = node2
node2.parent=node7
node7.parent=node6

node3 = LOTlib3.FunctionNode.FunctionNode(returntype='COLOR', name="'red'", parent= None, args=None)
node4 = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['color'] == %s", parent=None, args=[node3])
node5 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name='', parent=None, args=[node4])
node3.parent=node4
node4.parent=node5


node3 = LOTlib3.FunctionNode.FunctionNode(returntype='DISJ', name="(%s or %s)", parent=None, args=[node5, node6])
node5.parent = node3
node6.parent = node3

node8 = LOTlib3.FunctionNode.FunctionNode(returntype='START', name='', parent=None, args=[node3])
node3.parent=node8

In [56]:
n1 = LOTlib3.FunctionNode.FunctionNode(returntype='CONJ', name='', parent=None, args=None)
n2 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name='', parent=n1, args=None)
n1.args = [n2]
n3 = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['color'] == %s", parent=n2, args=None)
n2.args = [n3]
n4 = LOTlib3.FunctionNode.FunctionNode(returntype='COLOR', name="'red'", parent=n3, args=None)
n3.args = [n4]
nh = MyHypothesis(value = n1)

In [57]:
nh.compute_prior(), nh.compute_posterior(example_data2), nh.compute_likelihood(example_data2)

NameError: name 'example_data2' is not defined

In [940]:
nh1 = LOTlib3.FunctionNode.FunctionNode(returntype='START', name='', parent=None, args=None)
nh2 = LOTlib3.FunctionNode.FunctionNode(returntype='PRE-PREDICATE', name='', parent=nh1, args=None)
nh1.args = [nh2]
nh3 = LOTlib3.FunctionNode.FunctionNode(returntype='PREDICATE', name="x['color'] == %s", parent=nh2, args=None)
nh2.args = [nh3]
nh4 = LOTlib3.FunctionNode.FunctionNode(returntype='COLOR', name="'red'", parent=nh3, args=None)
nh3.args = [nh4]
nhh = MyHypothesis(value = nh1)

In [941]:
nh.compute_prior(), nh.compute_posterior(example_data2), nh.compute_likelihood(example_data2)

(-3.1780538303479444, -48.78496896272548, -45.60691513237754)

In [515]:
node8

(x['color'] == 'red' or x['shape'] == 'circle')

In [523]:
mh = MyHypothesis(value=node8)

In [544]:
mh

lambda x: (x['color'] == 'red' or x['shape'] == 'circle')

In [540]:
example_data2 = [FunctionData(input=[all_stimuli[i]], output=all_results[0][i], alpha=0.999) for i in range(9)]

In [541]:
example_data2

[<{'shape': 'circle', 'color': 'red'} -> True>,
 <{'shape': 'square', 'color': 'red'} -> True>,
 <{'shape': 'triangle', 'color': 'red'} -> True>,
 <{'shape': 'circle', 'color': 'blue'} -> True>,
 <{'shape': 'square', 'color': 'blue'} -> True>,
 <{'shape': 'triangle', 'color': 'blue'} -> True>,
 <{'shape': 'circle', 'color': 'green'} -> True>,
 <{'shape': 'square', 'color': 'green'} -> True>,
 <{'shape': 'triangle', 'color': 'green'} -> True>]

In [542]:
example_top2 = TopN(N=10)
for h2 in MetropolisHastingsSampler(example_hypo, example_data2, steps=1000):
    example_top2 << h2
example_top2 << mh

In [543]:
for h in example_top2:
    print(h)

lambda x: (not (x['shape'] == 'triangle') or (not (x['shape'] == 'circle') or not (x['shape'] == 'circle')))
lambda x: (not (x['shape'] == 'square') or (not (x['shape'] == 'square') or not (x['shape'] == 'circle')))
lambda x: (not (x['shape'] == 'square') or (not (x['shape'] == 'circle') or not (x['shape'] == 'circle')))
lambda x: (not (x['color'] == 'green') or x['color'] == 'green')
lambda x: (not (x['shape'] == 'circle') or x['shape'] == 'circle')
lambda x: (not (x['color'] == 'green') or not (x['color'] == 'red'))
lambda x: (not (x['shape'] == 'circle') or not (x['shape'] == 'square'))
lambda x: (not (x['shape'] == 'triangle') or not (x['shape'] == 'square'))
lambda x: (not (x['shape'] == 'triangle') or x['shape'] == 'triangle')
lambda x: True


In [545]:
mh.likelihood

0.0

In [546]:
mh.posterior_score

-inf

In [553]:
mh.compute_prior()

-8.748304912379627

In [551]:
mh.compute_likelihood(example_data2)

-30.406110463376734

In [555]:
mh.compute_posterior(example_data2)

-39.15441537575636

In [601]:
grammar.pack_ascii(node8)

'06abd5ach'

In [602]:
grammar.unpack_ascii('06abd5ach')

(x['color'] == 'red' or x['shape'] == 'circle')

In [603]:
node8

(x['color'] == 'red' or x['shape'] == 'circle')

In [1038]:
test1 = eval(data['gen_reshaped'][0])
test2 = eval(data['gen_reshaped'][0])[10:]

In [1039]:
test2

"(x['color'] == 'red' or x['shape'] == 'circle')"

In [1040]:
t2c, t2b = build_brick(test2)




In [1041]:
t2b

(x['color'] == 'red' or x['shape'] == 'circle')

In [647]:
grammar.pack_ascii(node8)

'06abd5ach'

In [649]:
grammar.unpack_ascii('06abd5ach')

(x['color'] == 'red' or x['shape'] == 'circle')

In [703]:
for i in grammar.get_all_rules():
    print(i)

START -> ['DISJ']	w/ p=1.0
START -> ['PRE-PREDICATE']	w/ p=5.0
START -> True	w/ p=5.0
START -> False	w/ p=5.0
DISJ -> ['CONJ']	w/ p=1.0
DISJ -> ['PRE-PREDICATE']	w/ p=5.0
DISJ -> (%s or %s)['PRE-PREDICATE', 'DISJ']	w/ p=1.0
CONJ -> ['PRE-PREDICATE']	w/ p=5.0
CONJ -> (%s and %s)['PRE-PREDICATE', 'CONJ']	w/ p=1.0
PRE-PREDICATE -> not (%s)['PREDICATE']	w/ p=5.0
PRE-PREDICATE -> ['PREDICATE']	w/ p=5.0
PREDICATE -> x['color'] == %s['COLOR']	w/ p=1.0
PREDICATE -> x['shape'] == %s['SHAPE']	w/ p=1.0
COLOR -> 'red'	w/ p=1.0
COLOR -> 'blue'	w/ p=1.0
COLOR -> 'green'	w/ p=1.0
SHAPE -> 'square'	w/ p=1.0
SHAPE -> 'circle'	w/ p=1.0
SHAPE -> 'triangle'	w/ p=1.0


In [632]:
testhyp = MyHypothesis()

In [633]:
testhyp.force_function(eval(eval(data['gen_reshaped'][0])))

In [634]:
eval(eval(data['gen_reshaped'][0]))

<function __main__.<lambda>(x)>

In [639]:
testhyp.compute_likelihood(example_data2)

-30.406110463376734

In [640]:
testhyp.compute_posterior(example_data2)

AttributeError: 'str' object has no attribute 'count_subnodes'

In [641]:
testhyp, mh

(lambda x: <FORCED_FUNCTION>,
 lambda x: (x['color'] == 'red' or x['shape'] == 'circle'))

In [642]:
testhyp.value

'<FORCED_FUNCTION>'

In [737]:
experiment = [str(r) for r in itertools.chain(*[grammar.rules[nt] for nt in list(grammar.rules.keys())])]

In [741]:
experiment

["START -> ['DISJ']\tw/ p=1.0",
 "START -> ['PRE-PREDICATE']\tw/ p=5.0",
 'START -> True\tw/ p=5.0',
 'START -> False\tw/ p=5.0',
 "DISJ -> ['CONJ']\tw/ p=1.0",
 "DISJ -> ['PRE-PREDICATE']\tw/ p=5.0",
 "DISJ -> (%s or %s)['PRE-PREDICATE', 'DISJ']\tw/ p=1.0",
 "CONJ -> ['PRE-PREDICATE']\tw/ p=5.0",
 "CONJ -> (%s and %s)['PRE-PREDICATE', 'CONJ']\tw/ p=1.0",
 "PRE-PREDICATE -> (not (%s))['PREDICATE']\tw/ p=5.0",
 "PRE-PREDICATE -> ['PREDICATE']\tw/ p=5.0",
 "PREDICATE -> x['color'] == %s['COLOR']\tw/ p=1.0",
 "PREDICATE -> x['shape'] == %s['SHAPE']\tw/ p=1.0",
 "COLOR -> 'red'\tw/ p=1.0",
 "COLOR -> 'blue'\tw/ p=1.0",
 "COLOR -> 'green'\tw/ p=1.0",
 "SHAPE -> 'square'\tw/ p=1.0",
 "SHAPE -> 'circle'\tw/ p=1.0",
 "SHAPE -> 'triangle'\tw/ p=1.0"]

In [749]:
grammar.rules

defaultdict(list,
            {'START': [START -> ['DISJ']	w/ p=1.0,
              START -> ['PRE-PREDICATE']	w/ p=5.0,
              START -> True	w/ p=5.0,
              START -> False	w/ p=5.0],
             'DISJ': [DISJ -> ['CONJ']	w/ p=1.0,
              DISJ -> ['PRE-PREDICATE']	w/ p=5.0,
              DISJ -> (%s or %s)['PRE-PREDICATE', 'DISJ']	w/ p=1.0],
             'CONJ': [CONJ -> ['PRE-PREDICATE']	w/ p=5.0,
              CONJ -> (%s and %s)['PRE-PREDICATE', 'CONJ']	w/ p=1.0],
             'PRE-PREDICATE': [PRE-PREDICATE -> (not (%s))['PREDICATE']	w/ p=5.0,
              PRE-PREDICATE -> ['PREDICATE']	w/ p=5.0],
             'PREDICATE': [PREDICATE -> x['color'] == %s['COLOR']	w/ p=1.0,
              PREDICATE -> x['shape'] == %s['SHAPE']	w/ p=1.0],
             'COLOR': [COLOR -> 'red'	w/ p=1.0,
              COLOR -> 'blue'	w/ p=1.0,
              COLOR -> 'green'	w/ p=1.0],
             'SHAPE': [SHAPE -> 'square'	w/ p=1.0,
              SHAPE -> 'circle'	w/ p=1.0,
     

In [786]:
exp_str = eval(data['gen_reshaped'][232])
exp_str

"lambda x: (not x['color'] == 'red')"

In [787]:
exp_str = exp_str[10:]
exp_str

"(not x['color'] == 'red')"

In [66]:
import re

In [788]:
exp_str = exp_str.split(' ')
exp_str

['(not', "x['color']", '==', "'red')"]

In [759]:
startnode = LOTlib3.FunctionNode.FunctionNode(returntype='START', name='', parent=None, args=None)
for i in range(len(exp_str)):
    if exp_str[i] == '(':
        y = i
        nodes = []
        while (exp_str[y] != ')' and y < len(exp_str)):
            y += 1
            

<class 'str'> (
<class 'str'> x
<class 'str'> [
<class 'str'> '
<class 'str'> c
<class 'str'> o
<class 'str'> l
<class 'str'> o
<class 'str'> r
<class 'str'> '
<class 'str'> ]
<class 'str'>  
<class 'str'> =
<class 'str'> =
<class 'str'>  
<class 'str'> '
<class 'str'> r
<class 'str'> e
<class 'str'> d
<class 'str'> '
<class 'str'>  
<class 'str'> o
<class 'str'> r
<class 'str'>  
<class 'str'> x
<class 'str'> [
<class 'str'> '
<class 'str'> s
<class 'str'> h
<class 'str'> a
<class 'str'> p
<class 'str'> e
<class 'str'> '
<class 'str'> ]
<class 'str'>  
<class 'str'> =
<class 'str'> =
<class 'str'>  
<class 'str'> '
<class 'str'> c
<class 'str'> i
<class 'str'> r
<class 'str'> c
<class 'str'> l
<class 'str'> e
<class 'str'> '
<class 'str'> )


In [58]:
data

Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped
0,1,0.555556,True,(color == 'red' or shape == 'circle'),4,37,1,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ..."
1,1,0.555556,True,(color == 'red' or shape == 'square'),4,37,2,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'circle') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ..."
2,1,1.000000,True,True,4,4,3,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
3,1,1.000000,True,True,4,4,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
4,1,1.000000,True,True,4,4,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True"""
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,133,0.777778,(shape == 'triangle' and color == 'green'),(color == 'blue' and shape == 'square'),42,39,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn (color...",1,1,"""lambda x: (x['color'] == 'blue' and x['shape'..."
1193,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""
1194,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""
1195,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False"""


In [1821]:
all_stimuli

[{'shape': 'circle', 'color': 'red'},
 {'shape': 'square', 'color': 'red'},
 {'shape': 'triangle', 'color': 'red'},
 {'shape': 'circle', 'color': 'blue'},
 {'shape': 'square', 'color': 'blue'},
 {'shape': 'triangle', 'color': 'blue'},
 {'shape': 'circle', 'color': 'green'},
 {'shape': 'square', 'color': 'green'},
 {'shape': 'triangle', 'color': 'green'}]

In [1802]:
data['stims_seen'][0]

"[{'shape': 'circle', 'color': 'red', 'alpha': 0.999}]"

In [1807]:
data['failed_tests'][0]

'["assert categorize(\'blue\', \'square\') == True", "assert categorize(\'blue\', \'triangle\') == True", "assert categorize(\'green\', \'square\') == True", "assert categorize(\'green\', \'triangle\') == True"]'

In [1808]:
data['passed_tests'][0]

'["assert categorize(\'red\', \'circle\') == True", "assert categorize(\'red\', \'square\') == True", "assert categorize(\'red\', \'triangle\') == True", "assert categorize(\'blue\', \'circle\') == True", "assert categorize(\'green\', \'circle\') == True"]'

In [59]:
works = []
for i in range(len(data['gen_reshaped'])):
    try:
        nodeItem, stringItem = convertToNode(data['gen_reshaped'][i])
        assert compareNodeString(nodeItem, stringItem)
    except:
        continue
        
    works.append(i)

In [64]:
nodeItem, stringItem = convertToNode(data['gen_reshaped'][0])

In [65]:
compareNodeString(nodeItem, stringItem)

NameError: name 're' is not defined

In [1846]:
poor

[144, 369, 396, 810, 819, 1017, 1071, 25]

In [1847]:
data['Problem_num']

0         1
1         1
2         1
3         1
4         1
       ... 
1192    133
1193    133
1194    133
1195    133
1196    133
Name: Problem_num, Length: 1197, dtype: int64

In [43]:
old_data = pd.read_json('../data/revised_codex_prompts.json')

In [44]:
len(old_data['results'])

133

In [67]:
priors = []
posteriors = []
likelihoods = []
correctResults = []
for i in range(len(data['gen_reshaped'])):
    results = eval(old_data['results'][int(i/9)])
    correctResults.append(results)
    try:
        nodeItem, stringItem = convertToNode(data['gen_reshaped'][i])
        assert compareNodeString(nodeItem, stringItem)
    except:
        priors.append(88888.88888)
        posteriors.append(88888.88888)
        likelihoods.append(88888.88888)
        continue
    
    nodeData = [FunctionData(input=[all_stimuli[i]], output=results[i], alpha=0.999) for i in range(9)]
    newHypothesis = MyHypothesis(value = nodeItem)
    try:
        priors.append(newHypothesis.compute_prior())
    except:
        priors.append(99999.99999)
    
    try:
        posteriors.append(newHypothesis.compute_posterior(nodeData))
    except:
        posteriors.append(99999.99999)
    
    try:
        likelihoods.append(newHypothesis.compute_likelihood(nodeData))
    except:
        likelihoods.append(99999.99999)

In [68]:
data['priors'] = priors
data['posteriors'] = posteriors
data['likelihoods'] = likelihoods
data['correctResults'] = correctResults
data

Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped,priors,posteriors,likelihoods,correctResults
0,1,0.555556,True,(color == 'red' or shape == 'circle'),4,37,1,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru..."
1,1,0.555556,True,(color == 'red' or shape == 'square'),4,37,2,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'circle') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru..."
2,1,1.000000,True,True,4,4,3,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
3,1,1.000000,True,True,4,4,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
4,1,1.000000,True,True,4,4,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,133,0.777778,(shape == 'triangle' and color == 'green'),(color == 'blue' and shape == 'square'),42,39,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn (color...",1,1,"""lambda x: (x['color'] == 'blue' and x['shape'...",-6.620073,-21.825379,-15.205306,"[False, False, False, False, False, False, Fal..."
1193,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."
1194,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."
1195,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."


In [1903]:
data.to_json('../data/full_with_posteriors.json')

In [1912]:
data[data['likelihoods'] == 99999.99999]['']

Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped,priors,posteriors,likelihoods,correctResults
140,16,0.777778,(color == 'blue' or not (shape == 'square')),((color == 'red' and shape == 'circle') or (co...,44,215,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('green', 'circle') == True...","""lambda x: (x['color']=='blue' or not (x['shap...","def categorize(color, shape):\n\treturn ((colo...",7,5,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[True, False, True, True, True, True, True, Fa..."
141,16,0.888889,(color == 'blue' or not (shape == 'square')),((color == 'red' and shape == 'circle') or (co...,44,259,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['color']=='blue' or not (x['shap...","def categorize(color, shape):\n\treturn ((colo...",7,6,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[True, False, True, True, True, True, True, Fa..."
152,17,1.0,not (shape == 'square'),((color == 'red' or color == 'blue' or color =...,23,104,9,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: not (x['shape']=='square')""","def categorize(color, shape):\n\treturn ((colo...",6,6,"""lambda x: ((x['color'] == 'red' or x['color']...",-inf,-inf,99999.99999,"[True, False, True, True, False, True, True, F..."
186,21,0.777778,(not (shape == 'circle') or (not (color == 're...,((color == 'red' and shape == 'square') or (co...,78,259,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'square') == True...","""lambda x: (not (x['shape']=='circle') or (not...","def categorize(color, shape):\n\treturn ((colo...",8,6,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[False, True, True, True, True, True, True, Tr..."
187,21,0.888889,(not (shape == 'circle') or (not (color == 're...,((color == 'red' and shape == 'square') or (co...,78,303,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (not (x['shape']=='circle') or (not...","def categorize(color, shape):\n\treturn ((colo...",8,7,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[False, True, True, True, True, True, True, Tr..."
192,22,0.555556,(not (shape == 'circle') or color == 'blue'),((color == 'red' and shape == 'square') or (co...,44,127,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == True""...","""lambda x: (not (x['shape']=='circle') or x['c...","def categorize(color, shape):\n\treturn ((colo...",7,3,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[False, True, True, True, True, True, False, T..."
195,22,0.777778,(not (shape == 'circle') or color == 'blue'),((color == 'red' and shape == 'square') or (co...,44,215,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'square') == True...","""lambda x: (not (x['shape']=='circle') or x['c...","def categorize(color, shape):\n\treturn ((colo...",7,5,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[False, True, True, True, True, True, False, T..."
196,22,0.888889,(not (shape == 'circle') or color == 'blue'),((color == 'red' and shape == 'square') or (co...,44,259,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (not (x['shape']=='circle') or x['c...","def categorize(color, shape):\n\treturn ((colo...",7,6,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[False, True, True, True, True, True, False, T..."
583,65,0.777778,(shape == 'circle' or (color == 'green' or (sh...,((color == 'red' or color == 'blue' or color =...,83,102,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='circle' or (x['color'...","def categorize(color, shape):\n\treturn ((colo...",6,6,"""lambda x: ((x['color'] == 'red' or x['color']...",-inf,-inf,99999.99999,"[True, True, False, True, False, False, True, ..."
609,68,0.777778,(not (color == 'red') or not (shape == 'square')),((color == 'red' and shape == 'circle') or (co...,49,259,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('green', 'square') == True...","""lambda x: (not (x['color']=='red') or not (x[...","def categorize(color, shape):\n\treturn ((colo...",8,6,"""lambda x: ((x['color'] == 'red' and x['shape'...",-inf,-inf,99999.99999,"[True, False, True, True, True, True, True, Tr..."


In [1922]:
data[data['Problem_num'] == 133]['gen_reshaped'].iloc[1]

'"lambda x: (x[\'color\'] == \'blue\' and x[\'shape\'] == \'triangle\')"'

In [1923]:
data[data['Problem_num'] == 133]['gen_reshaped']

1188                                    "lambda x: False"
1189    "lambda x: (x['color'] == 'blue' and x['shape'...
1190                                    "lambda x: False"
1191    "lambda x: (x['color'] == 'blue' and x['shape'...
1192    "lambda x: (x['color'] == 'blue' and x['shape'...
1193                                    "lambda x: False"
1194                                    "lambda x: False"
1195                                    "lambda x: False"
1196               "lambda x: (x['shape'] == 'triangle')"
Name: gen_reshaped, dtype: object

In [1924]:
checking_data = pd.read_csv('../data/full_with_posteriors.csv')
checking_data

Unnamed: 0.1,Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped,priors,posteriors,likelihoods,correctResults
0,0,1,0.555556,True,(color == 'red' or shape == 'circle'),4,37,1,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru..."
1,1,1,0.555556,True,(color == 'red' or shape == 'square'),4,37,2,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'circle') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru..."
2,2,1,1.000000,True,True,4,4,3,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
3,3,1,1.000000,True,True,4,4,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
4,4,1,1.000000,True,True,4,4,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,1192,133,0.777778,(shape == 'triangle' and color == 'green'),(color == 'blue' and shape == 'square'),42,39,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn (color...",1,1,"""lambda x: (x['color'] == 'blue' and x['shape'...",-6.620073,-21.825379,-15.205306,"[False, False, False, False, False, False, Fal..."
1193,1193,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."
1194,1194,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."
1195,1195,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal..."


# Log-Sum-Exp

In [71]:
rule_data['gen_reshaped']

1188                                    "lambda x: False"
1189    "lambda x: (x['color'] == 'blue' and x['shape'...
1190                                    "lambda x: False"
1191    "lambda x: (x['color'] == 'blue' and x['shape'...
1192    "lambda x: (x['color'] == 'blue' and x['shape'...
1193                                    "lambda x: False"
1194                                    "lambda x: False"
1195                                    "lambda x: False"
1196               "lambda x: (x['shape'] == 'triangle')"
Name: gen_reshaped, dtype: object

In [70]:
i = 133
rule_data = data[data['Problem_num'] == i]
rule_results = rule_data['correctResults'].iloc[0]
rule_objects = [FunctionData(input=[all_stimuli[k]], output=rule_results[k], alpha=0.999) for k in range(9)]
top_hypotheses = set()
codex_hypotheses = set()
for j in range(1, 10):
    codex_hypothesis = rule_data['gen_reshaped'].iloc[j-1]
    sub_objects = rule_objects[:j]
    sub_top = TopN(N=20)
    sub_hypo = MyHypothesis()
    for h in MetropolisHastingsSampler(sub_hypo, sub_objects, steps=100):
        sub_top << h
        
    for item in sub_top:
        top_hypotheses.add(item)
            
    try:
        node, string = convertToNode(codex_hypothesis)
        hy = MyHypothesis(value = node)
        top_hypotheses.add(hy)
        codex_hypotheses.add(hy)
    except:
        continue

In [72]:
top_hypotheses = list(top_hypotheses)
codexHypoIndices = []
for item in codex_hypotheses:
    
    
    
    
    print(item)
    correct = 0
    exec(f"def categorize(x): return {str(item)[10:]}")
    for i in range(len(all_stimuli)):
        print(categorize(all_stimuli[i]), rule_results[i])
        if categorize(all_stimuli[i]) == rule_results[i]:
            correct += 1
    print(correct)
    
    
    
    
    
    
    
    try:
        codexHypoIndices.append(top_hypotheses.index(item))
    except:
        print()

codexHypoIndices

lambda x: (x['color'] == 'blue' and x['shape'] == 'square')
False False
False False
False False
False False
True False
False False
False False
False False
False True
7
lambda x: (x['color'] == 'blue' and x['shape'] == 'circle')
False False
False False
False False
True False
False False
False False
False False
False False
False True
7
lambda x: False
False False
False False
False False
False False
False False
False False
False False
False False
False True
8
lambda x: x['shape'] == 'triangle'
False False
False False
True False
False False
False False
True False
False False
False False
True True
7
lambda x: (x['color'] == 'blue' and x['shape'] == 'triangle')
False False
False False
False False
False False
False False
True False
False False
False False
False True
7


[9, 10, 8, 4, 7]

In [2103]:
rule_data['gen_reshaped']

1188                                    "lambda x: False"
1189    "lambda x: (x['color'] == 'blue' and x['shape'...
1190                                    "lambda x: False"
1191    "lambda x: (x['color'] == 'blue' and x['shape'...
1192    "lambda x: (x['color'] == 'blue' and x['shape'...
1193                                    "lambda x: False"
1194                                    "lambda x: False"
1195                                    "lambda x: False"
1196               "lambda x: (x['shape'] == 'triangle')"
Name: gen_reshaped, dtype: object

In [2075]:
eval(rule_data['gen_reshaped'][33])

KeyError: 33

In [2071]:
str(top_hypotheses[2]) == eval(rule_data['gen_reshaped'][33])

True

In [2099]:
rule_data['gen_reshaped'].iloc[8]

'"lambda x: (x[\'shape\'] == \'triangle\')"'

In [73]:
normalized_posteriors = []
for j in range(1,10):
    postsForAmount = []
    sub_objects = rule_objects[:j]
    for item in top_hypotheses:
        postsForAmount.append(item.compute_posterior(sub_objects))
        
    postsForAmount_normed = np.array(postsForAmount[:])
    lse = logsumexp(postsForAmount_normed)
    postsForAmount_normed -= lse
    postsForAmount_normed = np.exp(postsForAmount_normed)
    
    for ind in codexHypoIndices:
        if compareNodeString(str(top_hypotheses[ind]), eval(rule_data['gen_reshaped'].iloc[j-1])):
            normalized_posteriors.append(postsForAmount_normed[ind])
normalized_posteriors

[0.43860182679818427,
 0.002523642620965115,
 0.7485728885851715,
 2.3772288369155434e-06,
 2.3888857908304197e-06,
 0.8999933422285405,
 0.9999362012072206,
 0.9999917735897651,
 0.00033334918592863044]

In [74]:
posteriorDf = pd.DataFrame()
for j in range(1,10):
    df = pd.DataFrame()
    postsForAmount = []
    likesForAmount = []
    priorsForAmount = []
    sub_objects = rule_objects[:j]
    for item in top_hypotheses:
        postsForAmount.append(item.compute_posterior(sub_objects))
        likesForAmount.append(item.compute_likelihood(sub_objects))
        priorsForAmount.append(item.compute_prior())
        
    postsForAmount_normed = np.array(postsForAmount[:])
    lse = logsumexp(postsForAmount_normed)
    postsForAmount_normed -= lse
    postsForAmount_normed = np.exp(postsForAmount_normed)
    df['num_stims'] = [j] * len(postsForAmount)
    hypo_num = list(np.arange(0, len(postsForAmount)))
    for ind in codexHypoIndices:
        hypo_num[ind] = f"Codex Prompt {ind}"
    
    df['hypo_num'] = hypo_num
    df['scores'] = postsForAmount_normed
    df['unnormed_scores'] = postsForAmount
    df['likelihoods'] = likesForAmount
    df['priors'] = priorsForAmount
    posteriorDf = pd.concat([posteriorDf, df], axis=0)
    
posteriorDf

Unnamed: 0,num_stims,hypo_num,scores,unnormed_scores,likelihoods,priors
0,1,0,1.218338e-02,-4.970313,-0.000500,-4.969813
1,1,1,1.218338e-02,-4.970313,-0.000500,-4.969813
2,1,2,3.655015e-02,-3.871701,-0.000500,-3.871201
3,1,3,1.828422e-05,-11.472103,-7.600902,-3.871201
4,1,Codex Prompt 4,2.924012e-01,-1.792260,-0.000500,-1.791759
...,...,...,...,...,...,...
12,9,12,5.216402e-15,-41.877714,-38.006513,-3.871201
13,9,13,1.042759e-11,-34.277311,-30.406110,-3.871201
14,9,14,4.166865e-05,-19.076507,-15.205306,-3.871201
15,9,15,1.738801e-15,-42.976326,-38.006513,-4.969813


In [2077]:
px.line(posteriorDf, x='num_stims', y='likelihoods', color='hypo_num')

In [2090]:
px.line(posteriorDf, x='num_stims', y='scores', color='hypo_num', hover_data=['unnormed_scores', 'likelihoods', 'priors'])

# Automated

In [2125]:
normalized_posteriors = []
for i in range(1, 134):
    print(f"Started rule {i}")
    rule_data = data[data['Problem_num'] == i]
    rule_results = rule_data['correctResults'].iloc[0]
    rule_objects = [FunctionData(input=[all_stimuli[k]], output=rule_results[k], alpha=0.999) for k in range(9)]
    top_hypotheses = set()
    codex_hypotheses = set()
    for j in range(1, 10):
        print(f"\tStarted subset {j}")
        codex_hypothesis = rule_data['gen_reshaped'].iloc[j-1]
        sub_objects = rule_objects[:j]
        sub_top = TopN(N=20)
        sub_hypo = MyHypothesis()
        for h in MetropolisHastingsSampler(sub_hypo, sub_objects, steps=20000):
            sub_top << h
        
        for item in sub_top:
            top_hypotheses.add(item)
            
        try:
            node, string = convertToNode(codex_hypothesis)
            hy = MyHypothesis(value = node)
            top_hypotheses.add(hy)
            codex_hypotheses.add(hy)
        except:
            continue
            
    print(f"\tHypothesis set completed.")     
    # Get the indices of hypotheses generated by Codex
    top_hypotheses = list(top_hypotheses)
    codexHypoIndices = []
    for item in codex_hypotheses:
        try:
            codexHypoIndices.append(top_hypotheses.index(item))
        except:
            continue
            
    ## TODO: Evaluate the set top_hypotheses at every individual data amount
    for l in range(1,10):
        postsForAmount = []
        sub_objects = rule_objects[:l]
        for item in top_hypotheses:
            postsForAmount.append(item.compute_posterior(sub_objects))

        postsForAmount_normed = np.array(postsForAmount[:])
        lse = logsumexp(postsForAmount_normed)
        postsForAmount_normed -= lse
        postsForAmount_normed = np.exp(postsForAmount_normed)

        added = False
        for ind in codexHypoIndices:
            if compareNodeString(str(top_hypotheses[ind]), eval(rule_data['gen_reshaped'].iloc[l-1])):
                normalized_posteriors.append(postsForAmount_normed[ind])
                added = True
        
        if(not added):
            normalized_posteriors.append(0)
                
normalized_posteriors

Started rule 1
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 2
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 3
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 4
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 5
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypoth

	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 42
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 43
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 44
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 45
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started r

	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 82
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 83
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 84
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 85
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started r

	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 122
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 123
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 124
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Started rule 125
	Started subset 1
	Started subset 2
	Started subset 3
	Started subset 4
	Started subset 5
	Started subset 6
	Started subset 7
	Started subset 8
	Started subset 9
	Hypothesis set completed.
Start

[0.001421601230700486,
 0.0016540901107463077,
 0.7100687458086044,
 0.8682209035412115,
 0.8702821762772519,
 0.870332276081712,
 0.9792856791789267,
 0.9805658773902193,
 0.9811764955106363,
 0.4714086888852307,
 0.0013579717541365576,
 0.7088117981756773,
 0.0028289732175434718,
 0.0028294380255373328,
 0.8666889604901516,
 0.978566550969145,
 0.9786297763111066,
 3.02772392968421e-05,
 0.471024317680383,
 0.0013567159839217504,
 0.5728334550297148,
 0.6715130394277784,
 0.6715985433268071,
 0.6716276124469542,
 0.7082337706445376,
 0,
 6.577918050020683e-06,
 0.47064460969181066,
 0.0013545908917915034,
 0.5718708133696362,
 0.6702088772975181,
 0.0021887152927711415,
 0.6703230147176367,
 0.7038999318385833,
 0.7073465484822994,
 0.7089388055672785,
 0.55825952991847,
 0.0016476395698521086,
 0.7070505370653036,
 0.002193556209136494,
 0.8657526543951798,
 0.4377761807157827,
 1.8723151504326506e-06,
 1.8725163673663702e-06,
 0.23840613703080227,
 0.5582841476418932,
 0.0016477240

In [2126]:
len(normalized_posteriors)

1197

In [2127]:
data['normalized_posteriors'] = normalized_posteriors

In [2128]:
data

Unnamed: 0,Problem_num,accuracy,tr_code_concat,gen_code_concat,true_code_size,gen_code_size,num_stims_seen,stims_seen,passed_tests,failed_tests,tr_code_full,gen_code_full,tr_domain,gen_domain,gen_reshaped,priors,posteriors,likelihoods,correctResults,normalized_posteriors
0,1,0.555556,True,(color == 'red' or shape == 'circle'),4,37,1,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'square') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru...",0.001422
1,1,0.555556,True,(color == 'red' or shape == 'square'),4,37,2,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...","[""assert categorize('blue', 'circle') == True""...","""lambda x: True""","def categorize(color, shape):\n\treturn (color...",9,5,"""lambda x: (x['color'] == 'red' or x['shape'] ...",-7.362011,-37.768121,-30.406110,"[True, True, True, True, True, True, True, Tru...",0.001654
2,1,1.000000,True,True,4,4,3,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru...",0.710069
3,1,1.000000,True,True,4,4,4,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru...",0.868221
4,1,1.000000,True,True,4,4,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == True"",...",[],"""lambda x: True""","def categorize(color, shape):\n\treturn True",9,9,"""lambda x: True""",-1.386294,-1.390795,-0.004501,"[True, True, True, True, True, True, True, Tru...",0.870282
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,133,0.777778,(shape == 'triangle' and color == 'green'),(color == 'blue' and shape == 'square'),42,39,5,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('blue', 'square') == False...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn (color...",1,1,"""lambda x: (x['color'] == 'blue' and x['shape'...",-6.620073,-21.825379,-15.205306,"[False, False, False, False, False, False, Fal...",0.000002
1193,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,6,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal...",0.874313
1194,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,7,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal...",0.985238
1195,133,0.888889,(shape == 'triangle' and color == 'green'),False,42,5,8,"[{'shape': 'circle', 'color': 'red', 'alpha': ...","[""assert categorize('red', 'circle') == False""...","[""assert categorize('green', 'triangle') == Tr...","""lambda x: (x['shape']=='triangle' and x['colo...","def categorize(color, shape):\n\treturn False",1,0,"""lambda x: False""",-1.386294,-8.991198,-7.604903,"[False, False, False, False, False, False, Fal...",0.985301


In [2129]:
data.to_json("../data/full_normed.json")