Skip to content

Commit

Permalink
add DYNOTEARS implementation (#50)
Browse files Browse the repository at this point in the history
Adds DYNOTEARS and corresponding data generator (for testing)
  • Loading branch information
GabrielAzevedoFerreiraQB committed Sep 10, 2020
1 parent 92968ee commit 4134ab8
Show file tree
Hide file tree
Showing 12 changed files with 2,747 additions and 9 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Upcoming release

* Add dynotears (`from_numpy_dynamic`, an algorithm for structure learning on Dynamic Bayesian Networks)
* Add a count data type to the data generator using a zero-inflated Poisson
* Set bounds/max class imbalance for binary features for the data generators
* Add non-linear data generators for multiple data types
Expand Down
2 changes: 1 addition & 1 deletion causalnex/structure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
``causalnex.structure`` provides functionality to define or learn structure.
"""

__all__ = ["StructureModel", "notears", "data_generators"]
__all__ = ["StructureModel", "notears", "dynotears", "data_generators"]

from .structuremodel import StructureModel
6 changes: 6 additions & 0 deletions causalnex/structure/data_generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@
"generate_continuous_data",
"generate_continuous_dataframe",
"generate_count_dataframe",
"gen_stationary_dyn_net_and_df",
"generate_dataframe_dynamic",
"generate_structure_dynamic",
]

from .core import generate_structure, nonlinear_sem_generator, sem_generator
from .wrappers import (
gen_stationary_dyn_net_and_df,
generate_binary_data,
generate_binary_dataframe,
generate_categorical_dataframe,
generate_continuous_data,
generate_continuous_dataframe,
generate_count_dataframe,
generate_dataframe_dynamic,
generate_structure_dynamic,
)
7 changes: 5 additions & 2 deletions causalnex/structure/data_generators/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
import pandas as pd
from sklearn.gaussian_process.kernels import RBF, Kernel

from causalnex.structure import StructureModel
from causalnex.structure.categorical_variable_mapper import (
VariableFeatureMapper,
validate_schema,
)
from causalnex.structure.structuremodel import StructureModel

# dict mapping distributions names to their functions
__distribution_mapper = {
Expand Down Expand Up @@ -117,7 +117,10 @@ def generate_structure(
edge_flags = np.tril(np.ones([num_nodes, num_nodes]), k=-1)

else:
raise ValueError("unknown graph type")
raise ValueError(
"Unknown graph type {t}. ".format(t=graph_type)
+ "Available types are ['erdos-renyi', 'barabasi-albert', 'full']"
)

# randomly permute edges - required because we limited ourselves to lower diagonal previously
perms = np.random.permutation(np.eye(num_nodes, num_nodes))
Expand Down
Loading

0 comments on commit 4134ab8

Please sign in to comment.