Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GINConv use example #65

Closed
JMcsLk opened this issue Jun 19, 2020 · 10 comments
Closed

GINConv use example #65

JMcsLk opened this issue Jun 19, 2020 · 10 comments

Comments

@JMcsLk
Copy link

JMcsLk commented Jun 19, 2020

Hello @danielegrattarola, may you please deliver in examples the use of GINConv layer? I have a problem during passing a tensor (output of Keras "Input Layer") to this layer (model definition). Its connected with propagate method in Message Passing class:

Model Structure:

X_in = Input(shape=(F, ))
A_in = Input(shape=(N, ), sparse=True)
gc1 = GINConv(channels=300, mlp_activation='relu',)([X_in, A_in])

Error relation:
self.index_i = A.indices[:, 0]

Error Type:
TypeError: 'SparseTensor' object is not subscriptable.

@JMcsLk JMcsLk changed the title MessagePassing use example GINConv use example Jun 23, 2020
@danielegrattarola
Copy link
Owner

Hi,

what version of TensorFlow/Keras are you using?
I am running the following and it works as expected on TF 2.2:

import numpy as np
import scipy.sparse as sp
from tensorflow.keras.layers import Input
from spektral.layers import GINConv
from spektral.layers import ops

A = sp.rand(10, 10)
A = ops.sp_matrix_to_sp_tensor(A)
X = np.random.randn(10, 5)

out = GINConv(300, activation='relu')([X, A])

Cheers

@JMcsLk
Copy link
Author

JMcsLk commented Jun 24, 2020

Thank you @danielegrattarola, it seems that problem was connected with Keras version. After upgrade everything seems ok in your example.

Could you just tell me how to preprocess label vector to disjoint mode? I'd like to connect GlobalSumPool layer to GINConv and next to fc one. After disjointing X and A matrices from (2200, 68, 68) and (2200, 68, 12) dimmensions I obtain (149600, 149600) (149600, 12) ones which produces the missmatch with y which is still 2200 length. How can I use "I" vector to resolve this problem?

Thanks!

@danielegrattarola
Copy link
Owner

I'm glad it works.

If you have your disjoint graph it should be sufficient to pass I to the GlobalSumPool layer:

X1 = GINConv(12, activation='relu')([X, A])
out = GlobalSumPool()([X1, I])  # Shape = (2200, 12)

Cheers

@JMcsLk
Copy link
Author

JMcsLk commented Jun 24, 2020

Thank you @danielegrattarola.

Is it possible to use GINConv multiple times in this nomenclature? If so, what's the mechanism of pooling A and I elements?

@danielegrattarola
Copy link
Owner

Yes, you can stack multiple layers.
Of course, once you use a global pooling layer you will lose the graph structure so you will not be able to apply message passing afterwards.

If you want to gradually reduce the size of the graph you can use "standard" pooling methods like MinCutPool or TopKPool, those will return a reduced X, A and I and you can apply GIN again afterwards.

Cheers

@JMcsLk
Copy link
Author

JMcsLk commented Jun 29, 2020

Thanks @danielegrattarola. I have an interesting observation to you. I was analyzing your disjoint type example about GraphConvSkip with TopKPool. In the train_step method I wanted to check some predictions by printining it on screen by each batch in training loop. To make it more clear in Fitting loop i placed something like this:

for b in batches:
    current_batch += 1
    **print(current_batch)**
    X_, A_, I_ = numpy_to_disjoint(*b[:-1])
    A_ = ops.sp_matrix_to_sp_tensor(A_)
    y_ = b[-1]
    outs = train_step(X_, A_, I_, y_)

And in train_step method:

def train_step(X_, A_, I_, y_):
    with tf.GradientTape() as tape:
        predictions = model([X_, A_, I_], training=True)
        print(predictions)

Could you tell me please why I'm getting this type of "print report":
Fitting model

1
Tensor("model/dense/Softmax:0", shape=(None, 3), dtype=float32)
Tensor("model/dense/Softmax:0", shape=(None, 3), dtype=float32)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
etc...

It looks like not every batch calls the predict method, I'm not sure if this has a negative effect on the learning process of the model.

When I'm printing predictions shape in evaluate method like this:

def evaluate(A_list, X_list, y_list, ops_list, batch_size):
    batches = batch_iterator([X_list, A_list, y_list], batch_size=batch_size)
    output = []
    for b in batches:
        X, A, I = numpy_to_disjoint(*b[:-1])
        A = ops.sp_matrix_to_sp_tensor(A)
        y = b[-1]
        pred = model([X, A, I], training=False)
        print(pred.shape)

As a continuation of trening process I'm getting this kind of report:

80
81
82
83
84
85
86
87
88
89
90
91
92
93
"Evaluation..."
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(2, 3)

Next epochs generates this:

1
2
3
4
5
6
7
8
9
(...)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(16, 3)
(2, 3)

Edit: I found it might be related to @tf.function.

@danielegrattarola
Copy link
Owner

Edit: I found it might be related to @tf.function.

That is exactly correct.
When you compile a function with tf.function any Python code gets called the first time, but not on subsequent calls to the compiled function.
This is true unless the shapes of the inputs change, in which case a "re-tracing" is triggered and the Python code gets called again. Have a look at this.

@JMcsLk
Copy link
Author

JMcsLk commented Jul 16, 2020

Thank you for your recent help @danielegrattarola !

I was analyzing your another example with GIN layer (https://github.com/danielegrattarola/spektral/blob/master/examples/graph_prediction/tud_disjoint.py) and want to ask some questions:

  • Why do you use MSE as an loss function in classification problem is Binary Crossentropy or Categorical Crossentropy not better way to represent loss value?

  • I saw that Message Passing Layer can also use Edge Features in propagation method, are these features irrelevant during the analysis (I've seen only A and X matrices as an input)?

  • Is it possible to use Embedding layers based on the properties of nodes and edges inside the GIN class before they are propagated?

@danielegrattarola
Copy link
Owner

@JMcsLk

Why do you use MSE as an loss function in classification problem is Binary Crossentropy or Categorical Crossentropy not better way to represent loss value?

That's because I copy-pasted the code from the QM9 regression example and forgot to change the loss :D

I saw that Message Passing Layer can also use Edge Features in propagation method, are these features irrelevant during the analysis (I've seen only A and X matrices as an input)?

Currently there are no methods based on the MessagePassing class that take advantage of the edge attributes, but the class supports them if you want to define your own layer. If you want to use edge attributes you can use the EdgeConditionedConv layer.

Is it possible to use Embedding layers based on the properties of nodes and edges inside the GIN class before they are> propagated?

GINConv is just another Keras layer. You can modify the inputs of the layer however you like. You can also re-define the GIN class to behave according to your needs, the class is self-contained in its own file and you should be able to copy-paste it and modify it easily.

Also, the MessagePassing class is exactly meant to simplify the definition of new layers, so you can have a look at that!

@JMcsLk
Copy link
Author

JMcsLk commented Jul 16, 2020

Ok I understand, thanks. Cool that I was helpful in some case (loss). :)

Have a great day!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants