Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disjoint MNIST Dataset? #31

Closed
kmader opened this issue Dec 4, 2019 · 3 comments
Closed

Disjoint MNIST Dataset? #31

kmader opened this issue Dec 4, 2019 · 3 comments

Comments

@kmader
Copy link
Contributor

kmader commented Dec 4, 2019

So my intuitive guess for how the MNIST dataset worked would be classifying the graph of just the positive pixels into 0-9. The example being a 2.
image

Is that an interesting problem to use as a benchmark for classification or is the fixed adjacency matrix and varying features better suited?

@danielegrattarola
Copy link
Owner

I've never seen what you suggest in the literature.
The 8-NN graph setting is taken from Defferrad et al. (2016), and it's interesting because it's a problem of graph signal classification: same topology, different features.

A GNN should not be able to distinguish the graph in the picture from a 5. If you remove the black pixels it should become almost impossible to classify MNIST, unless you encode spatial positions explicitly.

@kmader
Copy link
Contributor Author

kmader commented Dec 5, 2019

Thanks for the link, I am very news to graph NN and am just trying to get some feeling for it. The indistinguishability of 2 and 5 point makes sense but they don't get confused very often (they aren't however very accurately predicted)
image

The model seems to learn to recognize tips and twists very well, but struggle to count them which might be a weakness of the global attention layer. The figure shows the activation at each node for each channel.
image

@danielegrattarola
Copy link
Owner

Yeah, that 34% accuracy without coordinates does not surprise me. From a topology perspective, the GNN should be able to distinguish three classes depending on how many "holes" are in the graph: {1, 2, 3, 5, 7}, {4, 6, 9, 0}, {8}. I am not too sure about this claim though, I would need to think about it.

With spatial coordinates instead, it becomes a point cloud so it makes more sense. 87% seems a bit low, considering that on the original grid setting even very simple GNNs get to 99% easily, but it may be due to a number of factors.

Nice work! I am not sure that I would include it as a benchmark dataset in Spektral in order to avoid confusion, but you're doing a nice job for sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants