@@ -787,7 +787,8 @@ private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
787
787
private void CastInputToString < T > ( OnnxContext ctx , out OnnxNode node , out long [ ] termIds , string srcVariableName , int iinfo ,
788
788
string opType , string labelEncoderOutput )
789
789
{
790
- var castOutput = ctx . AddIntermediateVariable ( TextDataViewType . Instance , "castOutput" ) ;
790
+ var srcShape = ctx . RetrieveShapeOrNull ( srcVariableName ) ;
791
+ var castOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( TextDataViewType . Instance , ( int ) srcShape [ 1 ] ) , "castOutput" ) ;
791
792
var castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
792
793
var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . String ) . ToType ( ) ;
793
794
castNode . AddAttribute ( "to" , t ) ;
@@ -799,7 +800,8 @@ private void CastInputToString<T>(OnnxContext ctx, out OnnxNode node, out long[]
799
800
private void CastInputToFloat < T > ( OnnxContext ctx , out OnnxNode node , out long [ ] termIds , string srcVariableName , int iinfo ,
800
801
string opType , string labelEncoderOutput )
801
802
{
802
- var castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Single , "castOutput" ) ;
803
+ var srcShape = ctx . RetrieveShapeOrNull ( srcVariableName ) ;
804
+ var castOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , ( int ) srcShape [ 1 ] ) , "castOutput" ) ;
803
805
var castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
804
806
var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
805
807
castNode . AddAttribute ( "to" , t ) ;
@@ -813,7 +815,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
813
815
long [ ] termIds ;
814
816
string opType = "LabelEncoder" ;
815
817
OnnxNode castNode ;
816
- var labelEncoderOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Int64 , "LabelEncoderOutput" ) ;
818
+ var labelEncoderOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Int64 , _types [ iinfo ] . GetValueCount ( ) ) , "LabelEncoderOutput" ) ;
817
819
818
820
var type = info . TypeSrc . GetItemType ( ) ;
819
821
if ( type . Equals ( TextDataViewType . Instance ) )
0 commit comments