-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is higher order ops like fold and map_fn implemented? #143
Comments
Yes, for some. TF Scala already supports TensorFlow functions and they are used for some higher order ops, such as |
It seems that Dataset.map and Dataset.filter do not operate on Tensors ? The closest I found is whileLoop which can used to implement those HOFs. These higher order functions are important because there are important for so called "differentiable programming" which is the underlying theory of supervised learning (but not reinforcement learning ),and with them we can easily combines , for example, LSTM and CNN with CTC together just like ordinary functional programming! |
Yes, the dataset ops operate on datasets over nested structures of tensors, but you can use |
As temporary solution: A def map_fn[I: TF,O: TF](
tensor: Output[I],
fn: (Output[I], Output[Int]) => Output[O]
): Output[O] =
{
val len: Output[Int] = (tf shape tensor) apply 0
val input = TensorArray.create[I](len) unstack tensor
tf.whileLoop[(TensorArray[O],Output[Int]),(Shape,Shape)](
predicateFn = _._2 < len,
bodyFn = {
case (res, i) => (
res write ( i, fn(input read i, i) ),
i+1
)
},
loopVariables = ( TensorArray.create[O](len), 0 )
) match {
case (result,_) => result.stack()
}
}
// usage example
val in = tf.placeholder[Float]( Shape(-1,-1) )
val out = map_fn[Float,Float](in, (row,i) => row + (i*10+10).castTo[Float] )
val sess = Session()
try {
val in_feed: Tensor[Float] = Tensor(
Array(1,2,3),
Array(4,5,6)
)
val result = sess.run( FeedMap( in -> in_feed ), out )
println( result.summarize() )
} finally {
sess.close()
} The problem is the Tensorflow4Python implementation of Hope this still helps |
It would be pretty cool if val newTensor: Output[Double]
= for( row <- myTensor;
if tf.norm(row) > 2;
x <- row )
yield 2*x As soon as I get around to it, I'm gonna try and see if that's at all possible. EDIT: That absolutely works, see gist. |
No description provided.
The text was updated successfully, but these errors were encountered: