1+ // Licensed to the .NET Foundation under one or more agreements.
2+ // The .NET Foundation licenses this file to you under the MIT license.
3+ // See the LICENSE file in the project root for more information.
4+
5+ using Microsoft . ML . Core . Data ;
6+ using Microsoft . ML . Runtime ;
7+ using Microsoft . ML . Runtime . Data ;
8+ using Microsoft . ML . StaticPipe . Runtime ;
9+ using Microsoft . ML . Transforms . Text ;
10+ using System ;
11+ using System . Collections . Generic ;
12+
13+ namespace Microsoft . ML . StaticPipe
14+ {
15+ /// <summary>
16+ /// Information on the result of fitting a LDA transform.
17+ /// </summary>
18+ public sealed class LdaFitResult
19+ {
20+ /// <summary>
21+ /// For user defined delegates that accept instances of the containing type.
22+ /// </summary>
23+ /// <param name="result"></param>
24+ public delegate void OnFit ( LdaFitResult result ) ;
25+
26+ public LatentDirichletAllocationTransformer . LdaSummary LdaTopicSummary ;
27+ public LdaFitResult ( LatentDirichletAllocationTransformer . LdaSummary ldaTopicSummary )
28+ {
29+ LdaTopicSummary = ldaTopicSummary ;
30+ }
31+ }
32+
33+ public static class LdaStaticExtensions
34+ {
35+ private struct Config
36+ {
37+ public readonly int NumTopic ;
38+ public readonly Single AlphaSum ;
39+ public readonly Single Beta ;
40+ public readonly int MHStep ;
41+ public readonly int NumIter ;
42+ public readonly int LikelihoodInterval ;
43+ public readonly int NumThread ;
44+ public readonly int NumMaxDocToken ;
45+ public readonly int NumSummaryTermPerTopic ;
46+ public readonly int NumBurninIter ;
47+ public readonly bool ResetRandomGenerator ;
48+
49+ public readonly Action < LatentDirichletAllocationTransformer . LdaSummary > OnFit ;
50+
51+ public Config ( int numTopic , Single alphaSum , Single beta , int mhStep , int numIter , int likelihoodInterval ,
52+ int numThread , int numMaxDocToken , int numSummaryTermPerTopic , int numBurninIter , bool resetRandomGenerator ,
53+ Action < LatentDirichletAllocationTransformer . LdaSummary > onFit )
54+ {
55+ NumTopic = numTopic ;
56+ AlphaSum = alphaSum ;
57+ Beta = beta ;
58+ MHStep = mhStep ;
59+ NumIter = numIter ;
60+ LikelihoodInterval = likelihoodInterval ;
61+ NumThread = numThread ;
62+ NumMaxDocToken = numMaxDocToken ;
63+ NumSummaryTermPerTopic = numSummaryTermPerTopic ;
64+ NumBurninIter = numBurninIter ;
65+ ResetRandomGenerator = resetRandomGenerator ;
66+
67+ OnFit = onFit ;
68+ }
69+ }
70+
71+ private static Action < LatentDirichletAllocationTransformer . LdaSummary > Wrap ( LdaFitResult . OnFit onFit )
72+ {
73+ if ( onFit == null )
74+ return null ;
75+
76+ return ldaTopicSummary => onFit ( new LdaFitResult ( ldaTopicSummary ) ) ;
77+ }
78+
79+ private interface ILdaCol
80+ {
81+ PipelineColumn Input { get ; }
82+ Config Config { get ; }
83+ }
84+
85+ private sealed class ImplVector : Vector < float > , ILdaCol
86+ {
87+ public PipelineColumn Input { get ; }
88+ public Config Config { get ; }
89+ public ImplVector ( PipelineColumn input , Config config ) : base ( Rec . Inst , input )
90+ {
91+ Input = input ;
92+ Config = config ;
93+ }
94+ }
95+
96+ private sealed class Rec : EstimatorReconciler
97+ {
98+ public static readonly Rec Inst = new Rec ( ) ;
99+
100+ public override IEstimator < ITransformer > Reconcile ( IHostEnvironment env ,
101+ PipelineColumn [ ] toOutput ,
102+ IReadOnlyDictionary < PipelineColumn , string > inputNames ,
103+ IReadOnlyDictionary < PipelineColumn , string > outputNames ,
104+ IReadOnlyCollection < string > usedNames )
105+ {
106+ var infos = new LatentDirichletAllocationTransformer . ColumnInfo [ toOutput . Length ] ;
107+ Action < LatentDirichletAllocationTransformer > onFit = null ;
108+ for ( int i = 0 ; i < toOutput . Length ; ++ i )
109+ {
110+ var tcol = ( ILdaCol ) toOutput [ i ] ;
111+
112+ infos [ i ] = new LatentDirichletAllocationTransformer . ColumnInfo ( inputNames [ tcol . Input ] , outputNames [ toOutput [ i ] ] ,
113+ tcol . Config . NumTopic ,
114+ tcol . Config . AlphaSum ,
115+ tcol . Config . Beta ,
116+ tcol . Config . MHStep ,
117+ tcol . Config . NumIter ,
118+ tcol . Config . LikelihoodInterval ,
119+ tcol . Config . NumThread ,
120+ tcol . Config . NumMaxDocToken ,
121+ tcol . Config . NumSummaryTermPerTopic ,
122+ tcol . Config . NumBurninIter ,
123+ tcol . Config . ResetRandomGenerator ) ;
124+
125+ if ( tcol . Config . OnFit != null )
126+ {
127+ int ii = i ; // Necessary because if we capture i that will change to toOutput.Length on call.
128+ onFit += tt => tcol . Config . OnFit ( tt . GetLdaDetails ( ii ) ) ;
129+ }
130+ }
131+
132+ var est = new LatentDirichletAllocationEstimator ( env , infos ) ;
133+ if ( onFit == null )
134+ return est ;
135+
136+ return est . WithOnFitDelegate ( onFit ) ;
137+ }
138+ }
139+
140+ /// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
141+ /// <param name="input">A vector of floats representing the document.</param>
142+ /// <param name="numTopic">The number of topics.</param>
143+ /// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
144+ /// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
145+ /// <param name="mhstep">Number of Metropolis Hasting step.</param>
146+ /// <param name="numIterations">Number of iterations.</param>
147+ /// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
148+ /// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
149+ /// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
150+ /// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
151+ /// <param name="numBurninIterations">The number of burn-in iterations.</param>
152+ /// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
153+ /// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
154+ public static Vector < float > ToLdaTopicVector ( this Vector < float > input ,
155+ int numTopic = LatentDirichletAllocationEstimator . Defaults . NumTopic ,
156+ Single alphaSum = LatentDirichletAllocationEstimator . Defaults . AlphaSum ,
157+ Single beta = LatentDirichletAllocationEstimator . Defaults . Beta ,
158+ int mhstep = LatentDirichletAllocationEstimator . Defaults . Mhstep ,
159+ int numIterations = LatentDirichletAllocationEstimator . Defaults . NumIterations ,
160+ int likelihoodInterval = LatentDirichletAllocationEstimator . Defaults . LikelihoodInterval ,
161+ int numThreads = LatentDirichletAllocationEstimator . Defaults . NumThreads ,
162+ int numMaxDocToken = LatentDirichletAllocationEstimator . Defaults . NumMaxDocToken ,
163+ int numSummaryTermPerTopic = LatentDirichletAllocationEstimator . Defaults . NumSummaryTermPerTopic ,
164+ int numBurninIterations = LatentDirichletAllocationEstimator . Defaults . NumBurninIterations ,
165+ bool resetRandomGenerator = LatentDirichletAllocationEstimator . Defaults . ResetRandomGenerator ,
166+ LdaFitResult . OnFit onFit = null )
167+ {
168+ Contracts . CheckValue ( input , nameof ( input ) ) ;
169+ return new ImplVector ( input ,
170+ new Config ( numTopic , alphaSum , beta , mhstep , numIterations , likelihoodInterval , numThreads , numMaxDocToken , numSummaryTermPerTopic ,
171+ numBurninIterations , resetRandomGenerator , Wrap ( onFit ) ) ) ;
172+ }
173+ }
174+ }
0 commit comments