Skip to content

Commit

Permalink
Typed SDCA binary trainers
Browse files Browse the repository at this point in the history
Fix some tests failed because of duplicated load names
  • Loading branch information
wschin committed Feb 11, 2019
1 parent a254bf5 commit 27536c3
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 129 deletions.
@@ -1,5 +1,4 @@
using System;
using Microsoft.ML.Data;
using Microsoft.ML.StaticPipe;

namespace Microsoft.ML.Samples.Static
Expand Down Expand Up @@ -74,7 +73,7 @@ public static void SdcaBinaryClassification()
.Append(row => (
Features: row.Features.Normalize(),
Label: row.Label,
Score: mlContext.BinaryClassification.Trainers.Sdca(
Score: mlContext.BinaryClassification.Trainers.SdcaCalibrated(
row.Label,
row.Features,
l1Threshold: 0.25f,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
Expand Up @@ -31,7 +31,7 @@

namespace Microsoft.ML.Trainers.SymSgd
{
using TPredictor = CalibratedModelParametersBase<LinearBinaryModelParameters,PlattCalibrator>;
using TPredictor = CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>;

/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor>
Expand Down Expand Up @@ -208,7 +208,7 @@ private TPredictor CreatePredictor(VBuffer<float> weights, float bias)
VBufferUtils.CreateMaybeSparseCopy(in weights, ref maybeSparseWeights,
Conversions.Instance.GetIsDefaultPredicate<float>(NumberType.R4));
var predictor = new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias);
return new ParameterMixingCalibratedModelParameters<LinearBinaryModelParameters,PlattCalibrator>(Host, predictor, new PlattCalibrator(Host, -1, 0));
return new ParameterMixingCalibratedModelParameters<LinearBinaryModelParameters, PlattCalibrator>(Host, predictor, new PlattCalibrator(Host, -1, 0));
}

protected override BinaryPredictionTransformer<TPredictor> MakeTransformer(TPredictor model, Schema trainSchema)
Expand Down
203 changes: 127 additions & 76 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -32,6 +31,12 @@
"lc",
"sasdca")]

[assembly: LoadableClass(typeof(SdcaCalibratedBinaryTrainer), typeof(SdcaCalibratedBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaCalibratedBinaryTrainer.UserNameValue,
SdcaCalibratedBinaryTrainer.LoadNameValue,
"SdcaLogisticRegression")]

[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
StochasticGradientDescentClassificationTrainer.UserNameValue,
Expand Down Expand Up @@ -1395,11 +1400,9 @@ public void Add(Double summand)
}
}

public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
public abstract class SdcaBinaryTrainerBase<TModelParameters> : SdcaTrainerBase<SdcaBinaryTrainerBase<TModelParameters>.Options, BinaryPredictionTransformer<TModelParameters>, TModelParameters>
where TModelParameters : class, IPredictorProducing<float>
{
internal const string LoadNameValue = "SDCA";
internal const string UserNameValue = "Fast Linear (SA-SDCA)";

public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
Expand All @@ -1409,7 +1412,7 @@ public sealed class Options : ArgumentsBase
public float PositiveInstanceWeight = 1;

[Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory();
public ICalibratorTrainerFactory Calibrator = null;

[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;
Expand All @@ -1435,7 +1438,7 @@ internal override void Check(IHostEnvironment env)
public override TrainerInfo Info { get; }

/// <summary>
/// Initializes a new instance of <see cref="SdcaBinaryTrainer"/>
/// Initializes a new instance of <see cref="SdcaBinaryTrainerBase{TModelParameters}"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The label, or dependent variable.</param>
Expand All @@ -1445,7 +1448,7 @@ internal override void Check(IHostEnvironment env)
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
internal SdcaBinaryTrainer(IHostEnvironment env,
protected SdcaBinaryTrainerBase(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
Expand All @@ -1462,77 +1465,21 @@ internal override void Check(IHostEnvironment env)
Loss = _loss;
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
_positiveInstanceWeight = Args.PositiveInstanceWeight;

var outCols = new List<SchemaShape.Column>()
{
new SchemaShape.Column(
DefaultColumnNames.Score,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())
),
new SchemaShape.Column(
DefaultColumnNames.PredictedLabel,
SchemaShape.Column.VectorKind.Scalar,
BoolType.Instance,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))

};

if (!Info.NeedCalibration)
{
outCols.Insert(1, new SchemaShape.Column(
DefaultColumnNames.Probability,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))));
};

_outputColumns = outCols.ToArray();
_outputColumns = ComputeSdcaBinaryClassifierSchemaShape();
}

internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
protected SdcaBinaryTrainerBase(IHostEnvironment env, Options options)
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
{
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
_positiveInstanceWeight = Args.PositiveInstanceWeight;

var outCols = new List<SchemaShape.Column>()
{
new SchemaShape.Column(
DefaultColumnNames.Score,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())
),
new SchemaShape.Column(
DefaultColumnNames.PredictedLabel,
SchemaShape.Column.VectorKind.Scalar,
BoolType.Instance,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))

};

