-
Notifications
You must be signed in to change notification settings - Fork 1.9k
making GetCoefficientStatistics public #1979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,12 +27,12 @@ namespace Microsoft.ML.Learners | |
| public readonly struct CoefficientStatistics | ||
| { | ||
| public readonly string Name; | ||
| public readonly Single Estimate; | ||
| public readonly Single StandardError; | ||
| public readonly Single ZScore; | ||
| public readonly Single PValue; | ||
| public readonly float Estimate; | ||
| public readonly float StandardError; | ||
| public readonly float ZScore; | ||
| public readonly float PValue; | ||
|
|
||
| public CoefficientStatistics(string name, Single estimate, Single stdError, Single zScore, Single pValue) | ||
| public CoefficientStatistics(string name, float estimate, float stdError, float zScore, float pValue) | ||
| { | ||
| Contracts.AssertNonEmpty(name); | ||
| Name = name; | ||
|
|
@@ -69,10 +69,10 @@ private static VersionInfo GetVersionInfo() | |
| private readonly long _trainingExampleCount; | ||
|
|
||
| // The deviance of this model. | ||
| private readonly Single _deviance; | ||
| private readonly float _deviance; | ||
|
|
||
| // The deviance of the null hypothesis. | ||
| private readonly Single _nullDeviance; | ||
| private readonly float _nullDeviance; | ||
|
|
||
| // Total count of parameters. | ||
| private readonly int _paramCount; | ||
|
|
@@ -82,17 +82,17 @@ private static VersionInfo GetVersionInfo() | |
| // It could be null when there are too many non-zero weights so that | ||
| // the memory is insufficient to hold the Hessian matrix necessary for the computation | ||
| // of the variance-covariance matrix. | ||
| private readonly VBuffer<Single>? _coeffStdError; | ||
| private readonly VBuffer<float>? _coeffStdError; | ||
|
|
||
| public long TrainingExampleCount => _trainingExampleCount; | ||
|
|
||
| public Single Deviance => _deviance; | ||
| public float Deviance => _deviance; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
why we need _deviance?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
need proper comment. |
||
|
|
||
| public Single NullDeviance => _nullDeviance; | ||
| public float NullDeviance => _nullDeviance; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't see any reason for _nullDeviance to exist.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Need proper comment |
||
|
|
||
| public int ParametersCount => _paramCount; | ||
|
|
||
| internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, Single deviance, Single nullDeviance) | ||
| internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance) | ||
| { | ||
| Contracts.AssertValue(env); | ||
| env.Assert(trainingExampleCount > 0); | ||
|
|
@@ -104,7 +104,7 @@ internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, | |
| _nullDeviance = nullDeviance; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. none of _variable from that block should exist. |
||
| } | ||
|
|
||
| internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, Single deviance, Single nullDeviance, in VBuffer<Single> coeffStdError) | ||
| internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance, in VBuffer<float> coeffStdError) | ||
| : this(env, trainingExampleCount, paramCount, deviance, nullDeviance) | ||
| { | ||
| _env.Assert(coeffStdError.GetValues().Length == _paramCount); | ||
|
|
@@ -120,10 +120,10 @@ internal LinearModelStatistics(IHostEnvironment env, ModelLoadContext ctx) | |
| // *** Binary Format *** | ||
| // int: count of parameters | ||
| // long: count of training examples | ||
| // Single: deviance | ||
| // Single: null deviance | ||
| // float: deviance | ||
| // float: null deviance | ||
| // bool: whether standard error is included | ||
| // (Conditional) Single[_paramCount]: values of std errors of coefficients | ||
| // (Conditional) float[_paramCount]: values of std errors of coefficients | ||
| // (Conditional) int: length of std errors of coefficients | ||
| // (Conditional) int[_paramCount]: indices of std errors of coefficients | ||
|
|
||
|
|
@@ -143,18 +143,18 @@ internal LinearModelStatistics(IHostEnvironment env, ModelLoadContext ctx) | |
| return; | ||
| } | ||
|
|
||
| Single[] stdErrorValues = ctx.Reader.ReadFloatArray(_paramCount); | ||
| float[] stdErrorValues = ctx.Reader.ReadFloatArray(_paramCount); | ||
| int length = ctx.Reader.ReadInt32(); | ||
| _env.CheckDecode(length >= _paramCount); | ||
| if (length == _paramCount) | ||
| { | ||
| _coeffStdError = new VBuffer<Single>(length, stdErrorValues); | ||
| _coeffStdError = new VBuffer<float>(length, stdErrorValues); | ||
| return; | ||
| } | ||
|
|
||
| _env.Assert(length > _paramCount); | ||
| int[] stdErrorIndices = ctx.Reader.ReadIntArray(_paramCount); | ||
| _coeffStdError = new VBuffer<Single>(length, _paramCount, stdErrorValues, stdErrorIndices); | ||
| _coeffStdError = new VBuffer<float>(length, _paramCount, stdErrorValues, stdErrorIndices); | ||
| } | ||
|
|
||
| internal static LinearModelStatistics Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
|
@@ -178,10 +178,10 @@ private void SaveCore(ModelSaveContext ctx) | |
| // *** Binary Format *** | ||
| // int: count of parameters | ||
| // long: count of training examples | ||
| // Single: deviance | ||
| // Single: null deviance | ||
| // float: deviance | ||
| // float: null deviance | ||
| // bool: whether standard error is included | ||
| // (Conditional) Single[_paramCount]: values of std errors of coefficients | ||
| // (Conditional) float[_paramCount]: values of std errors of coefficients | ||
| // (Conditional) int: length of std errors of coefficients | ||
| // (Conditional) int[_paramCount]: indices of std errors of coefficients | ||
|
|
||
|
|
@@ -212,7 +212,7 @@ private void SaveCore(ModelSaveContext ctx) | |
| /// <summary> | ||
| /// Computes the standart deviation, Z-Score and p-Value. | ||
| /// </summary> | ||
| public static bool TryGetBiasStatistics(LinearModelStatistics stats, Single bias, out Single stdError, out Single zScore, out Single pValue) | ||
| public static bool TryGetBiasStatistics(LinearModelStatistics stats, float bias, out float stdError, out float zScore, out float pValue) | ||
| { | ||
| if (!stats._coeffStdError.HasValue) | ||
| { | ||
|
|
@@ -226,12 +226,12 @@ public static bool TryGetBiasStatistics(LinearModelStatistics stats, Single bias | |
| stdError = stats._coeffStdError.Value.GetValues()[0]; | ||
| Contracts.Assert(stdError == stats._coeffStdError.Value.GetItemOrDefault(0)); | ||
| zScore = bias / stdError; | ||
| pValue = 1.0f - (Single)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); | ||
| pValue = 1.0f - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); | ||
| return true; | ||
| } | ||
|
|
||
| private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stats, in VBuffer<Single> weights, in VBuffer<ReadOnlyMemory<char>> names, | ||
| ref VBuffer<Single> estimate, ref VBuffer<Single> stdErr, ref VBuffer<Single> zScore, ref VBuffer<Single> pValue, out ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames) | ||
| private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stats, in VBuffer<float> weights, in VBuffer<ReadOnlyMemory<char>> names, | ||
| ref VBuffer<float> estimate, ref VBuffer<float> stdErr, ref VBuffer<float> zScore, ref VBuffer<float> pValue, out ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames) | ||
| { | ||
| if (!stats._coeffStdError.HasValue) | ||
| { | ||
|
|
@@ -260,7 +260,7 @@ private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stat | |
| var weight = estimateEditor.Values[i - 1] = weights.GetItemOrDefault(wi); | ||
| var stdError = stdErrorEditor.Values[wi] = coeffStdErrorValues[i]; | ||
| zScoreEditor.Values[i - 1] = weight / stdError; | ||
| pValueEditor.Values[i - 1] = 1 - (Single)ProbabilityFunctions.Erf(Math.Abs(zScoreEditor.Values[i - 1] / sqrt2)); | ||
| pValueEditor.Values[i - 1] = 1 - (float)ProbabilityFunctions.Erf(Math.Abs(zScoreEditor.Values[i - 1] / sqrt2)); | ||
| } | ||
|
|
||
| estimate = estimateEditor.Commit(); | ||
|
|
@@ -283,28 +283,30 @@ private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stat | |
| }; | ||
| } | ||
|
|
||
| private List<CoefficientStatistics> GetUnorderedCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema) | ||
| private List<CoefficientStatistics> GetUnorderedCoefficientStatistics(LinearBinaryModelParameters parent, Schema.Column featureColumn) | ||
| { | ||
| Contracts.AssertValue(_env); | ||
| _env.CheckValue(parent, nameof(parent)); | ||
|
|
||
| if (!_coeffStdError.HasValue) | ||
| return new List<CoefficientStatistics>(); | ||
|
|
||
| var weights = parent.Weights as IReadOnlyList<Single>; | ||
| var weights = parent.Weights as IReadOnlyList<float>; | ||
| _env.Assert(_paramCount == 1 || weights != null); | ||
| _env.Assert(_coeffStdError.Value.Length == weights.Count + 1); | ||
|
|
||
| var names = default(VBuffer<ReadOnlyMemory<char>>); | ||
| MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, weights.Count, ref names); | ||
|
|
||
| featureColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names); | ||
| _env.Assert(names.Length > 0, "FeatureColumn has no metadata."); | ||
|
|
||
| ReadOnlySpan<float> stdErrorValues = _coeffStdError.Value.GetValues(); | ||
| const Double sqrt2 = 1.41421356237; // Math.Sqrt(2); | ||
|
|
||
| List<CoefficientStatistics> result = new List<CoefficientStatistics>(_paramCount - 1); | ||
| bool denseStdError = _coeffStdError.Value.IsDense; | ||
| ReadOnlySpan<int> stdErrorIndices = _coeffStdError.Value.GetIndices(); | ||
| Single[] zScores = new Single[_paramCount - 1]; | ||
| float[] zScores = new float[_paramCount - 1]; | ||
| for (int i = 1; i < _paramCount; i++) | ||
| { | ||
| int wi = denseStdError ? i - 1 : stdErrorIndices[i] - 1; | ||
|
|
@@ -315,7 +317,7 @@ private List<CoefficientStatistics> GetUnorderedCoefficientStatistics(LinearBina | |
| var weight = weights[wi]; | ||
| var stdError = stdErrorValues[i]; | ||
| var zScore = zScores[i - 1] = weight / stdError; | ||
| var pValue = 1 - (Single)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); | ||
| var pValue = 1 - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); | ||
| result.Add(new CoefficientStatistics(name, weight, stdError, zScore, pValue)); | ||
| } | ||
| return result; | ||
|
|
@@ -324,33 +326,31 @@ private List<CoefficientStatistics> GetUnorderedCoefficientStatistics(LinearBina | |
| /// <summary> | ||
| /// Gets the coefficient statistics as an object. | ||
| /// </summary> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need better summary than this one. |
||
| internal CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) | ||
| public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, Schema.Column featureColumn, int paramCountCap) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this isn't great name. Param is such a generic term, I have no idea what it's capping. |
||
| { | ||
| Contracts.AssertValue(_env); | ||
| _env.CheckValue(parent, nameof(parent)); | ||
| _env.CheckValue(schema, nameof(schema)); | ||
| _env.CheckParam(paramCountCap >= 0, nameof(paramCountCap)); | ||
|
|
||
| if (paramCountCap > _paramCount) | ||
| paramCountCap = _paramCount; | ||
|
|
||
| Single stdError; | ||
| Single zScore; | ||
| Single pValue; | ||
| float stdError; | ||
| float zScore; | ||
| float pValue; | ||
| var bias = parent.Bias; | ||
| if (!TryGetBiasStatistics(parent.Statistics, bias, out stdError, out zScore, out pValue)) | ||
| return null; | ||
|
|
||
| var order = GetUnorderedCoefficientStatistics(parent, schema).OrderByDescending(stat => stat.ZScore).Take(paramCountCap - 1); | ||
| var order = GetUnorderedCoefficientStatistics(parent, featureColumn).OrderByDescending(stat => stat.ZScore).Take(paramCountCap - 1); | ||
| return order.Prepend(new[] { new CoefficientStatistics("(Bias)", bias, stdError, zScore, pValue) }).ToArray(); | ||
| } | ||
|
|
||
| internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) | ||
| internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Schema.Column featureColumn, int paramCountCap) | ||
| { | ||
| Contracts.AssertValue(_env); | ||
| _env.CheckValue(writer, nameof(writer)); | ||
| _env.AssertValueOrNull(parent); | ||
| _env.AssertValueOrNull(schema); | ||
| writer.WriteLine(); | ||
| writer.WriteLine("*** MODEL STATISTICS SUMMARY *** "); | ||
| writer.WriteLine("Count of training examples:\t{0}", _trainingExampleCount); | ||
|
|
@@ -361,7 +361,7 @@ internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Ro | |
| if (parent == null) | ||
| return; | ||
|
|
||
| var coeffStats = GetCoefficientStatistics(parent, schema, paramCountCap); | ||
| var coeffStats = GetCoefficientStatistics(parent, featureColumn, paramCountCap); | ||
| if (coeffStats == null) | ||
| return; | ||
|
|
||
|
|
@@ -387,7 +387,7 @@ internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Ro | |
| /// Support method for linear models and <see cref="ICanGetSummaryInKeyValuePairs"/>. | ||
| /// </summary> | ||
| internal void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent, | ||
| RoleMappedSchema schema, int paramCountCap, List<KeyValuePair<string, object>> resultCollection) | ||
| Schema.Column featureColumn, int paramCountCap, List<KeyValuePair<string, object>> resultCollection) | ||
| { | ||
| Contracts.AssertValue(_env); | ||
| _env.AssertValue(resultCollection); | ||
|
|
@@ -400,15 +400,15 @@ internal void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent, | |
| if (parent == null) | ||
| return; | ||
|
|
||
| var coeffStats = GetCoefficientStatistics(parent, schema, paramCountCap); | ||
| var coeffStats = GetCoefficientStatistics(parent, featureColumn, paramCountCap); | ||
| if (coeffStats == null) | ||
| return; | ||
|
|
||
| foreach (var coeffStat in coeffStats) | ||
| { | ||
| resultCollection.Add(new KeyValuePair<string, object>( | ||
| coeffStat.Name, | ||
| new Single[] { coeffStat.Estimate, coeffStat.StandardError, coeffStat.ZScore, coeffStat.PValue })); | ||
| new float[] { coeffStat.Estimate, coeffStat.StandardError, coeffStat.ZScore, coeffStat.PValue })); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -458,7 +458,7 @@ internal Schema.Metadata MakeStatisticsMetadata(LinearBinaryModelParameters pare | |
| return builder.GetMetadata(); | ||
| } | ||
|
|
||
| private string DecorateProbabilityString(Single probZ) | ||
| private string DecorateProbabilityString(float probZ) | ||
| { | ||
| Contracts.AssertValue(_env); | ||
| _env.Assert(0 <= probZ && probZ <= 1); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System.Linq; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need Linq here? |
||
| using Microsoft.ML.Core.Data; | ||
| using Microsoft.ML.Data; | ||
| using Microsoft.ML.Internal.Calibration; | ||
|
|
@@ -54,7 +55,7 @@ public void TestEstimatorPoissonRegression() | |
| } | ||
|
|
||
| [Fact] | ||
| public void TestLogisticRegressionStats() | ||
| public void TestLogisticRegressionNoStats() | ||
| { | ||
| (IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline(); | ||
|
|
||
|
|
@@ -70,7 +71,7 @@ public void TestLogisticRegressionStats() | |
| } | ||
|
|
||
| [Fact] | ||
| public void TestLogisticRegressionStats_MKL() | ||
| public void TestLogisticRegressionWithStats() | ||
| { | ||
| (IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline(); | ||
|
|
||
|
|
@@ -80,14 +81,24 @@ public void TestLogisticRegressionStats_MKL() | |
| s.StdComputer = new ComputeLRTrainingStdThroughHal(); | ||
| })); | ||
|
|
||
| var transformerChain = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>>; | ||
| var transformer = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>>; | ||
|
|
||
| var linearModel = transformerChain.LastTransformer.Model.SubPredictor as LinearBinaryModelParameters; | ||
| var linearModel = transformer.LastTransformer.Model.SubPredictor as LinearBinaryModelParameters; | ||
| var stats = linearModel.Statistics; | ||
| LinearModelStatistics.TryGetBiasStatistics(stats, 2, out float stdError, out float zScore, out float pValue); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is the first thing that does not make much sense to me. It seems like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 As Tom pointed out, the API for getting statistics needs some cleanup. linearModel.Statistics or some other simple 1-liner should be the only code a user needs. In reply to: 244374354 [](ancestors = 244374354) |
||
|
|
||
| CompareNumbersWithTolerance(stdError, 0.250672936); | ||
| CompareNumbersWithTolerance(zScore, 7.97852373); | ||
|
|
||
| var scoredData = transformer.Transform(dataView); | ||
|
|
||
| var coeffcients = stats.GetCoefficientStatistics(linearModel, scoredData.Schema["Features"], 100); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is the second thing that does not make sense to me. Again should just be an instance method. As discussed also, the complication that this code itself is responsible for extracting out the slot names was a mistake. Anticipating some people's reaction, if the reaction from some people would be "well it's more convenient," then the correct remediation is to just make slot names easier to access (somehow), definitely not to "solve" the problem here as is done.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I find it strange to have
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| Assert.Equal(19, coeffcients.Length); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not a big deal at all since it's just a test, but while you're at it, this is a typo. |
||
|
|
||
| foreach(var coefficient in coeffcients) | ||
| Assert.True(coefficient.StandardError < 1.0); | ||
|
|
||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, if this is to be part of our public API we should probably as part of this PR make this entire structure a bit less goofy. If I were to trace back the root of what makes this structure a confusing structure, I think this choice here -- having the name of the slot instead of just the slot index, is the original sin. (There are lots of other really strange choices, but I think much of the evil and confusion stems from here.)
This should have just been
The fact that this is aNameinstead of just aSlotIndex` is the cause for an enormous amount of confusion in this structure. This requires that we link it with the schema column for features, but from an API perspective this is just something that could have easily (in fact, more easily) be done by the user themselves. As it is, the structure is offering "convenience" but in a way that inevitably leads to a more complicated API.Once you do that, restructuring the odd API below (which I'll comment more on in the test file) should become more approachable.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not to say that it might be convenient to have an internal utility method that makes saving as text more convenient for the
LinearModelParametersand whatnot, but confusing the needs of the text exporter (which is not part of our public API) with what we actually do in the public API was a misstep. (Not your misstep to be clear, since this code has existed, I think, more or less forever, but still something we ought to correct now that we are making it part of the publci API.)In reply to: 244374085 [](ancestors = 244374085)