-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add error checking to CUDA version of getNeighborPairs #80
Conversation
The Autograd class is not allowed to hold any state, right? |
Can't one add a host node for that?
El lun, 16 ene 2023 17:38, Raimondas Galvelis ***@***.***>
escribió:
… ***@***.**** commented on this pull request.
------------------------------
In src/pytorch/neighbors/getNeighborPairsCUDA.cu
<#80 (comment)>:
> @@ -151,6 +166,13 @@ public:
get_accessor<scalar_t, 2>(deltas),
get_accessor<scalar_t, 1>(distances),
get_accessor<scalar_t, 2>(box_vectors));
+ cudaEventRecord(event, stream);
+ cudaEventSynchronize(event);
+ //Check the error flag
+ TORCH_CHECK(tooManyNeighborsErrorFlag == 0, "Some particle has too many neighbours, found " +
This won't work with CUDA Graphs, because the CPU code will be executed
only during a graph creation.
—
Reply to this email directly, view it on GitHub
<#80 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADDJV4FYNFQWMUWWGSPJCRDWSV2RPANCNFSM6AAAAAAT45ST3A>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
What about moving the host side checking of the flag into the backward pass? There will usually be a lot of other kernels launched between the two, so waiting on the event at that point won't hurt performance. The disadvantage is that the check will be missed if the backward pass is skipped, for example if someone computes only energy but not forces. |
What if a user runs just the forward pass? |
Right, that's the disadvantage. It would give people error checking in most of the common use cases without significantly hurting performance. But there exist use cases where error checking would be skipped. |
Since this is going to be part of a graph the check should go after launching the graph, and the entity doing so should be the one checking for errors in its execution. What function is going to be building and launching the graph? An exception from inside a CUDA graph is problematic. One solution is triggering a cudaError during graph execution. OTOH we could report the error in the two ways: A direct call to forward can check the host flag and launch and throw an exception. A cudaError can be raised if the call to the forward kernel is happening as part of a CUDA graph execution. EDIT: One cannot call any cuda api function from inside a callback, so I do not know how to raise a cuda error. |
…true will force the function to synchronize and throw an exception if some error was found, so it can be catched. The default will throw the error asynchronously, which will crash the program. In both cases a meaningful message is printed.
This commit introduces a new optional bool flag, check_errors, to getNeighborPairs. The default (False) will check for errors and throw asynchronously, printing a meaningful message but crashing. In case of a cuda graph the False option is forced, the error is thrown asynchronously, crashing the code with a meaningful message. If False is chosen error checking is virtually free. There is no synchronization penalty and since the error flag lives in managed memory there should not be a mem transfer footprint at all if the error did not happen. |
|
While this would make the error synchronous, it would be non catchable AFAIK. Do you know if Pytorch defines something that can be detected C++ side when using debug mode? That way I could make the exception synchronous AND catchable in debug mode only. Also, I believe the meaning of check_errors should be the opposite as you wrote.
In both cases the kernel could write a especial value (say -1 or NaN) to, for instance |
- Add a new optional flag, sync_exceptions on top of the current check_errors. - Three behaviors are possible: 1. Default (both false). Operation is CUDA-graph compatible and an uncatchable exception is thrown in case of number of pairs being too high. 2. check_errors=True. Operation is CUDA-graph compatible. No exception is thrown and the number of found pairs is returned, which can be higher than max_number_pairs. 3. check_errors=False and sync_exceptions=True. Operation is NOT CUDA-graph compatible. The operation synchronizes to check for errors and throws a catchable exception if necessary.
I followed @raimis suggestion and added a bit of mine, ending up with the following:
Pros: The current unit test for this functionality might help understand how it works: def test_too_many_neighbors(device, dtype):
if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')
# 4 points result into 6 pairs, but there is a storage just for 4.
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
with pytest.raises(RuntimeError):
# checkErrors = False will throw due to exceeding neighbours
# syncExceptions = True makes this exception catchable at the
# expense of performance (even when no error ocurred)
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1, check_errors=False, sync_exceptions=True)
pt.cuda.synchronize()
# checkErrors = True will never throw due to exceeding neighbours,
# but will return the number of pairs found.
# syncExceptions is ignored in this case
neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff=1, max_num_neighbors=1, check_errors=True)
assert number_found_pairs == 6 |
I made some changes to make getNeighborPairs CUDA-graph compatible, now one can do something like: device = 'cuda'
dtype = pt.float32
num_atoms = 100
positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype)
cutoff = 5
graph = pt.cuda.CUDAGraph()
with pt.cuda.graph(graph):
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=num_atoms*num_atoms)
graph.replay()
pt.cuda.synchronize() |
We decided to change the interface so that the number of pairs is always returned, meaning the user can now easily check if the maximum number of pairs is exceeded. This changes the restrictions of the original problem a bit. We wanted the user to be informed (via an exception, for instance) in the case of the number of pairs found being larger than the maximum allowed. Alas, informing the user in a recoverable way requires synchronizing (slow and incompatible with CUDA graphs), so I believe it is sensible that this functionality is guarded behind a flag. Best we can do AFAIK is let the user choose between:
Right now we can do that with only the check_errors flag. If you guys are ok then I will remove sync_exceptions. The only option I see if we really do not want the code to crash is to let the results be silently wrong when check_errors is false, passing onto the user the responsibility to check for the num_pairs return value. In my opinion this function should not let the code progress further if num_pairs> maximum_neighbors. The user is probably not going to bother checking and the danger of being silently wrong is not tolerable. For me, the ideal use case for this function would be something as follows: import torch as pt
from NNPOps.neighbors import getNeighborPairs
positions = pt.tensor(...)
max_num_neighbors = 1
# Find the maximum number of neighbors
while True:
try:
getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=max_num_neighbors, check_errors=True)
except RuntimeError:
max_num_neighbors += 32
continue
break
# Fast and CUDA-graph compatible calls that should not ever raise, but will crash if an error occurs
neigh, deltas, distances, num_pairs = getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=max_num_neighbors) But we can also make it be something like this: import torch as pt
from NNPOps.neighbors import getNeighborPairs
positions = pt.tensor(...)
max_num_neighbors = 1
# Find the maximum number of neighbors. This call will not ever raise, but be silently wrong.
neigh, deltas, distances, num_pairs = getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=max_num_neighbors)
if num_pairs> max_num_neighbors:
max_num_neighbors = num_pairs + 32
# This will also never raise, also silently wrong. This call and the above will be fast and CUDA-graph compatible
neigh, deltas, distances, num_pairs = getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=max_num_neighbors)
# This will raise if necessary, but not be CUDA-graph compatible.
neigh, deltas, distances, num_pairs = getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=max_num_neighbors, check_errors=True) Let me know what you think! |
I don't think a crash is ever a good way of reporting an error. I would vote for combining the flags so you have a single option.
|
Then check_errors=True would be the default. If you agree that the responsibility to check errors in CUDA graph mode should fall onto the user I will go ahead and implement @peastman 's proposal. |
Correct. |
with CUDA graphs. If check_errors=False (the default) getNeighborPairs does not check for errors and is compatible with graphs. If check_errors=True, the function raises if necessary but it is incompatible with graphs
This simplifies the logic greatly, no kernel-side atomic error flag is required and the graph can be constructed without requiring a host node. |
Enforce that the found number of pairs is less than num_pairs
enforced. Right now this does not pass, since the function allows that an atom has more neighbors than max_num_neighbors as long as num_found_pairs<num_atoms*max_num_neighbors
… neighbors per particle) to max_num_pairs (maximum number of total pairs).
This is ready for review again. |
Ping @peastman |
This addresses all the issues I raised. Looks good to me now! |
@RaulPPelaez I suppose, this is ready to merge? |
This is ready for merge. |
I have added a flag in managed memory. It will be atomically written if too many neighbors are found for some particle.
It is checked using an event just after execution of the kernel.
All tests are passing (even the too_many_neighbors one in the GPU)