Skip to content

Commit

Permalink
Scrub static LDA
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Mar 8, 2019
1 parent 9b84259 commit a433ef2
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 50 deletions.
2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.Samples/Dynamic/LdaTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Microsoft.ML.Samples.Dynamic
{
public static class LdaTransform
public static class LatentDirichletAllocationTransform
{
public static void Example()
{
Expand Down
94 changes: 47 additions & 47 deletions src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,65 @@ namespace Microsoft.ML.StaticPipe
/// <summary>
/// Information on the result of fitting a LDA transform.
/// </summary>
public sealed class LdaFitResult
public sealed class LatentDirichletAllocationFitResult
{
/// <summary>
/// For user defined delegates that accept instances of the containing type.
/// </summary>
/// <param name="result"></param>
public delegate void OnFit(LdaFitResult result);
public delegate void OnFit(LatentDirichletAllocationFitResult result);

public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
public LatentDirichletAllocationFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
{
LdaTopicSummary = ldaTopicSummary;
}
}

public static class LdaStaticExtensions
public static class LatentDirichletAllocationStaticExtensions
{
private struct Config
{
public readonly int NumTopic;
public readonly int NumberOfTopics;
public readonly Single AlphaSum;
public readonly Single Beta;
public readonly int MHStep;
public readonly int NumIter;
public readonly int SamplingStepCount;
public readonly int MaximumNumberOfIterations;
public readonly int LikelihoodInterval;
public readonly int NumThread;
public readonly int NumMaxDocToken;
public readonly int NumSummaryTermPerTopic;
public readonly int NumBurninIter;
public readonly int NumberOfThreads;
public readonly int MaximumTokenCountPerDocument;
public readonly int NumberOfSummaryTermsPerTopic;
public readonly int NumberOfBurninIterations;
public readonly bool ResetRandomGenerator;

public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;

public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval,
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator,
public Config(int numberOfTopics, Single alphaSum, Single beta, int samplingStepCount, int maximumNumberOfIterations, int likelihoodInterval,
int numberOfThreads, int maximumTokenCountPerDocument, int numberOfSummaryTermsPerTopic, int numberOfBurninIterations, bool resetRandomGenerator,
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
{
NumTopic = numTopic;
NumberOfTopics = numberOfTopics;
AlphaSum = alphaSum;
Beta = beta;
MHStep = mhStep;
NumIter = numIter;
SamplingStepCount = samplingStepCount;
MaximumNumberOfIterations = maximumNumberOfIterations;
LikelihoodInterval = likelihoodInterval;
NumThread = numThread;
NumMaxDocToken = numMaxDocToken;
NumSummaryTermPerTopic = numSummaryTermPerTopic;
NumBurninIter = numBurninIter;
NumberOfThreads = numberOfThreads;
MaximumTokenCountPerDocument = maximumTokenCountPerDocument;
NumberOfSummaryTermsPerTopic = numberOfSummaryTermsPerTopic;
NumberOfBurninIterations = numberOfBurninIterations;
ResetRandomGenerator = resetRandomGenerator;

OnFit = onFit;
}
}

private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit)
private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LatentDirichletAllocationFitResult.OnFit onFit)
{
if (onFit == null)
return null;

return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary));
return ldaTopicSummary => onFit(new LatentDirichletAllocationFitResult(ldaTopicSummary));
}

private interface ILdaCol
Expand Down Expand Up @@ -107,16 +107,16 @@ private sealed class Rec : EstimatorReconciler

infos[i] = new LatentDirichletAllocationEstimator.ColumnOptions(outputNames[toOutput[i]],
inputNames[tcol.Input],
tcol.Config.NumTopic,
tcol.Config.NumberOfTopics,
tcol.Config.AlphaSum,
tcol.Config.Beta,
tcol.Config.MHStep,
tcol.Config.NumIter,
tcol.Config.SamplingStepCount,
tcol.Config.MaximumNumberOfIterations,
tcol.Config.LikelihoodInterval,
tcol.Config.NumThread,
tcol.Config.NumMaxDocToken,
tcol.Config.NumSummaryTermPerTopic,
tcol.Config.NumBurninIter,
tcol.Config.NumberOfThreads,
tcol.Config.MaximumTokenCountPerDocument,
tcol.Config.NumberOfSummaryTermsPerTopic,
tcol.Config.NumberOfBurninIterations,
tcol.Config.ResetRandomGenerator);

if (tcol.Config.OnFit != null)
Expand All @@ -136,36 +136,36 @@ private sealed class Rec : EstimatorReconciler

/// <include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
/// <param name="input">A vector of floats representing the document.</param>
/// <param name="numTopic">The number of topics.</param>
/// <param name="numberOfTopics">The number of topics.</param>
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
/// <param name="mhstep">Number of Metropolis Hasting step.</param>
/// <param name="numIterations">Number of iterations.</param>
/// <param name="samplingStepCount">Number of Metropolis Hasting step.</param>
/// <param name="maximumNumberOfIterations">Number of iterations.</param>
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
/// <param name="numBurninIterations">The number of burn-in iterations.</param>
/// <param name="numberOfThreads">The number of training threads. Default value depends on number of logical processors.</param>
/// <param name="maximumTokenCountPerDocument">The threshold of maximum count of tokens per doc.</param>
/// <param name="numberOfSummaryTermsPerTopic">The number of words to summarize the topic.</param>
/// <param name="numberOfBurninIterations">The number of burn-in iterations.</param>
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
public static Vector<float> ToLdaTopicVector(this Vector<float> input,
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumberOfTopics,
public static Vector<float> ToLatentDirichletAllocationTopicVector(this Vector<float> input,
int numberOfTopics = LatentDirichletAllocationEstimator.Defaults.NumberOfTopics,
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta,
int mhstep = LatentDirichletAllocationEstimator.Defaults.SamplingStepCount,
int numIterations = LatentDirichletAllocationEstimator.Defaults.MaximumNumberOfIterations,
int samplingStepCount = LatentDirichletAllocationEstimator.Defaults.SamplingStepCount,
int maximumNumberOfIterations = LatentDirichletAllocationEstimator.Defaults.MaximumNumberOfIterations,
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads,
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
int numberOfThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads,
int maximumTokenCountPerDocument = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
int numberOfSummaryTermsPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
int numberOfBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator,
LdaFitResult.OnFit onFit = null)
LatentDirichletAllocationFitResult.OnFit onFit = null)
{
Contracts.CheckValue(input, nameof(input));
return new ImplVector(input,
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic,
numBurninIterations, resetRandomGenerator, Wrap(onFit)));
new Config(numberOfTopics, alphaSum, beta, samplingStepCount, maximumNumberOfIterations, likelihoodInterval, numberOfThreads, maximumTokenCountPerDocument, numberOfSummaryTermsPerTopic,
numberOfBurninIterations, resetRandomGenerator, Wrap(onFit)));
}
}
}
1 change: 0 additions & 1 deletion src/Microsoft.ML.Transforms/Text/LdaTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.TextAnalytics;
using Microsoft.ML.Transforms.Text;

Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ public void LdaTopicModel()
var est = data.MakeNewEstimator()
.Append(r => (
r.label,
topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaSummary = m.LdaTopicSummary)));
topics: r.text.ToBagofWords().ToLatentDirichletAllocationTopicVector(numberOfTopics: 3, numberOfSummaryTermsPerTopic:5, alphaSum: 10, onFit: m => ldaSummary = m.LdaTopicSummary)));

var transformer = est.Fit(data);
var tdata = transformer.Transform(data);
Expand Down

0 comments on commit a433ef2

Please sign in to comment.