-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with ValueError #27
Comments
Hi @ademait, I can't seem to reproduce your bug here, would it be possible that you send the full error message? It would also be helpful if you share how you defined Here's my attempt at reproducing your bug.
graph = jraph.GraphsTuple(
nodes=jnp.asarray([
[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
[1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
[2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
[3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
[4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
[5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
[6.0000000e+00, 0.0000000e+00, 1.0000000e+00]]),
edges=jnp.asarray([1, 1, 1, 1, 2, 2]),
receivers=jnp.asarray([3, 1, 2, 0, 4, 3]),
senders=jnp.asarray([5, 6, 6, 6, 3, 0]),
globals=None,
n_node=jnp.asarray([7]),
n_edge=jnp.asarray([6]))
dataset = [{'input_graph': graph}] # Dummy dataset
def net_fn(graph: jraph.GraphsTuple) -> hk.Params:
"""Network function."""
net = jraph.GraphConvolution(
update_node_fn = lambda x: jax.nn.relu(hk.Linear(100)(x)),
add_self_edges=True,
)
return net(graph)
def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
"""Training loop."""
# Transform impure `net_fn` to pure functions with hk.transform.
net = hk.without_apply_rng(hk.transform(net_fn))
# Get a candidate graph and label to initialize the network.
graph = dataset[0]['input_graph']
# Initialize the network.
params = net.init(jax.random.PRNGKey(42), graph) Running |
Hi @salfaris thanks for the answer. It seems it was how I declared the variables with |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi everyone 馃憢 ,
First of all, thank you for such a great work with this library!
I'm having some trouble to understand and create my own GNN. I'm trying to do some sort of graph classification. I followed this tutorial, and now I'm trying to apply this network example to my own data.
This is an example of a graph:
When I try to initialize the network, it outputs a ValueError:
ValueError: data type <class 'numpy.int32'> not inexact
.This error comes from the last line of this code block (
net.init(jax.random.PRNGKey(42), graph)
):Graph
dataset[0]['input_graph']
is the one shown above.After reading the docs, some stackoverflow threads, and searching in Google, I haven't found anything to either understand or resolve this error.
I have some hesitation about the data types of the
GraphsTuple
. I tried to change theint32
type toint
native type of python (as Nate says in this stackoverflow thread), and I couldn't change the types. Also, it may be thefloat32
type of thenodes
field?I submit this issue as I haven't found any useful resource to help me debug this error. I hope there is no inconvience to do so, and help others to resolve this error faster.
Thank you!
The text was updated successfully, but these errors were encountered: