Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
najeeb-kazmi committed Dec 11, 2018
1 parent f4e35b2 commit a5aa3b9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
28 changes: 14 additions & 14 deletions src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
"frrank",
"btrank")]

[assembly: LoadableClass(typeof(FastTreerankingModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(FastTreeRankingModelParameters), null, typeof(SignatureLoadModel),
"FastTree Ranking Executor",
FastTreerankingModelParameters.LoaderSignature)]
FastTreeRankingModelParameters.LoaderSignature)]

[assembly: LoadableClass(typeof(void), typeof(FastTree), null, typeof(SignatureEntryPointModule), "FastTree")]

namespace Microsoft.ML.Trainers.FastTree
{
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
public sealed partial class FastTreeRankingTrainer
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Arguments, RankingPredictionTransformer<FastTreerankingModelParameters>, FastTreerankingModelParameters>
: BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Arguments, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
{
internal const string LoadNameValue = "FastTreeRanking";
internal const string UserNameValue = "FastTree (Boosted Trees) Ranking";
Expand Down Expand Up @@ -112,7 +112,7 @@ protected override float GetMaxLabel()
return GetLabelGains().Length - 1;
}

private protected override FastTreerankingModelParameters TrainModelCore(TrainContext context)
private protected override FastTreeRankingModelParameters TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
Expand All @@ -126,7 +126,7 @@ private protected override FastTreerankingModelParameters TrainModelCore(TrainCo
TrainCore(ch);
FeatureCount = trainData.Schema.Feature.Type.ValueCount;
}
return new FastTreerankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs);
return new FastTreeRankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs);
}

private Double[] GetLabelGains()
Expand Down Expand Up @@ -454,10 +454,10 @@ protected override string GetTestGraphHeader()
return headerBuilder.ToString();
}

protected override RankingPredictionTransformer<FastTreerankingModelParameters> MakeTransformer(FastTreerankingModelParameters model, Schema trainSchema)
=> new RankingPredictionTransformer<FastTreerankingModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
protected override RankingPredictionTransformer<FastTreeRankingModelParameters> MakeTransformer(FastTreeRankingModelParameters model, Schema trainSchema)
=> new RankingPredictionTransformer<FastTreeRankingModelParameters>(Host, model, trainSchema, FeatureColumn.Name);

public RankingPredictionTransformer<FastTreerankingModelParameters> Train(IDataView trainData, IDataView validationData = null)
public RankingPredictionTransformer<FastTreeRankingModelParameters> Train(IDataView trainData, IDataView validationData = null)
=> TrainTransformer(trainData, validationData);

protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand Down Expand Up @@ -1104,7 +1104,7 @@ private static extern unsafe void GetDerivatives(
}
}

public sealed class FastTreerankingModelParameters : TreeEnsembleModelParameters
public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParameters
{
internal const string LoaderSignature = "FastTreeRankerExec";
internal const string RegistrationName = "FastTreeRankingPredictor";
Expand All @@ -1121,7 +1121,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(FastTreerankingModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(FastTreeRankingModelParameters).Assembly.FullName);
}

protected override uint VerNumFeaturesSerialized => 0x00010002;
Expand All @@ -1130,12 +1130,12 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010005;

public FastTreerankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
public FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
}

private FastTreerankingModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private FastTreeRankingModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx, GetVersionInfo())
{
}
Expand All @@ -1146,9 +1146,9 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreerankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
private static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new FastTreerankingModelParameters(env, ctx);
return new FastTreeRankingModelParameters(env, ctx);
}

public override PredictionKind PredictionKind => PredictionKind.Ranking;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
case PredictionKind.Regression:
return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null);
case PredictionKind.Ranking:
return new FastTreerankingModelParameters(_host, ensemble, featureCount, null);
return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null);
default:
_host.Assert(false);
throw _host.ExceptNotSupp();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/TreeTrainersStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c
int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves,
double learningRate = Defaults.LearningRates,
Action<FastTreeRankingTrainer.Arguments> advancedSettings = null,
Action<FastTreerankingModelParameters> onFit = null)
Action<FastTreeRankingModelParameters> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit);

Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ public void FastTreeRanking()
c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)),
separator: '\t', hasHeader: true);

FastTreerankingModelParameters pred = null;
FastTreeRankingModelParameters pred = null;

var est = reader.MakeNewEstimator()
.Append(r => (r.label, r.features, groupId: r.groupId.ToKey()))
Expand Down

0 comments on commit a5aa3b9

Please sign in to comment.