Skip to content

Commit

Permalink
BayesianNetwork.EM_learning() doesn't seem to return correct results
Browse files Browse the repository at this point in the history
  • Loading branch information
mellesies committed Jun 27, 2020
1 parent 1db918b commit 027c7d4
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 74 deletions.
11 changes: 5 additions & 6 deletions tests/test_factor.py
Expand Up @@ -57,14 +57,13 @@ def test_add(self):
with self.assertRaises(Exception):
fA.add('noooooo')

@unittest.skip('deprecate?')
def test_extract_values(self):
"""Test factor.extract_values()."""
def test_get(self):
"""Test factor.get()."""
fA, fB_A, fC_A, fD_BC, fE_C = examples.get_sprinkler_factors()

self.assertIsInstance(fA.extract_values(A='a0'), float)
self.assertIsInstance(fB_A.extract_values(A='a0', B='b1'), float)
self.assertIsInstance(fB_A.extract_values(A='a0'), Factor)
self.assertIsInstance(fA.get(A='a0'), np.ndarray)
self.assertIsInstance(fB_A.get(A='a0', B='b1'), np.ndarray)
self.assertIsInstance(fB_A.get(A='a0'), np.ndarray)

def test_mul(self):
"""Test factor.mul()."""
Expand Down
75 changes: 36 additions & 39 deletions thomas/core/bayesiannetwork.py
Expand Up @@ -146,7 +146,6 @@ def setWidget(self, widget):
"""Associate this BayesianNetwork with a BayesianNetworkWidget."""
self.__widget = widget

# --- semi-private ---
def complete_case(self, case, include_weights=True):
"""Complete a single case.
Expand Down Expand Up @@ -184,16 +183,17 @@ def complete_case(self, case, include_weights=True):

# Create a dataframe by repeating the evicence multiple times: once for
# each possible (combination of) value(s) of the missing variable(s).
# The brackets enclosing the evidence transpose the matrix. Note that
# indexing a Series with a dict treats the dict like an iterable.
imputed = pd.DataFrame([case[evidence]], index=jpt.index)
# The brackets enclosing the evidence transpose the matrix, because the
# index is set, the row is broadcast.
idx = jpt.get_pandas_index()
imputed = pd.DataFrame([case[evidence]], index=idx)

# The combinations of missing variables are in the index. Reset the
# index to make them part of the dataframe.
imputed = imputed.reset_index()

# Add the computed probabilities as weights
imputed.loc[:, 'weight'] = jpt.values
imputed.loc[:, 'weight'] = jpt.flat

if include_weights:
order = list(case.index) + ['weight']
Expand All @@ -213,31 +213,31 @@ def estimate_emperical(self, data):
# JPT is complete
return JPT(Factor(0, self.states) + (summed / summed.sum())['weight'])

def complete_cases(self, data, inplace=False):
"""Impute missing values in data frame.
Args:
data (pandas.DataFrame): DataFrame that may have NAs.
Return:
pandas.DataFrame with NAs imputed.
"""
# Subset of all rows that have missing values.
NAs = data[data.isna().any(axis=1)]
imputed = NAs.apply(
self.complete_case,
axis=1,
include_weights=False
)

# DataFrame.update updates values *in place* by default.
if inplace:
data.update(imputed)
else:
data = data.copy()
data.update(imputed)

return data
# def complete_cases(self, data, inplace=False):
# """Impute missing values in data frame.
#
# Args:
# data (pandas.DataFrame): DataFrame that may have NAs.
#
# Return:
# pandas.DataFrame with NAs imputed.
# """
# # Subset of all rows that have missing values.
# NAs = data[data.isna().any(axis=1)]
# imputed = NAs.apply(
# self.complete_case,
# axis=1,
# include_weights=False
# )
#
# # DataFrame.update updates values *in place* by default.
# if inplace:
# data.update(imputed)
# else:
# data = data.copy()
# data.update(imputed)
#
# return data

# --- graph manipulation ---
def add_nodes(self, nodes):
Expand Down Expand Up @@ -282,7 +282,7 @@ def EM_learning(self, data, max_iterations=1, notify=True):

# Create a dataset with unique rows (& counts) ...
overlapping_cols = list(set(data.columns).intersection(self.vars))
counts = counts = data.fillna('NaN')
counts = data.fillna('NaN')
counts = counts.groupby(overlapping_cols, observed=True).size()
counts.name = 'count'
counts = pd.DataFrame(counts)
Expand All @@ -291,8 +291,6 @@ def EM_learning(self, data, max_iterations=1, notify=True):
counts = counts.reset_index(drop=True)
counts = counts.replace('NaN', np.nan)

print(counts)

