Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Arguments()
BasePredictors = new[]
{
ComponentFactoryUtils.CreateFromFunction(
env => new LinearSvm(env))
env => new LinearSvmTrainer(env))
};
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private TScalarTrainer CreateTrainer()
{
return Args.PredictorType != null ?
Args.PredictorType.CreateComponent(Host) :
new LinearSvm(Host, new LinearSvm.Arguments());
new LinearSvmTrainer(Host, new LinearSvmTrainer.Arguments());
}

private protected IDataView MapLabelsCore<T>(ColumnType type, InPredicate<T> equalsTarget, RoleMappedData data)
Expand Down
78 changes: 36 additions & 42 deletions src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
using Microsoft.ML.Numeric;
using Microsoft.ML.Trainers.Online;
using Microsoft.ML.Training;
using Float = System.Single;

[assembly: LoadableClass(LinearSvm.Summary, typeof(LinearSvm), typeof(LinearSvm.Arguments),
[assembly: LoadableClass(LinearSvmTrainer.Summary, typeof(LinearSvmTrainer), typeof(LinearSvmTrainer.Arguments),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
LinearSvm.UserNameValue,
LinearSvm.LoadNameValue,
LinearSvm.ShortName)]
LinearSvmTrainer.UserNameValue,
LinearSvmTrainer.LoadNameValue,
LinearSvmTrainer.ShortName)]

[assembly: LoadableClass(typeof(void), typeof(LinearSvm), null, typeof(SignatureEntryPointModule), "LinearSvm")]
[assembly: LoadableClass(typeof(void), typeof(LinearSvmTrainer), null, typeof(SignatureEntryPointModule), "LinearSvm")]

namespace Microsoft.ML.Trainers.Online
{
/// <summary>
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
/// </summary>
public sealed class LinearSvm : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryModelParameters>, LinearBinaryModelParameters>
public sealed class LinearSvmTrainer : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryModelParameters>, LinearBinaryModelParameters>
{
internal const string LoadNameValue = "LinearSVM";
internal const string ShortName = "svm";
Expand All @@ -47,7 +46,7 @@ public sealed class Arguments : OnlineLinearArguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
[TGUI(SuggestedSweeps = "0.00001-0.1;log;inc:10")]
[TlcModule.SweepableFloatParamAttribute("Lambda", 0.00001f, 0.1f, 10, isLogScale: true)]
public Float Lambda = (Float)0.001;
public float Lambda = 0.001f;

[Argument(ArgumentType.AtMostOnce, HelpText = "Batch size", ShortName = "batch", SortOrder = 190)]
[TGUI(Label = "Batch Size")]
Expand Down Expand Up @@ -78,16 +77,16 @@ private sealed class TrainState : TrainStateBase
// weightsUpdate/weightsUpdateScale/biasUpdate are similar to weights/weightsScale/bias, in that
// all elements of weightsUpdate are considered to be multiplied by weightsUpdateScale, and the
// bias update term is not considered to be multiplied by the scale.
private VBuffer<Float> _weightsUpdate;
private Float _weightsUpdateScale;
private Float _biasUpdate;
private VBuffer<float> _weightsUpdate;
private float _weightsUpdateScale;
private float _biasUpdate;

private readonly int _batchSize;
private readonly bool _noBias;
private readonly bool _performProjection;
private readonly float _lambda;

public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, LinearSvm parent)
public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, LinearSvmTrainer parent)
: base(ch, numFeatures, predictor, parent)
{
_batchSize = parent.Args.BatchSize;
Expand All @@ -101,7 +100,7 @@ public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor,
if (predictor == null)
VBufferUtils.Densify(ref Weights);

_weightsUpdate = VBufferUtils.CreateEmpty<Float>(numFeatures);
_weightsUpdate = VBufferUtils.CreateEmpty<float>(numFeatures);

}

Expand All @@ -119,7 +118,7 @@ private void BeginBatch()
VBufferUtils.Resize(ref _weightsUpdate, _weightsUpdate.Length, 0);
}

