-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scrubbing online learners #2892
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -25,7 +25,7 @@ public static void Example() | |||||||
{ | ||||||||
LossFunction = new SmoothedHingeLoss.Options(), | ||||||||
LearningRate = 0.1f, | ||||||||
DoLazyUpdates = false, | ||||||||
LazyUpdates = false, | ||||||||
RecencyGain = 0.1f, | ||||||||
NumberOfIterations = 10 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
}; | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -58,15 +58,15 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions | |||||||
/// Default is <see langword="true" />. | ||||||||
/// </value> | ||||||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy")] | ||||||||
public bool DoLazyUpdates = true; | ||||||||
public bool LazyUpdates = true; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
/// <summary> | ||||||||
/// The L2 weight for <a href='tmpurl_regularization'>regularization</a>. | ||||||||
/// </summary> | ||||||||
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg", SortOrder = 50)] | ||||||||
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg,L2RegularizerWeight", SortOrder = 50)] | ||||||||
[TGUI(Label = "L2 Regularization Weight")] | ||||||||
[TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)] | ||||||||
public float L2RegularizerWeight = AveragedDefault.L2RegularizerWeight; | ||||||||
public float L2Regularization = AveragedDefault.L2Regularization; | ||||||||
|
||||||||
/// <summary> | ||||||||
/// Extra weight given to more recent updates. | ||||||||
|
@@ -86,7 +86,7 @@ public abstract class AveragedLinearOptions : OnlineLinearOptions | |||||||
/// Default is <see langword="false" />. | ||||||||
/// </value> | ||||||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm")] | ||||||||
public bool RecencyGainMulti = false; | ||||||||
public bool RecencyGainMultiplicative = false; | ||||||||
|
||||||||
/// <summary> | ||||||||
/// Determines whether to do averaging or not. | ||||||||
|
@@ -109,7 +109,7 @@ internal class AveragedDefault : OnlineLinearOptions.OnlineDefault | |||||||
{ | ||||||||
public const float LearningRate = 1; | ||||||||
public const bool DecreaseLearningRate = false; | ||||||||
public const float L2RegularizerWeight = 0; | ||||||||
public const float L2Regularization = 0; | ||||||||
} | ||||||||
|
||||||||
internal abstract IComponentFactory<IScalarOutputLoss> LossFunctionFactory { get; } | ||||||||
|
@@ -186,7 +186,7 @@ public override void FinishIteration(IChannel ch) | |||||||
// Finalize things | ||||||||
if (Averaged) | ||||||||
{ | ||||||||
if (_args.DoLazyUpdates && NumNoUpdates > 0) | ||||||||
if (_args.LazyUpdates && NumNoUpdates > 0) | ||||||||
{ | ||||||||
// Update the total weights to include the final loss=0 updates | ||||||||
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights); | ||||||||
|
@@ -221,10 +221,10 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl | |||||||
// REVIEW: Should this be biasUpdate != 0? | ||||||||
// This loss does not incorporate L2 if present, but the chance of that addition to the loss | ||||||||
// exactly cancelling out loss is remote. | ||||||||
if (loss != 0 || _args.L2RegularizerWeight > 0) | ||||||||
if (loss != 0 || _args.L2Regularization > 0) | ||||||||
{ | ||||||||
// If doing lazy weights, we need to update the totalWeights and totalBias before updating weights/bias | ||||||||
if (_args.DoLazyUpdates && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers) | ||||||||
if (_args.LazyUpdates && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers) | ||||||||
{ | ||||||||
VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights); | ||||||||
TotalBias += Bias * NumNoUpdates * WeightsScale; | ||||||||
|
@@ -242,7 +242,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl | |||||||
|
||||||||
// Perform the update to weights and bias. | ||||||||
VectorUtils.AddMult(in feat, biasUpdate / WeightsScale, ref Weights); | ||||||||
WeightsScale *= 1 - 2 * _args.L2RegularizerWeight; // L2 regularization. | ||||||||
WeightsScale *= 1 - 2 * _args.L2Regularization; // L2 regularization. | ||||||||
ScaleWeightsIfNeeded(); | ||||||||
Bias += biasUpdate; | ||||||||
PendingMultipliers += Math.Abs(biasUpdate); | ||||||||
|
@@ -251,7 +251,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, fl | |||||||
// Add to averaged weights and increment the count. | ||||||||
if (Averaged) | ||||||||
{ | ||||||||
if (!_args.DoLazyUpdates) | ||||||||
if (!_args.LazyUpdates) | ||||||||
IncrementAverageNonLazy(); | ||||||||
else | ||||||||
NumNoUpdates++; | ||||||||
|
@@ -282,7 +282,7 @@ private void IncrementAverageNonLazy() | |||||||
VectorUtils.AddMult(in Weights, Gain * WeightsScale, ref TotalWeights); | ||||||||
TotalBias += Gain * Bias; | ||||||||
NumWeightUpdates += Gain; | ||||||||
Gain = (_args.RecencyGainMulti ? Gain * _args.RecencyGain : Gain + _args.RecencyGain); | ||||||||
Gain = (_args.RecencyGainMultiplicative ? Gain * _args.RecencyGain : Gain + _args.RecencyGain); | ||||||||
|
||||||||
// If gains got too big, rescale! | ||||||||
if (Gain > 1000) | ||||||||
|
@@ -303,11 +303,11 @@ private protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvi | |||||||
Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive); | ||||||||
|
||||||||
// Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible. | ||||||||
Contracts.CheckUserArg(0 <= options.L2RegularizerWeight && options.L2RegularizerWeight < 0.5, nameof(options.L2RegularizerWeight), "must be in range [0, 0.5)"); | ||||||||
Contracts.CheckUserArg(0 <= options.L2Regularization && options.L2Regularization < 0.5, nameof(options.L2Regularization), "must be in range [0, 0.5)"); | ||||||||
Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative); | ||||||||
Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative); | ||||||||
// Verify user didn't specify parameters that conflict | ||||||||
Contracts.Check(!options.DoLazyUpdates || !options.RecencyGainMulti && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates."); | ||||||||
Contracts.Check(!options.LazyUpdates || !options.RecencyGainMultiplicative && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates."); | ||||||||
|
||||||||
AveragedLinearTrainerOptions = options; | ||||||||
} | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -126,23 +126,23 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) | |||||
/// <param name="featureColumnName">The name of the feature column.</param> | ||||||
/// <param name="learningRate">The learning rate. </param> | ||||||
/// <param name="decreaseLearningRate">Whether to decrease learning rate as iterations progress.</param> | ||||||
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param> | ||||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param> | ||||||
/// <param name="numIterations">The number of training iterations.</param> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
internal AveragedPerceptronTrainer(IHostEnvironment env, | ||||||
string labelColumnName = DefaultColumnNames.Label, | ||||||
string featureColumnName = DefaultColumnNames.Features, | ||||||
IClassificationLoss lossFunction = null, | ||||||
float learningRate = Options.AveragedDefault.LearningRate, | ||||||
bool decreaseLearningRate = Options.AveragedDefault.DecreaseLearningRate, | ||||||
float l2RegularizerWeight = Options.AveragedDefault.L2RegularizerWeight, | ||||||
float l2Regularization = Options.AveragedDefault.L2Regularization, | ||||||
int numIterations = Options.AveragedDefault.NumIterations) | ||||||
: this(env, new Options | ||||||
{ | ||||||
LabelColumnName = labelColumnName, | ||||||
FeatureColumnName = featureColumnName, | ||||||
LearningRate = learningRate, | ||||||
DecreaseLearningRate = decreaseLearningRate, | ||||||
L2RegularizerWeight = l2RegularizerWeight, | ||||||
L2Regularization = l2Regularization, | ||||||
NumberOfIterations = numIterations, | ||||||
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) | ||||||
}) | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -95,22 +95,22 @@ public override LinearRegressionModelParameters CreatePredictor() | |||||
/// <param name="featureColumn">Name of the feature column.</param> | ||||||
/// <param name="learningRate">The learning Rate.</param> | ||||||
/// <param name="decreaseLearningRate">Decrease learning rate as iterations progress.</param> | ||||||
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param> | ||||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param> | ||||||
/// <param name="numIterations">Number of training iterations through the data.</param> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please also check other online learners' APIs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any stop rule in In reply to: 263972702 [](ancestors = 263972702) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am too lazy to dig into the code, but my suggestion just reflects this parameter's description. In reply to: 263975389 [](ancestors = 263975389,263972702) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But you suggest same description as I have right now.... In reply to: 264346723 [](ancestors = 264346723,263975389,263972702) |
||||||
/// <param name="lossFunction">The custom loss functions. Defaults to <see cref="SquaredLoss"/> if not provided.</param> | ||||||
internal OnlineGradientDescentTrainer(IHostEnvironment env, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
A descent algorithm ensures the decrease of function value per iteration. However, this is not true for most stochastic gradient learners. #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
string labelColumn = DefaultColumnNames.Label, | ||||||
string featureColumn = DefaultColumnNames.Features, | ||||||
float learningRate = Options.OgdDefaultArgs.LearningRate, | ||||||
bool decreaseLearningRate = Options.OgdDefaultArgs.DecreaseLearningRate, | ||||||
float l2RegularizerWeight = Options.OgdDefaultArgs.L2RegularizerWeight, | ||||||
float l2Regularization = Options.OgdDefaultArgs.L2Regularization, | ||||||
int numIterations = Options.OgdDefaultArgs.NumIterations, | ||||||
IRegressionLoss lossFunction = null) | ||||||
: this(env, new Options | ||||||
{ | ||||||
LearningRate = learningRate, | ||||||
DecreaseLearningRate = decreaseLearningRate, | ||||||
L2RegularizerWeight = l2RegularizerWeight, | ||||||
L2Regularization= l2Regularization, | ||||||
NumberOfIterations = numIterations, | ||||||
LabelColumnName = labelColumn, | ||||||
FeatureColumnName = featureColumn, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.