if (!Info.NeedCalibration)
{
outCols.Insert(1, new SchemaShape.Column(
DefaultColumnNames.Probability,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))));
};

_outputColumns = outCols.ToArray();
_outputColumns = ComputeSdcaBinaryClassifierSchemaShape();
}

protected abstract SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape();

protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
{
Contracts.Assert(labelCol.IsValid);
Expand All @@ -1547,7 +1494,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
error();
}

protected override TScalarPredictor CreatePredictor(VBuffer<float>[] weights, float[] bias)
protected LinearBinaryModelParameters CreateLinearBinaryModelParameters(VBuffer<float>[] weights, float[] bias)
{
Host.CheckParam(Utils.Size(weights) == 1, nameof(weights));
Host.CheckParam(Utils.Size(bias) == 1, nameof(bias));
Expand All @@ -1558,10 +1505,7 @@ protected override TScalarPredictor CreatePredictor(VBuffer<float>[] weights, fl
VBufferUtils.CreateMaybeSparseCopy(weights[0], ref maybeSparseWeights,
Conversions.Instance.GetIsDefaultPredicate<float>(NumberType.Float));

var predictor = new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias[0]);
if (Info.NeedCalibration)
return predictor;
return new ParameterMixingCalibratedModelParameters<LinearBinaryModelParameters, PlattCalibrator>(Host, predictor, new PlattCalibrator(Host, -1, 0));
return new LinearBinaryModelParameters(Host, in maybeSparseWeights, bias[0]);
}

private protected override float GetInstanceWeight(FloatLabelCursor cursor)
Expand All @@ -1575,8 +1519,115 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig
weightSetCount = 1;
}

protected override BinaryPredictionTransformer<TScalarPredictor> MakeTransformer(TScalarPredictor model, Schema trainSchema)
=> new BinaryPredictionTransformer<TScalarPredictor>(Host, model, trainSchema, FeatureColumn.Name);
protected override BinaryPredictionTransformer<TModelParameters> MakeTransformer(TModelParameters model, Schema trainSchema)
=> new BinaryPredictionTransformer<TModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
}

public sealed class SdcaCalibratedBinaryTrainer :
SdcaBinaryTrainerBase<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
{
internal const string LoadNameValue = "SDCALR";
internal const string UserNameValue = "Fast Linear (SA-SDCA) Logistic Regression";

internal SdcaCalibratedBinaryTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
: base(env, labelColumn, featureColumn, weightColumn, new LogLoss(), l2Const, l1Threshold, maxIterations)
{
}

internal SdcaCalibratedBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options)
{
}

protected override CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator> CreatePredictor(VBuffer<float>[] weights, float[] bias)
{
var linearModel = CreateLinearBinaryModelParameters(weights, bias);
var calibrator = new PlattCalibrator(Host, -1, 0);
return new ParameterMixingCalibratedModelParameters<LinearBinaryModelParameters, PlattCalibrator>(Host, linearModel, calibrator);
}

protected override SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape()
{
return new SchemaShape.Column[]
{
new SchemaShape.Column(
DefaultColumnNames.Score,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
new SchemaShape.Column(
DefaultColumnNames.Probability,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
new SchemaShape.Column(
DefaultColumnNames.PredictedLabel,
SchemaShape.Column.VectorKind.Scalar,
BoolType.Instance,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))

};
}
}

public sealed class SdcaBinaryTrainer : SdcaBinaryTrainerBase<LinearBinaryModelParameters>
{
internal const string LoadNameValue = "SDCA";
internal const string UserNameValue = "Fast Linear (SA-SDCA)";

internal SdcaBinaryTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
: base(env, labelColumn, featureColumn, weightColumn, loss, l2Const, l1Threshold, maxIterations)
{
}

internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options)
{
}

protected override SchemaShape.Column[] ComputeSdcaBinaryClassifierSchemaShape()
{
return new SchemaShape.Column[]
{
new SchemaShape.Column(
DefaultColumnNames.Score,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())
),
new SchemaShape.Column(
DefaultColumnNames.PredictedLabel,
SchemaShape.Column.VectorKind.Scalar,
BoolType.Instance,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
};
}

/// <summary>
/// Comparing with <see cref="SdcaCalibratedBinaryTrainer.CreatePredictor(VBuffer{float}[], float[])"/>,
/// <see cref="CreatePredictor"/> directly outputs a <see cref="LinearBinaryModelParameters"/> built from
/// the learned weights and bias without calibration.
/// </summary>
protected override LinearBinaryModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
=> CreateLinearBinaryModelParameters(weights, bias);
}

public sealed class StochasticGradientDescentClassificationTrainer :
Expand Down
46 changes: 46 additions & 0 deletions src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
Expand Up @@ -134,6 +134,52 @@ public static class StandardLearnersCatalog
return new SdcaBinaryTrainer(env, labelColumn, featureColumn, weights, loss, l2Const, l1Threshold, maxIterations);
}

/// <summary>
/// Predict a target using a logistic regression model trained with the SDCA trainer.
/// </summary>
/// <param name="catalog">The binary classification catalog trainer object.</param>
/// <param name="labelColumn">The labelColumn, or dependent variable.</param>
/// <param name="featureColumn">The features, or independent variables.</param>
/// <param name="weights">The optional example weights.</param>
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs)]
/// ]]></format>
/// </example>
public static SdcaCalibratedBinaryTrainer StochasticDualCoordinateAscentCalibrated(
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new SdcaCalibratedBinaryTrainer(env, labelColumn, featureColumn, weights, l2Const, l1Threshold, maxIterations);
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the SDCA trainer.
/// </summary>
/// <param name="catalog">The binary classification catalog trainer object.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
public static SdcaCalibratedBinaryTrainer StochasticDualCoordinateAscentCalibrated(
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
SdcaCalibratedBinaryTrainer.Options options)
{
Contracts.CheckValue(catalog, nameof(catalog));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(catalog);
return new SdcaCalibratedBinaryTrainer(env, options);
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the SDCA trainer.
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.StaticPipe/OnlineLearnerStatic.cs
Expand Up @@ -69,7 +69,7 @@ public static class AveragedPerceptronStaticExtensions
else
return trainer;
}, label, features, weights, hasProbs);
}, label, features, weights);

return rec.Output;
}
Expand Down Expand Up @@ -129,7 +129,7 @@ public static class AveragedPerceptronStaticExtensions
else
return trainer;
}, label, features, weights, hasProbs);
}, label, features, weights);

return rec.Output;
}
Expand Down

0 comments on commit 27536c3

Please sign in to comment.