Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SliceVertex value has different size than vertex #417

Closed
robertperrotta opened this issue Dec 20, 2018 · 3 comments
Closed

SliceVertex value has different size than vertex #417

robertperrotta opened this issue Dec 20, 2018 · 3 comments

Comments

@robertperrotta
Copy link
Contributor

The SliceVertex size matches the size of the input vertex but with the sliced dimension size 1. The value of the SliceVertex drops that dimension instead of setting it equal to 1.

import io.improbable.keanu.tensor.dbl.DoubleTensor
import io.improbable.keanu.vertices.ConstantVertex

fun main(args: Array<String>) {
    
    val a = ConstantVertex.of(DoubleTensor.eye(2))
    println(a.shape.toList())
    // [2, 2]

    val b = a.slice(0, 0)
    println(b.shape.toList())
    println(b.calculate().shape.toList())
    // [1, 2]; [2]

    val c = a.slice(1, 0)
    println(c.shape.toList())
    println(c.calculate().shape.toList())
    // [2, 1]; [2]

}

The tensor returned by SliceVertex.calculate should have the size equal to SliceVertex.initialShape.

Also, I found some strange behavior when writing this minimal example to produce the bug.

import io.improbable.keanu.tensor.dbl.DoubleTensor
import io.improbable.keanu.vertices.ConstantVertex


fun main(args: Array<String>) {


    val slice = ConstantVertex.of(DoubleTensor.eye(2))
            .slice(1, 0)

    // Raises exception
    try {
        println(slice.plus(ConstantVertex.of(DoubleTensor.ones(2))))
    } catch (err: Exception) {
        println(err.message)
    }

    // Raises a different exception
    try {
        val d = ConstantVertex.of(DoubleTensor.ones(2, 1))
        println(slice.plus(d).calculate())
    } catch (err: Exception) {
        println(err.message)
    }

    // No longer raises exception (but is the same code as the first try-catch block)
    try {
        println(slice.plus(ConstantVertex.of(DoubleTensor.ones(2))))
    } catch (err: Exception) {
        println(err.message)
    }

}
@gordoncaleb
Copy link
Contributor

Thanks for reporting this! The problem is that the initial shape is being calculated incorrectly. The slice operation should drop the dimension. Whether or not it should drop or set the dimension to 1 is up for debate but at the very least it should be consistent in the way you highlighted above. There's a fix for this on PR #413 to make it consistent. Do you think that slice should drop the dimension or set it to 1? The goal was to keep the operation behaviour as close to numpy as possible but in the case of slice in numpy there is a way to do either I think.

@robertperrotta
Copy link
Contributor Author

Good point. I like the consistency with numpy so dropping that dimension sounds good to me.

@gordoncaleb
Copy link
Contributor

@robertperrotta This should be fixed and released as of 0.0.17. Let us know if it still isn't what you expect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants