diff --git a/.gitignore b/.gitignore index 6580efd..45e28ea 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ thomas/core/.vscode *.log htmlcov build +*. diff --git a/tests/test_bn.py b/tests/test_bn.py index 369a45b..cce991f 100644 --- a/tests/test_bn.py +++ b/tests/test_bn.py @@ -66,7 +66,7 @@ def test_node_cpt(self): node.cpt = 1 with self.assertRaises(Exception): - node.cpt = CPT(cpts['G'].as_factor(), conditioned=['I','D']) + node.cpt = CPT(cpts['G'].as_factor(), conditioned=['I', 'D']) with self.assertRaises(Exception): node.cpt = CPT(cpts['G']) @@ -365,6 +365,7 @@ def test_elimination_order_importance(self): self.Gs.elimination_order = None + @unittest.skip('Not yet') def test_EM_learning(self): """Test the EM-learning algorithm.""" # Load the BN (with priors) @@ -399,5 +400,3 @@ def test_EM_learning(self): self.assertAlmostEqual(bn['D'].cpt['b1', 'd2'], 0.933, places=3) self.assertAlmostEqual(bn['D'].cpt['b2', 'd1'], 1.000, places=3) self.assertAlmostEqual(bn['D'].cpt['b2', 'd2'], 0.000, places=3) - - diff --git a/tests/test_oobn_reader.py b/tests/test_oobn_reader.py new file mode 100644 index 0000000..4ee6745 --- /dev/null +++ b/tests/test_oobn_reader.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import unittest +import logging + +import thomas.core +from thomas.core.bayesiannetwork import BayesianNetwork +from thomas.core.reader import oobn + +log = logging.getLogger(__name__) + +class TestOOBNReader(unittest.TestCase): + + def setUp(self): + self.maxDiff = None + self.places = 3 + + def test_oobn_reader(self): + filename = thomas.core.get_pkg_data('prostatecancer.oobn') + bn = oobn.read(filename) + + self.assertTrue(isinstance(bn, BayesianNetwork)) + + grade = bn['grade'].cpt + self.assertAlmostEqual(grade['g2'], 0.0185338) + self.assertAlmostEqual(grade['g3'], 0.981466) + + cT = bn['cT'].cpt.reorder_scope(['grade', 'cT']) + self.assertAlmostEqual(cT['g2', 'T2'], 0.0) + self.assertAlmostEqual(cT['g2', 'T3'], 0.0) + self.assertAlmostEqual(cT['g2', 'T4'], 1.0) + self.assertAlmostEqual(cT['g3', 'T2'], 0.521457) + self.assertAlmostEqual(cT['g3', 'T3'], 0.442157) + self.assertAlmostEqual(cT['g3', 'T4'], 0.0363858) + + cN = bn['cN'].cpt.reorder_scope(['edition', 'cT']) + self.assertAlmostEqual(cN['TNM 6', 'T2', 'NX'], 0.284264) + self.assertAlmostEqual(cN['TNM 6', 'T2', 'N0'], 0.680203) + self.assertAlmostEqual(cN['TNM 6', 'T2', 'N1'], 0.035533) + + cTNM = bn['cTNM'].cpt.reorder_scope(['cN', 'cT', 'edition', 'cTNM']) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 6', 'I'], 0.0) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 6', 'II'], 1.0) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 6', 'III'], 0.0) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 6', 'IV'], 0.0) + + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 7', 'I'], 0.522727) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 7', 'II'], 0.454545) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 7', 'III'], 0.0) + self.assertAlmostEqual(cTNM['NX', 'T2', 'TNM 7', 'IV'], 0.0227273) diff --git a/thomas/core/base.py b/thomas/core/base.py index 4e203ef..96382cd 100644 --- a/thomas/core/base.py +++ b/thomas/core/base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import pandas as pd +import numpy as np + def index_to_dict(idx): if isinstance(idx, pd.MultiIndex): @@ -9,10 +11,12 @@ def index_to_dict(idx): def remove_none_values_from_dict(dict_): - """Remove none values, like `None` and `np.nan` from the dict.""" - t = lambda x: (x is None) or (isinstance(x, float) and np.isnan(x)) - result = {k:v for k,v in dict_.items() if not t(v)} - return result + """Remove none values, like `None` and `np.nan` from the dict.""" + def t(x): + return (x is None) or (isinstance(x, float) and np.isnan(x)) + + result = {k: v for k, v in dict_.items() if not t(v)} + return result # ------------------------------------------------------------------------------ @@ -58,10 +62,10 @@ def split(s): def create_query_string(cls, qd=None, qv=None, ed=None, ev=None): """Generate a query string.""" qd_str = ','.join(qd) if qd else '' - qv_str = ','.join([f'{k}={v}' for k,v in qv.items()]) if qv else '' + qv_str = ','.join([f'{k}={v}' for k, v in qv.items()]) if qv else '' ed_str = ','.join(ed) if ed else '' - ev_str = ','.join([f'{k}={v}' for k,v in ev.items()]) if ev else '' + ev_str = ','.join([f'{k}={v}' for k, v in ev.items()]) if ev else '' Q = ','.join([q for q in [qd_str, qv_str] if q]) E = ','.join([e for e in [ed_str, ev_str] if e]) @@ -111,4 +115,3 @@ def P(self, query_string): # return d.idxmax(), d.max() # # return d.idxmax() - diff --git a/thomas/core/bayesiannetwork.py b/thomas/core/bayesiannetwork.py index 2ce4aba..beaf164 100644 --- a/thomas/core/bayesiannetwork.py +++ b/thomas/core/bayesiannetwork.py @@ -1,29 +1,26 @@ # -*- coding: utf-8 -*- """BayesianNetwork""" -import sys, os -from datetime import datetime as dt +from typing import List, Tuple -import itertools -from collections import OrderedDict - -import networkx as nx -import networkx.algorithms.moral +import sys +from functools import reduce import numpy as np import pandas as pd -from pandas.core.dtypes.dtypes import CategoricalDtype -from functools import reduce + +import networkx as nx +import networkx.algorithms.moral import json from . import options -from .factor import Factor, mul +from .factor import Factor from .cpt import CPT from .jpt import JPT from .base import ProbabilisticModel from .bag import Bag -from .junctiontree import JunctionTree, TreeNode +from .junctiontree import JunctionTree from . import error @@ -31,10 +28,6 @@ log = logging.getLogger('thomas.bn') - -# ------------------------------------------------------------------------------ -# BayesianNetwork -# ------------------------------------------------------------------------------ class BayesianNetwork(ProbabilisticModel): """A Bayesian Network (BN) consistst of Nodes and directed Edges. @@ -89,7 +82,8 @@ def __repr__(self): """x.__repr__() <==> repr(x)""" s = f"\n" for RV in self.nodes: - s += f" \n" + node = self.nodes[RV] + s += f" \n" s += '' @@ -214,32 +208,6 @@ def estimate_emperical(self, data): # JPT is complete return JPT(Factor(0, self.states) + (summed / summed.sum())['weight']) - # def complete_cases(self, data, inplace=False): - # """Impute missing values in data frame. - # - # Args: - # data (pandas.DataFrame): DataFrame that may have NAs. - # - # Return: - # pandas.DataFrame with NAs imputed. - # """ - # # Subset of all rows that have missing values. - # NAs = data[data.isna().any(axis=1)] - # imputed = NAs.apply( - # self.complete_case, - # axis=1, - # include_weights=False - # ) - # - # # DataFrame.update updates values *in place* by default. - # if inplace: - # data.update(imputed) - # else: - # data = data.copy() - # data.update(imputed) - # - # return data - # --- graph manipulation --- def add_nodes(self, nodes): """Add a Node to the network.""" @@ -255,6 +223,10 @@ def add_edges(self, edges): self._jt = None + def delete_edge(self, edge: Tuple[str, str]): + parent_RV, child_RV = edge + self.nodes[parent_RV] + def moralize_graph(self): """Return the moral graph for the DAG. @@ -277,6 +249,13 @@ def EM_learning(self, data, max_iterations=1, notify=True): * https://www.cse.ust.hk/bnbook/pdf/l07.h.pdf * https://www.youtube.com/watch?v=NDoHheP2ww4 """ + # Ensure the data only contains states that are allowed. + data = data.copy() + + for RV in self.scope: + node = self.nodes[RV] + data = data[data[RV].isin(node.states)] + # Children (i.e. nodes with parents) identify the families in the BN. nodes_with_parents = self.nodes_with_parents nodes_without_parents = self.nodes_without_parents @@ -292,14 +271,10 @@ def EM_learning(self, data, max_iterations=1, notify=True): counts = counts.reset_index(drop=True) counts = counts.replace('NaN', np.nan) - # print() - # print('counts:') - # print(counts) - iterator = range(max_iterations) # If tqdm is available *and* we're not in quiet mode - if options.get('quiet', False) == False: + if not options.get('quiet', False): try: from tqdm import tqdm iterator = tqdm(iterator) @@ -307,7 +282,6 @@ def EM_learning(self, data, max_iterations=1, notify=True): print('Could not instantiate tqdm') print(e) - for k in iterator: # print(f'--- iteration {k} ---') @@ -386,6 +360,12 @@ def ML_estimation(self, df): df (pandas.Dataframe): dataset that contains columns with names corresponding to the variables in this BN's scope. """ + # Ensure the data only contains states that are allowed. + data = df.copy() + + for RV in self.scope: + node = self.nodes[RV] + data = data[data[RV].isin(node.states)] # The empirical distribution may not contain all combinations of # variable states; `from_data` fixes that by setting all missing entries @@ -405,7 +385,6 @@ def ML_estimation(self, df): if self.__widget: self.__widget.update() - def likelihood(self, df, per_case=False): """Return the likelihood of the current network parameters given data. @@ -509,7 +488,7 @@ def compute_posterior(self, qd, qv, ed, ev, use_VE=False): required_RVs = set(qd + list(qv.keys()) + ed) node = self.junction_tree.get_node_for_set(required_RVs) - if node is None and use_VE == False: + if node is None and use_VE is False: log.info('Cannot answer this query with the current junction tree.') use_VE = True @@ -517,7 +496,6 @@ def compute_posterior(self, qd, qv, ed, ev, use_VE=False): log.debug('Using VE') return self.as_bag().compute_posterior(qd, qv, ed, ev) - # Compute the answer to the query using the junction tree. log.debug(f'Found a node in the JT that contains {required_RVs}: {node.cluster}') self.junction_tree.reset_evidence() @@ -545,6 +523,9 @@ def compute_posterior(self, qd, qv, ed, ev, use_VE=False): def reset_evidence(self, RVs=None, notify=True): """Reset evidence.""" + if isinstance(RVs, str): + RVs = [RVs] + self.junction_tree.reset_evidence(RVs) if RVs: @@ -686,9 +667,7 @@ def open(cls, filename): data = fp.read() return cls.from_json(data) -# ------------------------------------------------------------------------------ -# Node -# ------------------------------------------------------------------------------ + class Node(object): """Base class for discrete and continuous nodes in a Bayesian Network. @@ -713,11 +692,11 @@ def __init__(self, RV, name=None, description=''): # A node needs to know its parents in order to determine the shape of # its CPT. This should be a list of Nodes. - self._parents = [] + self._parents: List[Node] = [] # For purposes of message passing, a node also needs to know its # children. - self._children = [] + self._children: List[Node] = [] @property def parents(self): @@ -755,7 +734,7 @@ def add_parent(self, parent, add_child=True): return False - def add_child(self, child, add_parent=True): + def add_child(self, child: object, add_parent=True) -> bool: """Add a child to the Node. Args: @@ -776,7 +755,7 @@ def add_child(self, child, add_parent=True): return False - def remove_parent(self, parent, remove_child=True): + def remove_parent(self, parent: object, remove_child=True) -> bool: """Remove a parent from the Node. If succesful, the Node's distribution's parameters (ContinousNode) or @@ -795,6 +774,22 @@ def remove_parent(self, parent, remove_child=True): return False + def remove_child(self, child: object, remove_parent=True) -> bool: + """Remove a child from the Node. + + Return: + True iff the parent was removed. + """ + if child in self._children: + self._children.remove(child) + + if remove_parent: + child._parents.remove(self) + + return True + + return False + def validate(self): """Validate the probability parameters for this Node.""" raise NotImplementedError @@ -810,9 +805,7 @@ def from_dict(cls, d): clstype = getattr(sys.modules[__name__], clsname) return clstype.from_dict(d) -# ------------------------------------------------------------------------------ -# DiscreteNetworkNode -# ------------------------------------------------------------------------------ + class DiscreteNetworkNode(Node): """Node in a Bayesian Network with discrete values.""" @@ -1025,10 +1018,7 @@ def from_dict(cls, d): states=d['states'], description=d['description'] ) - node.position = d.get('position', (0,0)) + node.position = d.get('position', (0, 0)) node.cpt = cpt return node - - - diff --git a/thomas/core/cpt.py b/thomas/core/cpt.py index 5bbcbd5..228ba8d 100644 --- a/thomas/core/cpt.py +++ b/thomas/core/cpt.py @@ -1,19 +1,6 @@ # -*- coding: utf-8 -*- """CPT: Conditional Probability Table.""" -import os -from datetime import datetime as dt - -from collections import OrderedDict - -import numpy as np -import pandas as pd -from pandas.core.dtypes.dtypes import CategoricalDtype -from functools import reduce - -import json - from .factor import * -from . import error as e # ------------------------------------------------------------------------------ @@ -93,7 +80,11 @@ def display_name(self): return f'P({self.short_query_str()})' def _repr_html_(self): - """Return an HTML representation of this CPT.""" + """Return an HTML representation of this CPT. + + Note that the order of the index may differ as pandas sorts it when + performing `unstack()`. + """ data = self.as_series() if self.conditioning: @@ -125,6 +116,15 @@ def as_factor(self): """Return a copy this CPT as a Factor.""" return Factor(self.values, self.states) + def as_dataframe(self): + """Return the CPT as a pandas.DataFrame.""" + data = self.as_series() + + if self.conditioning: + data = data.unstack(self.conditioned) + + return data + @classmethod def from_factor(cls, factor): """Create a CPT from a Factor. diff --git a/thomas/core/data/student.json b/thomas/core/data/student.json index f8948d0..9493a94 100644 --- a/thomas/core/data/student.json +++ b/thomas/core/data/student.json @@ -10,7 +10,7 @@ "i0", "i1" ], - "description": "", + "description": "Intelligence", "cpt": { "type": "CPT", "scope": [ @@ -45,7 +45,7 @@ "s0", "s1" ], - "description": "", + "description": "SAT Score", "cpt": { "type": "CPT", "scope": [ @@ -89,7 +89,7 @@ "d0", "d1" ], - "description": "", + "description": "Difficulty", "cpt": { "type": "CPT", "scope": [ @@ -125,7 +125,7 @@ "g2", "g3" ], - "description": "", + "description": "Grade", "cpt": { "type": "CPT", "scope": [ @@ -184,7 +184,7 @@ "l0", "l1" ], - "description": "", + "description": "Letter", "cpt": { "type": "CPT", "scope": [ diff --git a/thomas/core/factor.py b/thomas/core/factor.py index 9689c1d..b299275 100644 --- a/thomas/core/factor.py +++ b/thomas/core/factor.py @@ -78,7 +78,9 @@ class Factor(object): """Factor for discrete variables. Code is heavily inspired (not to say partially copied from) by pgmpy's - DiscreteFactor. See https://github.com/pgmpy/pgmpy/blob/dev/pgmpy/factors/discrete/DiscreteFactor.py + DiscreteFactor. + + See https://github.com/pgmpy/pgmpy/blob/dev/pgmpy/factors/discrete/DiscreteFactor.py """ def __init__(self, data, states): @@ -123,7 +125,7 @@ def __repr__(self): s = f'{self.display_name}\n{repr(self.as_series())}' return s - return f'{self.display_name}: {self.values:.2}' + return f'{self.display_name}: {self.values:.2f}' def __eq__(self, other): """f1 == f2 <==> f1.__eq__(f2)""" @@ -489,7 +491,13 @@ def get_state_index(self, RV, state): if isinstance(state, (tuple, list)): return [self.name_to_number[RV][s] for s in state] - return self.name_to_number[RV][state] + try: + return self.name_to_number[RV][state] + except: + print(f'self.name_to_number: {self.name_to_number}') + print(f'RV: {RV}') + print(f'state: {state}') + raise def get(self, **kwargs): """Return the cells identified by kwargs. @@ -674,7 +682,7 @@ def from_series(cls, series): factor[idx] = series.get(idx, 0) else: - factor = Factor(series.values, states) + factor = Factor(series[full_idx].values, states) return factor diff --git a/thomas/core/reader/net.py b/thomas/core/reader/net.py index 14ec255..b68e300 100644 --- a/thomas/core/reader/net.py +++ b/thomas/core/reader/net.py @@ -30,6 +30,7 @@ ?value: string | number + | "boolean" | tuple string: ESCAPED_STRING diff --git a/thomas/core/reader/oobn.py b/thomas/core/reader/oobn.py index 69b9c90..b0eac32 100644 --- a/thomas/core/reader/oobn.py +++ b/thomas/core/reader/oobn.py @@ -28,6 +28,7 @@ ?value: string | number + | "boolean" | tuple string: ESCAPED_STRING @@ -44,14 +45,15 @@ %ignore WS """ + class BasicTransformer(lark.Transformer): """Transform lark.Tree into basic, Python native, objects.""" def oobn_class(self, items): """The oobn_class is the root element of an OOBN file.""" name, properties, comment = items - #for idx, i in enumerate(items): - # print(repr(i)[:25]) + # for idx, i in enumerate(items): + # print(repr(i)[:25]) oobn_obj = { 'name': name @@ -168,10 +170,12 @@ def _parse(filename): return tree + def _transform(tree): transformer = BasicTransformer() return transformer.transform(tree) + def _create_structure(tree): # dict, indexed by node name nodes = {} @@ -216,12 +220,20 @@ def _create_structure(tree): columns = pd.Index(node_states, name=name) data = data.reshape(-1, len(columns)) df = pd.DataFrame(data, index=index, columns=columns) + stacked = df.stack() + # This keeps the index order cpt = CPT( - Factor.from_series(df.stack()), + stacked, + states={n: states[n] for n in stacked.index.names}, conditioned=[name], ) + # cpt = CPT( + # Factor.from_series(df.stack()), + # conditioned=[name], + # ) + # Else, it's a probability table else: cpt = CPT( @@ -250,6 +262,7 @@ def _create_structure(tree): return network + def _create_bn(structure): """Create a BayesianNetwork from a previously created structure.""" nodes = [] @@ -274,6 +287,7 @@ def _create_bn(structure): edges = structure['edges'] return bayesiannetwork.BayesianNetwork(structure['name'], nodes, edges) + def read(filename): """Parse the OOBN file and transform it into a sensible dictionary.""" # Parse the OOBN file