private void FinishBatch(in VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
private void FinishBatch(in VBuffer<float> weightsUpdate, float weightsUpdateScale)
{
if (_numBatchExamples > 0)
UpdateWeights(in weightsUpdate, weightsUpdateScale);
Expand All @@ -129,19 +128,19 @@ private void FinishBatch(in VBuffer<Float> weightsUpdate, Float weightsUpdateSca
/// <summary>
/// Observe an example and update weights if necesary.
/// </summary>
public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Float label, Float weight)
public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, float label, float weight)
{
base.ProcessDataInstance(ch, in feat, label, weight);

// compute the update and update if needed
Float output = Margin(in feat);
Float trueOutput = (label > 0 ? 1 : -1);
Float loss = output * trueOutput - 1;
float output = Margin(in feat);
float trueOutput = (label > 0 ? 1 : -1);
float loss = output * trueOutput - 1;

// Accumulate the update if there is a loss and we have larger batches.
if (_batchSize > 1 && loss < 0)
{
Float currentBiasUpdate = trueOutput * weight;
float currentBiasUpdate = trueOutput * weight;
_biasUpdate += currentBiasUpdate;
// Only aggregate in the case where we're handling multiple instances.
if (_weightsUpdate.GetValues().Length == 0)
Expand All @@ -160,7 +159,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
Contracts.Assert(_weightsUpdate.GetValues().Length == 0);
// If we aren't aggregating multiple instances, just use the instance's
// vector directly.
Float currentBiasUpdate = trueOutput * weight;
float currentBiasUpdate = trueOutput * weight;
_biasUpdate += currentBiasUpdate;
FinishBatch(in feat, currentBiasUpdate);
}
Expand All @@ -174,13 +173,13 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
/// Updates the weights at the end of the batch. Since weightsUpdate can be an instance
/// feature vector, this function should not change the contents of weightsUpdate.
/// </summary>
private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
private void UpdateWeights(in VBuffer<float> weightsUpdate, float weightsUpdateScale)
{
Contracts.Assert(_batch > 0);

// REVIEW: This is really odd - normally lambda is small, so the learning rate is initially huge!?!?!
// Changed from the paper's recommended rate = 1 / (lambda * t) to rate = 1 / (1 + lambda * t).
Float rate = 1 / (1 + _lambda * _batch);
float rate = 1 / (1 + _lambda * _batch);

// w_{t+1/2} = (1 - eta*lambda) w_t + eta/k * totalUpdate
WeightsScale *= 1 - rate * _lambda;
Expand All @@ -194,7 +193,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
// w_{t+1} = min{1, 1/sqrt(lambda)/|w_{t+1/2}|} * w_{t+1/2}
if (_performProjection)
{
Float normalizer = 1 / (MathUtils.Sqrt(_lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
float normalizer = 1 / (MathUtils.Sqrt(_lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
if (normalizer < 1)
{
// REVIEW: Why would we not scale _bias if we're scaling the weights?
Expand All @@ -208,7 +207,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
/// <summary>
/// Return the raw margin from the decision hyperplane.
/// </summary>
public override Float Margin(in VBuffer<Float> feat)
public override float Margin(in VBuffer<float> feat)
=> Bias + VectorUtils.DotProduct(in feat, in Weights) * WeightsScale;

public override LinearBinaryModelParameters CreatePredictor()
Expand All @@ -222,21 +221,21 @@ public override LinearBinaryModelParameters CreatePredictor()
protected override bool NeedCalibration => true;

/// <summary>
/// Initializes a new instance of <see cref="LinearSvm"/>.
/// Initializes a new instance of <see cref="LinearSvmTrainer"/>.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column. </param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weightsColumn">The optional name of the weights column.</param>
/// <param name="numIterations">The number of training iteraitons.</param>
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
public LinearSvm(IHostEnvironment env,
public LinearSvmTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightsColumn = null,
int numIterations = Arguments.OnlineDefaultArgs.NumIterations,
Action<Arguments> advancedSettings = null)
:this(env, InvokeAdvanced(advancedSettings, new Arguments
: this(env, InvokeAdvanced(advancedSettings, new Arguments
{
LabelColumn = labelColumn,
FeatureColumn = featureColumn,
Expand All @@ -246,8 +245,8 @@ public LinearSvm(IHostEnvironment env,
{
}

internal LinearSvm(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
internal LinearSvmTrainer(IHostEnvironment env, Arguments args)
: base(args, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
{
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
Expand All @@ -261,9 +260,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
{
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
};
}

Expand All @@ -274,14 +272,7 @@ private protected override void CheckLabels(RoleMappedData data)
}

private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
{
return new TrainState(ch, numFeatures, predictor, this);
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
{
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
}
=> new TrainState(ch, numFeatures, predictor, this);

[TlcModule.EntryPoint(Name = "Trainers.LinearSvmBinaryClassifier", Desc = "Train a linear SVM.", UserName = UserNameValue, ShortName = ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvironment env, Arguments input)
Expand All @@ -292,12 +283,15 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new LinearSvm(host, input),
() => new LinearSvmTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}

protected override BinaryPredictionTransformer<LinearBinaryModelParameters> MakeTransformer(LinearBinaryModelParameters model, Schema trainSchema)
=> new BinaryPredictionTransformer<LinearBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
=> new BinaryPredictionTransformer<LinearBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);

public BinaryPredictionTransformer<LinearBinaryModelParameters> Train(IDataView trainData, IPredictor initialPredictor = null)
=> TrainTransformer(trainData, initPredictor: initialPredictor);
}
}
8 changes: 4 additions & 4 deletions src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.Multicl
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvm"/> trainer.
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvmTrainer"/> trainer.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -403,15 +403,15 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.Multicl
/// <param name="weightsColumn">The optional name of the weights column.</param>
/// <param name="numIterations">The number of training iteraitons.</param>
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
public static LinearSvm LinearSupportVectorMachines(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightsColumn = null,
int numIterations = OnlineLinearArguments.OnlineDefaultArgs.NumIterations,
Action<LinearSvm.Arguments> advancedSettings = null)
Action<LinearSvmTrainer.Arguments> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
return new LinearSvm(CatalogUtils.GetEnvironment(ctx), labelColumn, featureColumn, weightsColumn, numIterations, advancedSettings);
return new LinearSvmTrainer(CatalogUtils.GetEnvironment(ctx), labelColumn, featureColumn, weightsColumn, numIterations, advancedSettings);
}
}
}
2 changes: 1 addition & 1 deletion test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model.
Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.LightGBM.LightGbm TrainMultiClass Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.LightGBM.LightGbm TrainRanking Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.LightGBM.LightGbm TrainRegression Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvm TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvm+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvmTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainBinary Microsoft.ML.Learners.LogisticRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainMultiClass Microsoft.ML.Learners.MulticlassLogisticRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
Trainers.NaiveBayesClassifier Train a MultiClassNaiveBayesTrainer. Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer TrainMultiClassNaiveBayesTrainer Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/Scenarios/OvaTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void OvaLinearSvm()
var data = mlContext.Data.Cache(reader.Read(GetDataPath(dataPath)));

// Pipeline
var pipeline = new Ova(mlContext, new LinearSvm(mlContext, numIterations: 100), useProbabilities: false);
var pipeline = new Ova(mlContext, new LinearSvmTrainer(mlContext, numIterations: 100), useProbabilities: false);

var model = pipeline.Fit(data);
var predictions = model.Transform(data);
Expand Down
Loading