diff --git a/h2o-py/h2o/model/dim_reduction.py b/h2o-py/h2o/model/dim_reduction.py index 9d9c9bab865c..aae8c7956fee 100644 --- a/h2o-py/h2o/model/dim_reduction.py +++ b/h2o-py/h2o/model/dim_reduction.py @@ -7,4 +7,28 @@ class H2ODimReductionModel(ModelBase): def __init__(self, dest_key, model_json): - super(H2ODimReductionModel, self).__init__(dest_key, model_json,H2ODimReductionModelMetrics) \ No newline at end of file + super(H2ODimReductionModel, self).__init__(dest_key, model_json,H2ODimReductionModelMetrics) + + def screeplot(self, type="barplot", show=True): + """ + Produce the scree plot + :param type: type of plot. "barplot" and "lines" currently supported + :param show: if False, the plot is not shown. matplotlib show method is blocking. + :return: None + """ + # check for matplotlib. exit if absent. + try: + imp.find_module('matplotlib') + import matplotlib.pyplot as plt + except ImportError: + print "matplotlib is required for this function!" + return + + variances = [s**2 for s in self._model_json['output']['pc_importance'].cell_values[0][1:]] + plt.xlabel('Components') + plt.ylabel('Variances') + plt.title('Scree Plot') + plt.xticks(range(1,len(variances)+1)) + if type == "barplot": plt.bar(range(1,len(variances)+1), variances) + elif type == "lines": plt.plot(range(1,len(variances)+1), variances, 'b--') + if show: plt.show() \ No newline at end of file diff --git a/h2o-py/tests/testdir_misc/pyunit_screeplot.py b/h2o-py/tests/testdir_misc/pyunit_screeplot.py new file mode 100644 index 000000000000..e38e556abb43 --- /dev/null +++ b/h2o-py/tests/testdir_misc/pyunit_screeplot.py @@ -0,0 +1,15 @@ +import sys +sys.path.insert(1, "../../") +import h2o + +def screeplot_test(ip,port): + # Connect to h2o + h2o.init(ip,port) + + australia = h2o.upload_file(h2o.locate("smalldata/pca_test/AustraliaCoast.csv")) + australia_pca = h2o.prcomp(x=australia[0:8], k = 4, transform = "STANDARDIZE") + australia_pca.screeplot(type="barplot", show=False) + australia_pca.screeplot(type="lines", show=False) + +if __name__ == "__main__": + h2o.run_test(sys.argv, screeplot_test) \ No newline at end of file