Skip to content

Commit

Permalink
fix metrics generation (#5718)
Browse files Browse the repository at this point in the history
  • Loading branch information
krasinski committed Mar 22, 2024
1 parent 8fb6b7e commit 3af804b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import water.api.API

trait MetricResolver {
def resolveMetrics(substitutionContext: ModelMetricsSubstitutionContext): Seq[Metric] = {
val h2oSchemaClass = substitutionContext.h2oSchemaClass
val h2oSchemaClass: Class[_] = substitutionContext.h2oSchemaClass

val parameters =
for (field <- h2oSchemaClass.getDeclaredFields
if field.getAnnotation(classOf[API]) != null && !MetricFieldExceptions.ignored().contains(field.getName))
if field.getAnnotation(classOf[API]) != null
if !MetricFieldExceptions.ignored().contains(field.getName)
if !substitutionContext.skipFields.contains(field.getName))
yield {
val (swFieldName, swMetricName) = MetricNameConverter.convertFromH2OToSW(field.getName)
Metric(swFieldName, swMetricName, field.getName, field.getType, field.getAnnotation(classOf[API]).help())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,82 +22,92 @@ import water.api.schemas3._

trait MetricsConfigurations {
def metricsConfiguration: Seq[ModelMetricsSubstitutionContext] = {
val duplicatedGLMMetrics = Seq("loglikelihood", "AIC")
Seq(
ModelMetricsSubstitutionContext(
"H2OCommonMetrics",
classOf[ModelMetricsBaseV3[_, _]],
Seq("H2OMetrics"),
"The class makes available all metrics that shared across all algorithms, and ML problems." +
entityName = "H2OCommonMetrics",
h2oSchemaClass = classOf[ModelMetricsBaseV3[_, _]],
parentEntities = Seq("H2OMetrics"),
classDescription = "The class makes available all metrics that shared across all algorithms, and ML problems." +
" (classification, regression, dimension reduction)."),
ModelMetricsSubstitutionContext(
"H2OBinomialMetrics",
classOf[ModelMetricsBinomialV3[_, _]],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting binomial classification."),
entityName = "H2OBinomialMetrics",
h2oSchemaClass = classOf[ModelMetricsBinomialV3[_, _]],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting binomial classification."),
ModelMetricsSubstitutionContext(
"H2OBinomialGLMMetrics",
classOf[ModelMetricsBinomialGLMV3],
Seq("H2OBinomialMetrics", "H2OGLMMetrics"),
"The class makes available all binomial metrics supported by GLM algorithm."),
entityName = "H2OBinomialGLMMetrics",
h2oSchemaClass = classOf[ModelMetricsBinomialGLMV3],
parentEntities = Seq("H2OBinomialMetrics", "H2OGLMMetrics"),
classDescription = "The class makes available all binomial metrics supported by GLM algorithm.",
skipFields = duplicatedGLMMetrics),
ModelMetricsSubstitutionContext(
"H2ORegressionMetrics",
classOf[ModelMetricsRegressionV3[_, _]],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting regression."),
entityName = "H2ORegressionMetrics",
h2oSchemaClass = classOf[ModelMetricsRegressionV3[_, _]],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting regression."),
ModelMetricsSubstitutionContext(
"H2ORegressionGLMMetrics",
classOf[ModelMetricsRegressionGLMV3],
Seq("H2ORegressionMetrics", "H2OGLMMetrics"),
"The class makes available all regression metrics supported by GLM algorithm."),
entityName = "H2ORegressionGLMMetrics",
h2oSchemaClass = classOf[ModelMetricsRegressionGLMV3],
parentEntities = Seq("H2ORegressionMetrics", "H2OGLMMetrics"),
classDescription = "The class makes available all regression metrics supported by GLM algorithm.",
skipFields = duplicatedGLMMetrics),
ModelMetricsSubstitutionContext(
"H2ORegressionCoxPHMetrics",
classOf[ModelMetricsRegressionCoxPHV3],
Seq("H2ORegressionMetrics"),
"The class makes available all regression metrics supported by CoxPH algorithm."),
entityName = "H2ORegressionCoxPHMetrics",
h2oSchemaClass = classOf[ModelMetricsRegressionCoxPHV3],
parentEntities = Seq("H2ORegressionMetrics"),
classDescription = "The class makes available all regression metrics supported by CoxPH algorithm."),
ModelMetricsSubstitutionContext(
"H2OMultinomialMetrics",
classOf[ModelMetricsMultinomialV3[_, _]],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting multinomial classification."),
entityName = "H2OMultinomialMetrics",
h2oSchemaClass = classOf[ModelMetricsMultinomialV3[_, _]],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting multinomial classification."),
ModelMetricsSubstitutionContext(
"H2OMultinomialGLMMetrics",
classOf[ModelMetricsMultinomialGLMV3],
Seq("H2OMultinomialMetrics", "H2OGLMMetrics"),
"The class makes available all multinomial metrics supported by GLM algorithm."),
entityName = "H2OMultinomialGLMMetrics",
h2oSchemaClass = classOf[ModelMetricsMultinomialGLMV3],
parentEntities = Seq("H2OMultinomialMetrics", "H2OGLMMetrics"),
classDescription = "The class makes available all multinomial metrics supported by GLM algorithm.",
skipFields = duplicatedGLMMetrics),
ModelMetricsSubstitutionContext(
"H2OOrdinalMetrics",
classOf[ModelMetricsOrdinalV3[_, _]],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting ordinal regression."),
entityName = "H2OOrdinalMetrics",
h2oSchemaClass = classOf[ModelMetricsOrdinalV3[_, _]],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting ordinal regression."),
ModelMetricsSubstitutionContext(
"H2OOrdinalGLMMetrics",
classOf[ModelMetricsOrdinalGLMV3],
Seq("H2OOrdinalMetrics", "H2OGLMMetrics"),
"The class makes available all ordinal metrics supported by GLM algorithm."),
entityName = "H2OOrdinalGLMMetrics",
h2oSchemaClass = classOf[ModelMetricsOrdinalGLMV3],
parentEntities = Seq("H2OOrdinalMetrics", "H2OGLMMetrics"),
classDescription = "The class makes available all ordinal metrics supported by GLM algorithm."),
ModelMetricsSubstitutionContext(
"H2OAnomalyMetrics",
classOf[ModelMetricsAnomalyV3],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting anomaly detection."),
entityName = "H2OAnomalyMetrics",
h2oSchemaClass = classOf[ModelMetricsAnomalyV3],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting anomaly detection."),
ModelMetricsSubstitutionContext(
"H2OClusteringMetrics",
classOf[ModelMetricsClusteringV3],
Seq("H2OCommonMetrics"),
"The class makes available all metrics that shared across all algorithms supporting clustering."),
entityName = "H2OClusteringMetrics",
h2oSchemaClass = classOf[ModelMetricsClusteringV3],
parentEntities = Seq("H2OCommonMetrics"),
classDescription =
"The class makes available all metrics that shared across all algorithms supporting clustering."),
ModelMetricsSubstitutionContext(
"H2OAutoEncoderMetrics",
classOf[ModelMetricsAutoEncoderV3],
Seq("H2OCommonMetrics"),
"The class provides all metrics available for ``H2OAutoEncoder``."),
entityName = "H2OAutoEncoderMetrics",
h2oSchemaClass = classOf[ModelMetricsAutoEncoderV3],
parentEntities = Seq("H2OCommonMetrics"),
classDescription = "The class provides all metrics available for ``H2OAutoEncoder``."),
ModelMetricsSubstitutionContext(
"H2OGLRMMetrics",
classOf[ModelMetricsGLRMV99],
Seq("H2OCommonMetrics"),
"The class provides all metrics available for ``H2OGLRM``."),
entityName = "H2OGLRMMetrics",
h2oSchemaClass = classOf[ModelMetricsGLRMV99],
parentEntities = Seq("H2OCommonMetrics"),
classDescription = "The class provides all metrics available for ``H2OGLRM``."),
ModelMetricsSubstitutionContext(
"H2OPCAMetrics",
classOf[ModelMetricsPCAV3],
Seq("H2OCommonMetrics"),
"The class provides all metrics available for ``H2OPCA``."))
entityName = "H2OPCAMetrics",
h2oSchemaClass = classOf[ModelMetricsPCAV3],
parentEntities = Seq("H2OCommonMetrics"),
classDescription = "The class provides all metrics available for ``H2OPCA``."))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ case class ModelMetricsSubstitutionContext(
entityName: String,
h2oSchemaClass: Class[_],
parentEntities: Seq[String],
classDescription: String)
classDescription: String,
skipFields: Seq[String] = Seq.empty)
extends SubstitutionContextBase {

val namespace = "ai.h2o.sparkling.ml.metrics"
Expand Down

0 comments on commit 3af804b

Please sign in to comment.