Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scrubbing online learners #2892

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public static void Example()
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);

// Create data training pipeline.
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numIterations: 10);
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numberOfIterations: 10);

// Fit this pipeline to the training data.
var model = pipeline.Fit(trainTestData.TrainSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public static void Example()
{
LossFunction = new SmoothedHingeLoss(),
LearningRate = 0.1f,
DoLazyUpdates = false,
LazyUpdate = false,
RecencyGain = 0.1f,
NumberOfIterations = 10
Copy link
Member

@wschin wschin Mar 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NumberOfIterations = 10
MaximumNumberOfIterations = 10
``` #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, It's number of iterations.


In reply to: 263972500 [](ancestors = 263972500)

};
Expand Down
30 changes: 15 additions & 15 deletions src/Microsoft.ML.StandardTrainers/Standard/Online/AveragedLinear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions
/// <see langword="false" /> to update averaged weights on every example.
/// Default is <see langword="true" />.
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")]
public bool DoLazyUpdates = true;
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy,DoLazyUpdates")]
public bool LazyUpdate = true;

/// <summary>
/// The L2 weight for <a href='tmpurl_regularization'>regularization</a>.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)]
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg,L2RegularizerWeight", SortOrder = 50)]
[TGUI(Label = "L2 Regularization Weight")]
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
public float L2RegularizerWeight = AveragedDefault.L2RegularizerWeight;
public float L2Regularization = AveragedDefault.L2Regularization;

/// <summary>
/// Extra weight given to more recent updates.
Expand All @@ -85,8 +85,8 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions
/// <see langword="false" /> means <see cref="RecencyGain"/> is additive.
/// Default is <see langword="false" />.
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")]
public bool RecencyGainMulti = false;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm,RecencyGainMulti")]
public bool RecencyGainMultiplicative = false;

/// <summary>
/// Determines whether to do averaging or not.
Expand All @@ -109,7 +109,7 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault
{
public const float LearningRate = 1;
public const bool DecreaseLearningRate = false;
public const float L2RegularizerWeight = 0;
public const float L2Regularization = 0;
}

internal abstract IComponentFactory<IScalarLoss> LossFunctionFactory { get; }
Expand Down Expand Up @@ -186,7 +186,7 @@ public override void FinishIteration(IChannel ch)
// Finalize things
if (Averaged)
{
if (_args.DoLazyUpdates && NumNoUpdates > 0)
if (_args.LazyUpdate && NumNoUpdates > 0)
{
// Update the total weights to include the final loss=0 updates
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
Expand Down Expand Up @@ -221,10 +221,10 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl
// REVIEW: Should this be biasUpdate != 0?
// This loss does not incorporate L2 if present, but the chance of that addition to the loss
// exactly cancelling out loss is remote.
if (loss != 0 || _args.L2RegularizerWeight > 0)
if (loss != 0 || _args.L2Regularization > 0)
{
// If doing lazy weights, we need to update the totalWeights and totalBias before updating weights/bias
if (_args.DoLazyUpdates && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers)
if (_args.LazyUpdate && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers)
{
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
TotalBias += Bias * NumNoUpdates * WeightsScale;
Expand All @@ -242,7 +242,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl

// Perform the update to weights and bias.
VectorUtils.AddMult(in feat, biasUpdate / WeightsScale, ref Weights);
WeightsScale *= 1 - 2 * _args.L2RegularizerWeight; // L2 regularization.
WeightsScale *= 1 - 2 * _args.L2Regularization; // L2 regularization.
ScaleWeightsIfNeeded();
Bias += biasUpdate;
PendingMultipliers += Math.Abs(biasUpdate);
Expand All @@ -251,7 +251,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl
// Add to averaged weights and increment the count.
if (Averaged)
{
if (!_args.DoLazyUpdates)
if (!_args.LazyUpdate)
IncrementAverageNonLazy();
else
NumNoUpdates++;
Expand Down Expand Up @@ -282,7 +282,7 @@ private void IncrementAverageNonLazy()
VectorUtils.AddMult(in Weights, Gain * WeightsScale, ref TotalWeights);
TotalBias += Gain * Bias;
NumWeightUpdates += Gain;
Gain = (_args.RecencyGainMulti ? Gain * _args.RecencyGain : Gain + _args.RecencyGain);
Gain = (_args.RecencyGainMultiplicative ? Gain * _args.RecencyGain : Gain + _args.RecencyGain);

// If gains got too big, rescale!
if (Gain > 1000)
Expand All @@ -303,11 +303,11 @@ private protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvi
Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive);

// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
Contracts.CheckUserArg(0 <= options.L2RegularizerWeight && options.L2RegularizerWeight < 0.5, nameof(options.L2RegularizerWeight), "must be in range [0, 0.5)");
Contracts.CheckUserArg(0 <= options.L2Regularization && options.L2Regularization < 0.5, nameof(options.L2Regularization), "must be in range [0, 0.5)");
Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative);
Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative);
// Verify user didn't specify parameters that conflict
Contracts.Check(!options.DoLazyUpdates || !options.RecencyGainMulti && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
Contracts.Check(!options.LazyUpdate || !options.RecencyGainMultiplicative && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");

AveragedLinearTrainerOptions = options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,24 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options)
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="learningRate">The learning rate. </param>
/// <param name="decreaseLearningRate">Whether to decrease learning rate as iterations progress.</param>
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param>
/// <param name="numIterations">The number of training iterations.</param>
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
/// <param name="numberOfIterations">The number of training iterations.</param>
internal AveragedPerceptronTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
IClassificationLoss lossFunction = null,
float learningRate = Options.AveragedDefault.LearningRate,
bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate,
float l2RegularizerWeight = Options.AveragedDefault.L2RegularizerWeight,
int numIterations = Options.AveragedDefault.NumIterations)
float l2Regularization = Options.AveragedDefault.L2Regularization,
int numberOfIterations = Options.AveragedDefault.NumberOfIterations)
: this(env, new Options
{
LabelColumnName = labelColumnName,
FeatureColumnName = featureColumnName,
LearningRate = learningRate,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
NumberOfIterations = numIterations,
L2Regularization = l2Regularization,
NumberOfIterations = numberOfIterations,
LossFunction = lossFunction ?? new HingeLoss()
})
{
Expand Down
16 changes: 8 additions & 8 deletions src/Microsoft.ML.StandardTrainers/Standard/Online/LinearSvm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public sealed class Options : OnlineLinearOptions
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string WeightColumn = null;
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight,WeightColumn", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string ExampleWeightColumnName = null;
}

private sealed class TrainState : TrainStateBase
Expand Down Expand Up @@ -232,20 +232,20 @@ public override LinearBinaryModelParameters CreatePredictor()
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column. </param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weightColumn">The optional name of the weight column.</param>
/// <param name="numIterations">The number of training iteraitons.</param>
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
/// <param name="numberOfIterations">The number of training iteraitons.</param>
[BestFriend]
internal LinearSvmTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
int numIterations = Options.OnlineDefault.NumIterations)
string exampleWeightColumnName = null,
int numberOfIterations = Options.OnlineDefault.NumberOfIterations)
: this(env, new Options
{
LabelColumnName = labelColumn,
FeatureColumnName = featureColumn,
WeightColumn = weightColumn,
NumberOfIterations = numIterations,
ExampleWeightColumnName = exampleWeightColumnName,
NumberOfIterations = numberOfIterations,
})
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,23 @@ public override LinearRegressionModelParameters CreatePredictor()
/// <param name="featureColumn">Name of the feature column.</param>
/// <param name="learningRate">The learning Rate.</param>
/// <param name="decreaseLearningRate">Decrease learning rate as iterations progress.</param>
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param>
/// <param name="numIterations">Number of training iterations through the data.</param>
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
/// <param name="numberOfIterations">Number of training iterations through the data.</param>
/// <param name="lossFunction">The custom loss functions. Defaults to <see cref="SquaredLoss"/> if not provided.</param>
internal OnlineGradientDescentTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
float learningRate = Options.OgdDefaultArgs.LearningRate,
bool decreaseLearningRate = Options.OgdDefaultArgs.DecreaseLearningRate,
float l2RegularizerWeight = Options.OgdDefaultArgs.L2RegularizerWeight,
int numIterations = Options.OgdDefaultArgs.NumIterations,
float l2Regularization = Options.OgdDefaultArgs.L2Regularization,
int numberOfIterations = Options.OgdDefaultArgs.NumberOfIterations,
IRegressionLoss lossFunction = null)
: this(env, new Options
{
LearningRate = learningRate,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
NumberOfIterations = numIterations,
L2Regularization= l2Regularization,
NumberOfIterations = numberOfIterations,
LabelColumnName = labelColumn,
FeatureColumnName = featureColumn,
LossFunction = lossFunction ?? new SquaredLoss()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public abstract class OnlineLinearOptions : TrainerInputBaseWithLabel
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter,numIterations", SortOrder = 50)]
[TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
[TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
public int NumberOfIterations = OnlineDefault.NumIterations;
public int NumberOfIterations = OnlineDefault.NumberOfIterations;

/// <summary>
/// Initial weights and bias, comma-separated.
Expand Down Expand Up @@ -62,7 +62,7 @@ public abstract class OnlineLinearOptions : TrainerInputBaseWithLabel
[BestFriend]
internal class OnlineDefault
{
public const int NumIterations = 1;
public const int NumberOfIterations = 1;
}
}

Expand Down