for k in range(max_iterations):
# dict of joint distributions, indexed by family index
joints = {}
Expand All @@ -303,10 +301,11 @@ def EM_learning(self, data, max_iterations=1, notify=True):

N = row.pop('count')
evidence = row.dropna().to_dict()
print(f'applying evidence: {evidence}')

if (idx % 10) == 0:
print(idx, end=', ')
sys.stdout.flush()
# if (idx % 10) == 0:
# print(idx, end=', ')
# sys.stdout.flush()

self.reset_evidence()
self.junction_tree.set_evidence_hard(**evidence)
Expand Down Expand Up @@ -502,10 +501,8 @@ def compute_posterior(self, qd, qv, ed, ev, use_VE=False):

# If query values were specified we can extract them from the factor.
if qv:
result = result.extract_values(**qv)
result = result.get(**qv)

# FIXME: not sure what I think of the fact that we return scalars
# if the result doesn't have a MultiIndex ...
if isinstance(result, Factor):
return CPT(result, conditioned=query_vars)

Expand Down
64 changes: 35 additions & 29 deletions thomas/core/factor.py
Expand Up @@ -198,21 +198,22 @@ def _get_state_idx(self, RV):
"""Return ..."""

# Return the column that corresponds to the position of 'RV'
idx = np.array(self._get_index_tuples())
return idx[:, self.variables.index(RV)]
idx_cols = np.array(self._get_index_tuples())
return idx_cols[:, self.variables.index(RV)]

def _get_bool_idx(self, **kwargs):
def _get_bool_idx(self, **states):
"""Return ..."""
idx_cols = np.array(self._get_index_tuples())

# Only keep RVs that are in this factor's scope.
states = {RV: state for RV, state in kwargs.items() if RV in self.scope}
# Select all entries by default
trues = np.ones(idx_cols.shape) == 1

bools_per_RV = np.array([
[s == state for s in self._get_state_idx(RV)]
for RV, state in states.items()
])
# Filter on kwargs
for idx, RV in enumerate(self.scope):
if RV in states:
trues[:, idx] = idx_cols[:, idx] == states[RV]

return bools_per_RV.all(axis=0)
return trues.all(axis=1)

@property
def display_name(self):
Expand Down Expand Up @@ -374,16 +375,6 @@ def div(self, other, inplace=False):
factor.values = factor.values / other.values
factor.values[np.isnan(factor.values)] = 0

# if len(w) > 0:
# print()
# print(w)
# print(factor.scope, factor.values)
# print(other.scope, other.values)

# assert len(w) == 0

# factor.values = values

if not inplace:
return factor

Expand All @@ -409,7 +400,20 @@ def get_state_index(self, RV, state):
return self.name_to_number[RV][state]

def get(self, **kwargs):
"""..."""
"""Return the cells identified by kwargs.
Examples
--------
>>> factor = Factor([1, 1], {'A': ['a0', 'a1']})
>>> print(factor)
factor(A)
A
a0 1.0
a1 1.0
dtype: float64
>>> factor.get(A='a0')
array([1.])
"""
return self.flat[self._get_bool_idx(**kwargs)]

def set(self, value, inplace=False, **kwargs):
Expand All @@ -434,7 +438,7 @@ def set(self, value, inplace=False, **kwargs):
"""
factor = self if inplace else Factor.copy(self)

factor.values.reshape(-1)[factor._get_bool_idx(**kwargs)] = value
factor.flat[factor._get_bool_idx(**kwargs)] = value

if not inplace:
return factor
Expand Down Expand Up @@ -490,7 +494,7 @@ def normalize(self, inplace=False):
if not inplace:
return factor

def sum_out(self, variables, simplify=False, inplace=False):
def sum_out(self, variables, inplace=False):
"""Sum-out (marginalize) a variable (or list of variables) from the
factor.
Expand All @@ -511,17 +515,19 @@ def sum_out(self, variables, simplify=False, inplace=False):
# Nothing to sum out ...
return factor

scope = set(factor.variables)
scope_set = set(factor.scope)

if not variable_set.issubset(scope):
raise error.NotInScopeError(variable_set, scope)
if not variable_set.issubset(scope_set):
raise error.NotInScopeError(variable_set, scope_set)

# Unstack the requested variables into columns and sum over them.
var_indexes = [factor.variables.index(var) for var in variables]
# Find the indices of the variables to sum out
var_indexes = [factor.scope.index(var) for var in variables]

index_to_keep = set(range(len(factor.variables))) - set(var_indexes)
# Remove the variables from the factor
factor.del_state_names(variables)

# Sum over the variables we just deleted. This can reduce the result
# to a scalar.
factor.values = np.sum(factor.values, axis=tuple(var_indexes))

if not inplace:
Expand Down

0 comments on commit 027c7d4

Please sign in to comment.