@@ -165,11 +165,13 @@ private static VersionInfo GetVersionInfo()
165165 // verWrittenCur: 0x00010002, // Invert hash key values, hash fix
166166 verWrittenCur : 0x00010003 , // Get rid of writing float size in model context and change saving format
167167 verReadableCur : 0x00010003 ,
168- verWeCanReadBack : 0x00010003 ,
168+ verWeCanReadBack : 0x00010002 ,
169169 loaderSignature : LoaderSignature ,
170170 loaderAssemblyName : typeof ( NgramHashingTransformer ) . Assembly . FullName ) ;
171171 }
172172
173+ private const int VersionTransformer = 0x00010003 ;
174+
173175 /// <summary>
174176 /// Describes how the transformer handles one pair of mulitple inputs - singular output columns.
175177 /// </summary>
@@ -242,6 +244,7 @@ public ColumnInfo(string[] inputs, string output,
242244 InvertHash = invertHash ;
243245 RehashUnigrams = rehashUnigrams ;
244246 }
247+
245248 internal ColumnInfo ( ModelLoadContext ctx )
246249 {
247250 Contracts . AssertValue ( ctx ) ;
@@ -275,6 +278,36 @@ internal ColumnInfo(ModelLoadContext ctx)
275278 AllLengths = ctx . Reader . ReadBoolByte ( ) ;
276279 }
277280
281+ internal ColumnInfo ( ModelLoadContext ctx , string [ ] inputs , string output )
282+ {
283+ Contracts . AssertValue ( ctx ) ;
284+ Contracts . CheckValue ( inputs , nameof ( inputs ) ) ;
285+ Contracts . CheckParam ( ! inputs . Any ( r => string . IsNullOrWhiteSpace ( r ) ) , nameof ( inputs ) ,
286+ "Contained some null or empty items" ) ;
287+ Inputs = inputs ;
288+ Output = output ;
289+ // *** Binary format ***
290+ // string Output;
291+ // int: NgramLength
292+ // int: SkipLength
293+ // int: HashBits
294+ // uint: Seed
295+ // byte: Rehash
296+ // byte: Ordered
297+ // byte: AllLengths
298+ NgramLength = ctx . Reader . ReadInt32 ( ) ;
299+ Contracts . CheckDecode ( 0 < NgramLength && NgramLength <= NgramBufferBuilder . MaxSkipNgramLength ) ;
300+ SkipLength = ctx . Reader . ReadInt32 ( ) ;
301+ Contracts . CheckDecode ( 0 <= SkipLength && SkipLength <= NgramBufferBuilder . MaxSkipNgramLength ) ;
302+ Contracts . CheckDecode ( SkipLength <= NgramBufferBuilder . MaxSkipNgramLength - NgramLength ) ;
303+ HashBits = ctx . Reader . ReadInt32 ( ) ;
304+ Contracts . CheckDecode ( 1 <= HashBits && HashBits <= 30 ) ;
305+ Seed = ctx . Reader . ReadUInt32 ( ) ;
306+ RehashUnigrams = ctx . Reader . ReadBoolByte ( ) ;
307+ Ordered = ctx . Reader . ReadBoolByte ( ) ;
308+ AllLengths = ctx . Reader . ReadBoolByte ( ) ;
309+ }
310+
278311 internal void Save ( ModelSaveContext ctx )
279312 {
280313 Contracts . AssertValue ( ctx ) ;
@@ -416,19 +449,56 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
416449 private static IRowMapper Create ( IHostEnvironment env , ModelLoadContext ctx , Schema inputSchema )
417450 => Create ( env , ctx ) . MakeRowMapper ( inputSchema ) ;
418451
419- private NgramHashingTransformer ( IHostEnvironment env , ModelLoadContext ctx ) :
452+ private NgramHashingTransformer ( IHostEnvironment env , ModelLoadContext ctx , bool loadLegacy = false ) :
420453 base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( NgramHashingTransformer ) ) )
421454 {
422455 Host . CheckValue ( ctx , nameof ( ctx ) ) ;
423- ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
456+ if ( loadLegacy )
457+ {
458+ int cbFloat = ctx . Reader . ReadInt32 ( ) ;
459+ Host . CheckDecode ( cbFloat == sizeof ( float ) ) ;
460+ }
424461 var columnsLength = ctx . Reader . ReadInt32 ( ) ;
462+ Contracts . CheckDecode ( columnsLength > 0 ) ;
425463 var columns = new ColumnInfo [ columnsLength ] ;
464+ if ( ! loadLegacy )
465+ {
466+ // *** Binary format ***
467+ // int number of columns
468+ // columns
469+ for ( int i = 0 ; i < columnsLength ; i ++ )
470+ columns [ i ] = new ColumnInfo ( ctx ) ;
471+ }
472+ else
473+ {
474+ // *** Binary format ***
475+ // int: number of added columns
476+ // for each added column
477+ // int: id of output column name
478+ // int: number of input column names
479+ // int[]: ids of input column names
480+ var outputs = new string [ columnsLength ] ;
481+ var inputs = new string [ columnsLength ] [ ] ;
482+ for ( int i = 0 ; i < columnsLength ; i ++ )
483+ {
484+ outputs [ i ] = ctx . LoadNonEmptyString ( ) ;
426485
427- // *** Binary format ***
428- // int number of columns
429- // columns
430- for ( int i = 0 ; i < columnsLength ; i ++ )
431- columns [ i ] = new ColumnInfo ( ctx ) ;
486+ int csrc = ctx . Reader . ReadInt32 ( ) ;
487+ Contracts . CheckDecode ( csrc > 0 ) ;
488+ inputs [ i ] = new string [ csrc ] ;
489+ for ( int j = 0 ; j < csrc ; j ++ )
490+ {
491+ string src = ctx . LoadNonEmptyString ( ) ;
492+ inputs [ i ] [ j ] = src ;
493+ }
494+ }
495+
496+ // *** Binary format ***
497+ // int number of columns
498+ // columns
499+ for ( int i = 0 ; i < columnsLength ; i ++ )
500+ columns [ i ] = new ColumnInfo ( ctx , inputs [ i ] , outputs [ i ] ) ;
501+ }
432502 _columns = columns . ToImmutableArray ( ) ;
433503 TextModelHelper . LoadAll ( Host , ctx , columnsLength , out _slotNames , out _slotNamesTypes ) ;
434504 }
@@ -469,7 +539,8 @@ private static NgramHashingTransformer Create(IHostEnvironment env, ModelLoadCon
469539 {
470540 Contracts . CheckValue ( env , nameof ( env ) ) ;
471541 var host = env . Register ( nameof ( NgramHashingTransformer ) ) ;
472- return new NgramHashingTransformer ( host , ctx ) ;
542+ ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
543+ return new NgramHashingTransformer ( host , ctx , ctx . Header . ModelVerWritten < VersionTransformer ) ;
473544 }
474545
475546 private protected override IRowMapper MakeRowMapper ( Schema schema ) => new Mapper ( this , schema ) ;
0 commit comments