In [1]:
%load_ext autoreload
%autoreload

In [2]:
from sklearn.linear_model import LogisticRegression

from dtcontrol.benchmark_suite import BenchmarkSuite
from dtcontrol.decision_tree.decision_tree import DecisionTree
from dtcontrol.decision_tree.determinization.max_freq_determinizer import MaxFreqDeterminizer
from dtcontrol.decision_tree.impurity.entropy import Entropy
from dtcontrol.decision_tree.impurity.multi_label_entropy import MultiLabelEntropy
from dtcontrol.decision_tree.splitting.axis_aligned import AxisAlignedSplittingStrategy
from dtcontrol.decision_tree.splitting.linear_classifier import LinearClassifierSplittingStrategy

In [3]:
suite = BenchmarkSuite(timeout=60,
                       save_folder='saved_classifiers',
                       benchmark_file='benchmark',
                       rerun=False)

INFO: Benchmark statistics will be available in benchmark.json and benchmark.html.
INFO: Constructed trees will be written to decision_trees.



In [4]:
suite.add_datasets(['examples', 'examples/prism'], include=['cartpole'])

In [5]:
aa = AxisAlignedSplittingStrategy()
logreg = LinearClassifierSplittingStrategy(LogisticRegression, solver='lbfgs', penalty=None)
classifiers = [
    DecisionTree([aa], Entropy(), 'CART'),
    DecisionTree([aa, logreg], Entropy(), 'LogReg'),
    DecisionTree([aa], Entropy(), 'Early-stopping', early_stopping=True),
    DecisionTree([aa], Entropy(MaxFreqDeterminizer()), 'MaxFreq', early_stopping=True),
    DecisionTree([aa], MultiLabelEntropy(), 'MultiLabelEntropy', early_stopping=True)
]

In [6]:
suite.benchmark(classifiers)

1/5: Evaluating CART on cartpole... 
1/5: Not running since the result is already available.
2/5: Evaluating LogReg on cartpole... 
2/5: Not running since the result is already available.
3/5: Evaluating Early-stopping on cartpole... 
3/5: Not running since the result is already available.
4/5: Evaluating MaxFreq on cartpole... 
4/5: Not running since the result is already available.
5/5: Evaluating MultiLabelEntropy on cartpole... 
5/5: Not running since the result is already available.
All benchmarks completed. Shutting down dtControl.


In [7]:
suite.display_html()

In [7]:
# suite.delete_cell('10rooms', 'LinearClassifierDT-LogisticRegression ')
# suite.delete_cell('10rooms', 'MaxFreqDT')