Skip to content

Commit

Permalink
Adding XML Docs to public functions; making accessors public and seal…
Browse files Browse the repository at this point in the history
…ing classes; fixing a serialization error; fixing an error in the sparsity calculation.
  • Loading branch information
Rogan Carr committed Jan 15, 2019
1 parent 1df23d6 commit add0cc8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
13 changes: 9 additions & 4 deletions src/Microsoft.ML.FastTree/GamClassification.cs
Expand Up @@ -156,12 +156,15 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
}
}

public class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
{
internal const string LoaderSignature = "BinaryClassGamPredictor";
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

internal BinaryClassificationGamModelParameters(IHostEnvironment env,
public BinaryClassificationGamModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int[] featureToInputMap)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, featureToInputMap) { }

Expand All @@ -172,8 +175,10 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM BINP",
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010001, // Added Intercept but collided from release 0.6-0.9
verWrittenCur: 0x00020001, // Added Intercept (version revved to address collisions)
verReadableCur: 0x00020001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
Expand Down
14 changes: 8 additions & 6 deletions src/Microsoft.ML.FastTree/GamModelParameters.cs
Expand Up @@ -21,6 +21,9 @@

namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// The base class for GAM Model Parameters.
/// </summary>
public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution,
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat
{
Expand Down Expand Up @@ -62,7 +65,7 @@ public abstract class GamModelParametersBase : ModelParametersBase<float>, IValu
_numFeatures = binEffects.Length;

// For sparse inputs we have a fast lookup
_binsAtAllZero = new int[_numFeatures]; // All 0s at 0 -- bug?
_binsAtAllZero = new int[_numFeatures];
_valueAtAllZero = 0;

// Walk through each feature and perform checks / updates
Expand All @@ -73,7 +76,7 @@ public abstract class GamModelParametersBase : ModelParametersBase<float>, IValu
Host.CheckParam(binUpperBounds[i].Length == binEffects[i].Length, nameof(binEffects), "Array contained wrong number of effect values");

// Update the value at zero
_valueAtAllZero += _binEffects[i][_binsAtAllZero[i]];
_valueAtAllZero += GetBinEffect(i, 0, out _binsAtAllZero[i]);
}

_featureMap = featureToInputMap;
Expand Down Expand Up @@ -104,6 +107,9 @@ protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadCon
_inputLength = reader.ReadInt32();
Host.CheckDecode(_inputLength >= 0);
Intercept = reader.ReadDouble();
if (ctx.Header.ModelVerWritten == 0x00010001)
using (var ch = env.Start("GamWarningChannel"))
ch.Warning("GAMs models written prior to ML.NET 0.6 are loaded with an incorrect Intercept. For these models, subtract the value of the intercept from the prediction.");

_binEffects = new double[_numFeatures][];
_binUpperBounds = new double[_numFeatures][];
Expand All @@ -116,10 +122,6 @@ protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadCon
for (int i = 0; i < _numFeatures; i++)
{
_binUpperBounds[i] = reader.ReadDoubleArray(_binEffects[i].Length);
// Ideally should verify that the sum of these matches _baseOutput,
// but due to differences in JIT over time and other considerations,
// it's possible that the sum may change even in the absence of
// model corruption.
_valueAtAllZero += GetBinEffect(i, 0, out _binsAtAllZero[i]);
}
int len = reader.ReadInt32();
Expand Down
25 changes: 20 additions & 5 deletions src/Microsoft.ML.FastTree/GamRegression.cs
Expand Up @@ -106,13 +106,26 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
}
}

public class RegressionGamModelParameters : GamModelParametersBase
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class RegressionGamModelParameters : GamModelParametersBase
{
internal const string LoaderSignature = "RegressionGamPredictor";
public override PredictionKind PredictionKind => PredictionKind.Regression;

internal RegressionGamModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int[] featureToInputMap)
/// <summary>
/// Construct a new Regression GAM with the defined properties.
/// </summary>
/// <param name="env">The Host Environment</param>
/// <param name="binUpperBounds">An array of arrays of bin-upper-bounds for each feature.</param>
/// <param name="binEffects">Anay array of arrays of effect sizes for each bin for each feature.</param>
/// <param name="intercept">The intercept term for the model. Also referred to as the bias or the mean effect.</param>
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
/// to the input feature. Used when multiple input features map to the same shape function. Leave null if all features have
/// a shape function.</param>
public RegressionGamModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int[] featureToInputMap = null)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, featureToInputMap) { }

private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
Expand All @@ -122,8 +135,10 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM REGP",
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010001, // Added Intercept but collided from release 0.6-0.9
verWrittenCur: 0x00020001, // Added Intercept (version revved to address collisions)
verReadableCur: 0x00020001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName);
Expand Down

0 comments on commit add0cc8

Please sign in to comment.