Skip to content

isinstance(numpy.zeros(1), jax.numpy.ndarray) returns True #2014

@honnibal

Description

@honnibal

numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray. I guess I see how this happens: I think Jax doesn't actually use that type, so it's maybe the actual one from numpy? It's a bit of a hassle when you're checking the array provenances though.

def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
    assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)

Btw, what's the preferred way to convert data from Jax to numpy? I've found jax.device_get() by poking around, but I don't think it's documented.

Thanks for the great project!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingquestionQuestions for the JAX team

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions