Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with ValueError #27

Closed
ademait opened this issue Mar 24, 2022 · 2 comments
Closed

Problem with ValueError #27

ademait opened this issue Mar 24, 2022 · 2 comments

Comments

@ademait
Copy link

ademait commented Mar 24, 2022

Hi everyone 馃憢 ,

First of all, thank you for such a great work with this library!
I'm having some trouble to understand and create my own GNN. I'm trying to do some sort of graph classification. I followed this tutorial, and now I'm trying to apply this network example to my own data.

This is an example of a graph:

GraphsTuple(nodes=DeviceArray([[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
             [1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
             [2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
             [3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
             [4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
             [5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
             [6.0000000e+00, 0.0000000e+00, 1.0000000e+00]],            dtype=float32), edges=DeviceArray([1, 1, 1, 1, 2, 2], dtype=int32), receivers=DeviceArray([3, 1, 2, 0, 4, 3], dtype=int32), senders=DeviceArray([5, 6, 6, 6, 3, 0], dtype=int32), globals=None, n_node=DeviceArray([7], dtype=int32), n_edge=DeviceArray([6], dtype=int32))

When I try to initialize the network, it outputs a ValueError: ValueError: data type <class 'numpy.int32'> not inexact.
This error comes from the last line of this code block (net.init(jax.random.PRNGKey(42), graph)):

def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  graph = dataset[0]['input_graph']
  
  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)

Graph dataset[0]['input_graph'] is the one shown above.

After reading the docs, some stackoverflow threads, and searching in Google, I haven't found anything to either understand or resolve this error.

I have some hesitation about the data types of the GraphsTuple. I tried to change the int32 type to int native type of python (as Nate says in this stackoverflow thread), and I couldn't change the types. Also, it may be the float32 type of the nodes field?

I submit this issue as I haven't found any useful resource to help me debug this error. I hope there is no inconvience to do so, and help others to resolve this error faster.

Thank you!

@salfaris
Copy link
Contributor

Hi @ademait,

I can't seem to reproduce your bug here, would it be possible that you send the full error message? It would also be helpful if you share how you defined dataset and net.


Here's my attempt at reproducing your bug.

  • Defining the graph
graph = jraph.GraphsTuple(
            nodes=jnp.asarray([
            [0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
            [1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
            [2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
            [3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
            [4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
            [5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
            [6.0000000e+00, 0.0000000e+00, 1.0000000e+00]]),
            edges=jnp.asarray([1, 1, 1, 1, 2, 2]), 
            receivers=jnp.asarray([3, 1, 2, 0, 4, 3]), 
            senders=jnp.asarray([5, 6, 6, 6, 3, 0]), 
            globals=None, 
            n_node=jnp.asarray([7]), 
            n_edge=jnp.asarray([6]))
dataset = [{'input_graph': graph}]  # Dummy dataset
  • Defining arbitrary net
def net_fn(graph: jraph.GraphsTuple) -> hk.Params:
  """Network function."""

  net = jraph.GraphConvolution(
    update_node_fn = lambda x: jax.nn.relu(hk.Linear(100)(x)),
    add_self_edges=True,
  )
  return net(graph)
  • Running train
def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  graph = dataset[0]['input_graph']
  
  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)

Running train works without any errors for me.

@ademait
Copy link
Author

ademait commented Jun 27, 2022

Hi @salfaris thanks for the answer.

It seems it was how I declared the variables with int32 types. Now the error is gone. 馃憤

@ademait ademait closed this as completed Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants