Skip to content

Commit

Permalink
PUBDEV-3071: Add RMSE to model metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
navdeep-G committed Jul 13, 2016
1 parent 5aa433f commit 666184f
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 2 deletions.
2 changes: 2 additions & 0 deletions h2o-core/src/main/java/hex/ModelMetrics.java
Expand Up @@ -53,13 +53,15 @@ public ModelMetrics(Model model, Frame frame, long nobs, double MSE, String desc
sb.append(" model id: " + _modelKey + "\n");
sb.append(" frame id: " + _frameKey + "\n");
sb.append(" MSE: " + (float)_MSE + "\n");
sb.append(" RMSE: " + (float)rmse() + "\n");
return sb.toString();
}

public Model model() { return _model==null ? (_model=DKV.getGet(_modelKey)) : _model; }
public Frame frame() { return _frame==null ? (_frame=DKV.getGet(_frameKey)) : _frame; }

public double mse() { return _MSE; }
public double rmse() { return Math.sqrt(_MSE);}
public ConfusionMatrix cm() { return null; }
public float[] hr() { return null; }
public AUC2 auc_obj() { return null; }
Expand Down
2 changes: 1 addition & 1 deletion h2o-core/src/main/java/hex/schemas/GridSchemaV99.java
Expand Up @@ -27,7 +27,7 @@ public class GridSchemaV99 extends SchemaV3<Grid, GridSchemaV99> {
@API(help = "Grid id")
public KeyV3.GridKeyV3 grid_id;

@API(help = "Model performance metric to sort by. Examples: logloss, residual_deviance, mse, auc, r2, f1, recall, precision, accuracy, mcc, err, err_count, lift_top_group, max_per_class_error", required = false, direction = API.Direction.INOUT)
@API(help = "Model performance metric to sort by. Examples: logloss, residual_deviance, mse, rmse, auc, r2, f1, recall, precision, accuracy, mcc, err, err_count, lift_top_group, max_per_class_error", required = false, direction = API.Direction.INOUT)
public String sort_by;

@API(help = "Specify whether sort order should be decreasing.", required = false, direction = API.Direction.INOUT)
Expand Down
Expand Up @@ -46,6 +46,9 @@ public class ModelMetricsBaseV3<I extends ModelMetrics, S extends ModelMetricsBa
@API(help = "The Mean Squared Error of the prediction for this scoring run.", direction = API.Direction.OUTPUT)
public double MSE;

@API(help = "The Root Mean Squared Error of the prediction for this scoring run.", direction = API.Direction.OUTPUT)
public double RMSE;

@API(help="Number of observations.")
public long nobs;

Expand All @@ -70,7 +73,7 @@ public ModelMetricsBaseV3() {}

PojoUtils.copyProperties(this, modelMetrics, PojoUtils.FieldNaming.ORIGIN_HAS_UNDERSCORES,
new String[]{"model", "model_category", "model_checksum", "frame", "frame_checksum"});

RMSE=modelMetrics.rmse();

return (S) this;
}
Expand Down
7 changes: 7 additions & 0 deletions h2o-py/h2o/model/metrics_base.py
Expand Up @@ -59,6 +59,7 @@ def show(self):
print(reported_on.format("test"))
print()
print("MSE: " + str(self.mse()))
print("RMSE: " + str(self.rmse()))
if metric_type in types_w_r2:
print("R^2: " + str(self.r2()))
if metric_type in types_w_mean_residual_deviance:
Expand Down Expand Up @@ -144,6 +145,12 @@ def mse(self):
"""
return self._metric_json['MSE']

def rmse(self):
"""
:return: Retrieve the RMSE for this set of metrics
"""
return self._metric_json['RMSE']

def residual_deviance(self):
"""
:return: the residual deviance if the model has residual deviance, or None if no residual deviance.
Expand Down
25 changes: 25 additions & 0 deletions h2o-py/h2o/model/model_base.py
Expand Up @@ -473,6 +473,31 @@ def mse(self, train=False, valid=False, xval=False):
for k,v in zip(list(tm.keys()),list(tm.values())): m[k] = None if v is None else v.mse()
return list(m.values())[0] if len(m) == 1 else m

def rmse(self, train=False, valid=False, xval=False):
"""
Get the RMSE(s).
If all are False (default), then return the training metric value.
If more than one options is set to True, then return a dictionary of metrics where the keys are "train", "valid",
and "xval"
Parameters
----------
train : bool, default=True
If train is True, then return the RMSE value for the training data.
valid : bool, default=True
If valid is True, then return the RMSE value for the validation data.
xval : bool, default=True
If xval is True, then return the RMSE value for the cross validation data.
Returns
-------
The RMSE for this regression model.
"""
tm = ModelBase._get_metrics(self, train, valid, xval)
m = {}
for k,v in zip(list(tm.keys()),list(tm.values())): m[k] = None if v is None else v.rmse()
return list(m.values())[0] if len(m) == 1 else m

def logloss(self, train=False, valid=False, xval=False):
"""
Get the Log Loss(s).
Expand Down
73 changes: 73 additions & 0 deletions h2o-r/h2o-package/R/models.R
Expand Up @@ -1024,6 +1024,79 @@ h2o.mse <- function(object, train=FALSE, valid=FALSE, xval=FALSE) {
invisible(NULL)
}

#' Retrieves Root Mean Squared Error Value
#'
#' Retrieves the root mean squared error value from an \linkS4class{H2OModelMetrics}
#' object.
#' If "train", "valid", and "xval" parameters are FALSE (default), then the training RMSEvalue is returned. If more
#' than one parameter is set to TRUE, then a named vector of RMSEs are returned, where the names are "train", "valid"
#' or "xval".
#'
#' This function only supports \linkS4class{H2OBinomialMetrics},
#' \linkS4class{H2OMultinomialMetrics}, and \linkS4class{H2ORegressionMetrics} objects.
#'
#' @param object An \linkS4class{H2OModelMetrics} object of the correct type.
#' @param train Retrieve the training RMSE
#' @param valid Retrieve the validation RMSE
#' @param xval Retrieve the cross-validation RMSE
#' @seealso \code{\link{h2o.auc}} for AUC, \code{\link{h2o.mse}} for RMSE, and
#' \code{\link{h2o.metric}} for the various threshold metrics. See
#' \code{\link{h2o.performance}} for creating H2OModelMetrics objects.
#' @examples
#' \donttest{
#' library(h2o)
#' h2o.init()
#'
#' prosPath <- system.file("extdata", "prostate.csv", package="h2o")
#' hex <- h2o.uploadFile(prosPath)
#'
#' hex[,2] <- as.factor(hex[,2])
#' model <- h2o.gbm(x = 3:9, y = 2, training_frame = hex, distribution = "bernoulli")
#' perf <- h2o.performance(model, hex)
#' h2o.rmse(perf)
#' }
#' @export
h2o.rmse <- function(object, train=FALSE, valid=FALSE, xval=FALSE) {
if( is(object, "H2OModelMetrics") ) return( object@metrics$RMSE )
if( is(object, "H2OModel") ) {
metrics <- NULL # break out special for clustering vs the rest
model.parts <- .model.parts(object)
if ( !train && !valid && !xval ) {
metric <- model.parts$tm@metrics$RMSE
if ( !is.null(metric) ) return(metric)
}
v <- c()
v_names <- c()
if ( train ) {
if( is(object, "H2OClusteringModel") ) v <- model.parts$tm@metrics$centroid_stats$within_cluster_sum_of_squares
else v <- c(v,model.parts$tm@metrics$RMSE)
v_names <- c(v_names,"train")
}
if ( valid ) {
if( is.null(model.parts$vm) ) return(invisible(.warn.no.validation()))
else {
if( is(object, "H2OClusteringModel") ) v <- model.parts$vm@metrics$centroid_stats$within_cluster_sum_of_squares
else v <- c(v,model.parts$vm@metrics$MSE)
v_names <- c(v_names,"valid")
}
}
if ( xval ) {
if( is.null(model.parts$xm) ) return(invisible(.warn.no.cross.validation()))
else {
if( is(object, "H2OClusteringModel") ) v <- model.parts$xm@metrics$centroid_stats$within_cluster_sum_of_squares
else v <- c(v,model.parts$xm@metrics$MSE)
v_names <- c(v_names,"xval")
}
}
if ( !is.null(v) ) {
names(v) <- v_names
if ( length(v)==1 ) { return( v[[1]] ) } else { return( v ) }
}
}
warning(paste0("No MSE for ",class(object)))
invisible(NULL)
}

#' Retrieve the Log Loss Value
#'
#' Retrieves the log loss output for a \linkS4class{H2OBinomialMetrics} or
Expand Down

0 comments on commit 666184f

Please sign in to comment.