Skip to content

Commit

Permalink
Refactor GAM Predictor to be Public-ly Creatable (#2142)
Browse files Browse the repository at this point in the history
Refactoring the GAM Trainer / Predictor to move all Training information into the trainer, and make the GAM predictor public so that models can be instantiated from generic parameters.
  • Loading branch information
rogancarr committed Jan 22, 2019
1 parent f8b2f39 commit 240faeb
Show file tree
Hide file tree
Showing 4 changed files with 1,021 additions and 938 deletions.
33 changes: 25 additions & 8 deletions src/Microsoft.ML.FastTree/GamClassification.cs
Expand Up @@ -108,8 +108,8 @@ private static bool[] ConvertTargetsToBool(double[] targets)
private protected override IPredictorProducing<float> TrainModelCore(TrainContext context)
{
TrainBase(context);
var predictor = new BinaryClassificationGamModelParameters(Host, InputLength, TrainSet,
MeanEffect, BinEffects, FeatureMap);
var predictor = new BinaryClassificationGamModelParameters(Host,
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
return new CalibratedPredictor(Host, predictor, calibrator);
}
Expand Down Expand Up @@ -157,14 +157,29 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
}
}

public class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
{
internal const string LoaderSignature = "BinaryClassGamPredictor";
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

internal BinaryClassificationGamModelParameters(IHostEnvironment env, int inputLength, Dataset trainset,
double meanEffect, double[][] binEffects, int[] featureMap)
: base(env, LoaderSignature, inputLength, trainset, meanEffect, binEffects, featureMap) { }
/// <summary>
/// Construct a new Binary Classification GAM with the defined properties.
/// </summary>
/// <param name="env">The Host Environment</param>
/// <param name="binUpperBounds">An array of arrays of bin-upper-bounds for each feature.</param>
/// <param name="binEffects">Anay array of arrays of effect sizes for each bin for each feature.</param>
/// <param name="intercept">The intercept term for the model. Also referred to as the bias or the mean effect.</param>
/// <param name="inputLength">The number of features passed from the dataset. Used when the number of input features is
/// different than the number of shape functions. Use default if all features have a shape function.</param>
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
public BinaryClassificationGamModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }

private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }
Expand All @@ -173,8 +188,10 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM BINP",
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010001, // Added Intercept but collided from release 0.6-0.9
verWrittenCur: 0x00010002, // Added Intercept (version revved to address collisions)
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
Expand Down

0 comments on commit 240faeb

Please sign in to comment.