Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TODO]: Address issue with torch and torch_scatter installation order - replace torch_scatter functions with native torch? #580

Closed
kevingreenman opened this issue Jan 10, 2024 · 8 comments
Assignees
Labels
bug Something isn't working todo add an item to the to-do list
Milestone

Comments

@kevingreenman
Copy link
Member

kevingreenman commented Jan 10, 2024

The issue installing torch and torch-scatter in one go is not something we're equipped to address, as it's a larger problem faced by many others. We should probably add a note in the README about installing correctly:

  1. active your desired environment (python >= 3.11)
  2. install torch
  3. install torch_scatter
  4. pip install chemprop (or pip install . if installing from source)

Originally posted by @davidegraff in #567 (comment)

@kevingreenman kevingreenman added the todo add an item to the to-do list label Jan 10, 2024
@kevingreenman kevingreenman added this to the v2.0.0 milestone Jan 10, 2024
@kevingreenman kevingreenman self-assigned this Jan 11, 2024
@kevingreenman
Copy link
Member Author

Also note from another comment that we might need special installation instructions for torch_scatter on Mac (installing pre-compiled binary rather than from PyPI): #566 (comment)

@JacksonBurns
Copy link
Member

@kevingreenman we might not need torch scatter anymore, according to this comment from the maintainer (rusty1s/pytorch_scatter#379 (comment)) all of the functionality is now in pytorch.

From a cursory glance through the source code, we only use scatter sum, mean, and softmax. The first two are directly implemented in PyTorch now (https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_) and the last I am sure we could find a workaround for (i.e. calling exp on the arg to scatter reduce w/ sum).

Just an idea - this might end up being more work.

@JacksonBurns
Copy link
Member

JacksonBurns commented Feb 1, 2024

We can actually just take the implementation of scatter softmax from torch_scatter (but using the pytorch's built-in scatter) and put it in chemprop: https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/composite/softmax.py#L9

@kevingreenman
Copy link
Member Author

Thanks for pointing that out! Curious what @davidegraff thinks about this since he's the one who added torch_scatter in #470

@davidegraff
Copy link
Contributor

I don't think it's a simple drop-in replacement as the APIs for torch_scatter.scatter_sum and torch.scatter_add are fairly different, but it seems it can be done:

import torch
 from torch_scatter import scatter_sum
X = torch.arange(4) * torch.ones(4, 6)
dim_size = 2
print(X)
# tensor([[0., 0., 0., 0., 0., 0.],
#         [1., 1., 1., 1., 1., 1.],
#         [2., 2., 2., 2., 2., 2.],
#         [3., 3., 3., 3., 3., 3.]])
index_torch_scatter = torch.tensor([0, 1, 0, 1])
index_torch = index_torch_scatter[:, None].repeat(1, 6)
print(scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size))
# tensor([[2., 2., 2., 2., 2., 2.],
#         [4., 4., 4., 4., 4., 4.]])
print(torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X))
# tensor([[2., 2., 2., 2., 2., 2.],
#         [4., 4., 4., 4., 4., 4.]])

testing with random inputs also seems to work

dim_size = 64
X = torch.randn(256, 256)
index_torch_scatter = torch.randint(dim_size, size=(X.shape[0],))
index_torch = index_torch_scatter[:, None].repeat(1, X.shape[1])
Z_torch_scatter = scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size)
Z_torch = torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X)
print(torch.isclose(Z_torch_scatter, Z_torch).all())
# tensor(True)

and FWIW, native torch seems to be slightly faster too:

from timeit import timeit
NUMBER = 10000
print(timeit('scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size)', globals=globals(), number=NUMBER)
# 0.4636257500387728
print(timeit('torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X)', globals=globals(), number=NUMBER)
# 0.4535887080710381

If anyone wants to tackle this and fix our environment build problems, the second code snippet should provide a path forward, i.e., for every instance of scatter_sum , replace with the following lines ((using the argument names of scatter_sum):

index = index[:, None].repeat(1, src.shape[1])
torch.zeros(dim_size, src.shape[1]).scatter_add(0, index, src)

Note: in many places, dim_size is inferred by torch_scatter, but it should always be equal to len(bmg.V) unless otherwise stated.

If anyone does decide to tackle this, I would first write a unit test to ensure the torch-native reimplementation is mathematically equivalent to the torch_scatter implementation. If this passes for enough random molecules, then you can overwrite the torch_scatter implementation and delete the test.

@kevingreenman kevingreenman changed the title [TODO]: Add note to README and/or docs about torch and torch_scatter installation order [TODO]: Address issue with torch and torch_scatter installation order - replace torch_scatter functions with native torch? Mar 2, 2024
@kevingreenman kevingreenman modified the milestones: v2.0.0-rc.1, v2.0.0 Mar 2, 2024
@kevingreenman kevingreenman added the bug Something isn't working label Mar 3, 2024
@kevingreenman
Copy link
Member Author

kevingreenman commented Mar 3, 2024

I've tested our installation instructions on my Mac and on 5 different Linux machines.

They worked without issue on my Mac (2021 M1 Pro), but we found in #695 that even updating the CI to define the correct order of installing torch and torch-scatter does not resolve the automated building step error on Mac.

On Linux, I was able to install with no issues on 3/5 machines. The other two machines encountered issues at the torch-scatter step. However, they are different errors than the typical one we see that comes from not having torch installed (ModuleNotFoundError: No module named 'torch').

One one machine (called slater, for my reference), I get:

RuntimeError:
      The detected CUDA version (11.6) mismatches the version that was used to compile
      PyTorch (12.1). Please make sure to use the same CUDA versions.

      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for torch-scatter
  Running setup.py clean for torch-scatter
Failed to build torch-scatter
ERROR: Could not build wheels for torch-scatter, which is required to install pyproject.toml-based projects

This machine is running the following:

NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0

On the other machine (called kohn, for my reference), I get:

error: [Errno 2] No such file or directory: '/usr/local/cuda/bin/nvcc'
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for torch-scatter
  Running setup.py clean for torch-scatter
Failed to build torch-scatter
ERROR: Could not build wheels for torch-scatter, which is required to install pyproject.toml-based projects

This machine currently has an issue with 1 of its 3 GPUs, but I'm not sure why that would lead to this error.

Based on these results, I think we should definitely try to move away from the torch-scatter dependency by replacing its functions with native torch alternatives between now and the April MLPDS meeting.

@JacksonBurns
Copy link
Member

I am going to take a wack at this. I have been trying to get the CI to work (#714) but cannot get torch_scatter to install on MacOS at all.

@davidegraff
Copy link
Contributor

for MacOS I have to install torch-scatter like so:

$ TORCH=2.0.0
$ CUDA=cpu
$ pip install -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html --no-index

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working todo add an item to the to-do list
Projects
None yet
Development

No branches or pull requests

4 participants