Skip to content

Commit

Permalink
Delete BayesNet.variables(); add check that X is distinct from evidence.
Browse files Browse the repository at this point in the history
  • Loading branch information
darius committed Nov 11, 2011
1 parent 9273ba7 commit e9f7720
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions probability.py
Expand Up @@ -157,7 +157,7 @@ def add(self, node):
"""Add a node to the net. Its parents must already be in the """Add a node to the net. Its parents must already be in the
net, and node itself must not.""" net, and node itself must not."""
assert node not in self.nodes assert node not in self.nodes
assert every(lambda parent: parent in self.variables(), node.parents) assert every(lambda parent: parent in self.vars, node.parents)
self.nodes.append(node) self.nodes.append(node)
self.vars.append(node.variable) self.vars.append(node.variable)
for parent in node.parents: for parent in node.parents:
Expand All @@ -172,12 +172,6 @@ def variable_node(self, var):
return n return n
raise Exception("No such variable: %s" % var) raise Exception("No such variable: %s" % var)


def variables(self):
"""List all of the net's variables, parents before children.
>>> burglary.variables()
['Burglary', 'Earthquake', 'Alarm', 'JohnCalls', 'MaryCalls']"""
return [n.variable for n in self.nodes]

def variable_values(self, var): def variable_values(self, var):
"Return the domain of var." "Return the domain of var."
return [True, False] return [True, False]
Expand Down Expand Up @@ -270,9 +264,10 @@ def enumeration_ask(X, e, bn):
>>> enumeration_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary >>> enumeration_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
... ).show_approx() ... ).show_approx()
'False: 0.716, True: 0.284'""" 'False: 0.716, True: 0.284'"""
assert X not in e, "Query variable must be distinct from evidence"
Q = ProbDist(X) Q = ProbDist(X)
for xi in bn.variable_values(X): for xi in bn.variable_values(X):
Q[xi] = enumerate_all(bn.variables(), extend(e, X, xi), bn) Q[xi] = enumerate_all(bn.vars, extend(e, X, xi), bn)
return Q.normalize() return Q.normalize()


def enumerate_all(vars, e, bn): def enumerate_all(vars, e, bn):
Expand Down Expand Up @@ -398,7 +393,7 @@ def gibbs_ask(X, e, bn, N):
'False: 0.738, True: 0.262' 'False: 0.738, True: 0.262'
""" """
counts = dict((x, 0) for x in bn.variable_values(X)) # bold N in Fig. 14.16 counts = dict((x, 0) for x in bn.variable_values(X)) # bold N in Fig. 14.16
Z = [var for var in bn.variables() if var not in e] Z = [var for var in bn.vars if var not in e]
state = dict(e) # boldface x in Fig. 14.16 state = dict(e) # boldface x in Fig. 14.16
for Zi in Z: for Zi in Z:
state[Zi] = choice(bn.variable_values(Zi)) state[Zi] = choice(bn.variable_values(Zi))
Expand Down

0 comments on commit e9f7720

Please sign in to comment.