Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify synchronization semantics #57

Closed
kkraus14 opened this issue Dec 3, 2020 · 62 comments
Closed

Specify synchronization semantics #57

kkraus14 opened this issue Dec 3, 2020 · 62 comments

Comments

@kkraus14
Copy link

kkraus14 commented Dec 3, 2020

From the CUDA side, nearly all of the DL frameworks and array libraries that support dlpack use CUDA streams and stream order both their computations and memory allocations. In its current form, dlpack doesn't specify any synchronization semantics as well as doesn't have a way to specify information to allow for a producer-consumer pair to exchange the necessary information to continue to stream order computations.

I imagine there's a similar problem in other contexts as well (OpenCL, ROCm, etc.) where maybe it's possible to generalize an approach.

cc @oleksandr-pavlyk

@veritas9872
Copy link

I am also curious as to whether DLPack is blocking.

@leofang
Copy link
Contributor

leofang commented Dec 28, 2020

I would like to add that another data exchange formant, CUDA Array Interface, has a recent revision that @kkraus14 and I and many others have contributed to, see the Synchronization section from its v3 protocol. It turns out nontrivial to ensure the data being exchanged is correct with respect to the semantics of CUDA (or HIP or OpenCL, if one cares). It appears that DLPack currently lacks this level of consideration, and assumes implicitly that the Producer and the Consumer exchanging the data live on the same (whatever) stream/queue. This is a very strong assumption with hindsight.

I understand no one has complained about lacking such a synchronization requirement in DLPack (as this should have been a serious issue since the advent of DLPack, but in practice even on the CUDA Array Interface side we never receive a single bug report for not guaranteeing the synchronization behavior), but if the Python Array API Consortium is going to adopt DLPack as the official zero-copy standard, this will need to be rigorously sorted out as done in CUDA Array Interface v3.

cc: @rgommers (for awareness)

@veritas9872
Copy link

veritas9872 commented Dec 28, 2020

I think that making synchronization optional might also be an option. In most high-level CUDA tasks, as @leofang has already mentioned, most people do use the array on the same stream. However, forcing synchronization is time-consuming and goes against one of the main benefits of CUDA programming, host/device asynchronous behavior.

@tqchen
Copy link
Member

tqchen commented Dec 28, 2020

Thanks @leofang @kkraus14 . The issue of synchronization has been thought of, and it is certainly possible to come up with a specification that allows more asynchronize behavior. See also my previous related comment to this issue here data-apis/consortium-feedback#1 (comment)

The main design goal is mainly to get folks to agree on common set of things. The main problem is not about a solution for synchronization, but to get frameworks to agree on the mechanism.

Most modern frameworks have their own opion about allocation and scheduling, in many cases these opinions can vary. For example, both TF-RT and MXNet has its own internal scheduler. New abstraction are also being proposed (e.g. the case of graph base scheduling).

Based on that observation, we concluded that it is hard to get people agree on the synchronization and allocation mechanism (yet), and resort to a simpler convention -- we require producer and consumer to be on the same stream.

Most protocol that brings synchronization behavior will inevitably embed stream, or context into the data structure. This would defer synchronization burden to the runtime implementation of the system(consumer). While on one hand it is certainly more flexible, it does bring engineering cost to the developers who are supposed to consume the DLPack array.

Additionally, fewer expert programmers may want to think about synchronization, while there are broader range of programmers who may not even consider this perspective and only uses one stream throughout ther applications. Enabling context in DLPack means this complexity falls into the crowd, rather than a few experts. It also brings extra burden to the compilers themselves. The compiler will need to generate optional synchronization code based on the streams, which is non-trivial.

Finally, the simple protocol (of asking producer/consumer to be on the same stream) will not prevent async computation. As no synchronization is needed if both are on the same stream(for most apps they are). Alternative APIs can also be designed to request DLPack array that is presented in certain context, for example an optionally API below:

array.to_dlpack_on_stream(stream_number)

Such API will put the burden of sychronization to the producer, which is not necessarily more complex than putting things to the consumer, but reduces the metal burden on the consumer side(think about common developers who might want to consume the array).

To summarize, I think it is good to have a discussion about synchronization API. However, based on the current state, I think it would be useful to de-couple the discussion of synchronization(which is harder to get consensus) from the data structure(which most frameworks already agree on). I would also argue that simplicity is not necessary a bad thing in here, since we could build APIs that as powerful, but drives adoption in common cases

@tqchen
Copy link
Member

tqchen commented Dec 28, 2020

To summarize things further, let us discuss the following two variants of possible APIs to support sychronization

  • S0: Always enforcing on the same stream, optionally pass in a stream number/context that the consumer expects to be on
consumer_arr = consumer.from_dlpack(producer.to_dlpack_on_stream(stream_number))
  • S1: Embed Stream into the data structure.
consumer_arr = consumer.from_dlpack(producer.to_dlpack_with_stream_embedded())

From the flexibility PoV. S1 style API is certainly more flexible. For example, the consumer can choose to continue run computation on the streams provided by the producer.

The reality is, however, most of the application/framework developers (rightfully) have opinions about what streams to run computation on(that are attached to their own internal scheduler). So the best thing consumer could do is to run a dependency sync to bring the array to the internal stream then consumes the data, if the developer knows how to do so. In many cases
developers may not even want to think about synchronization at all and operates on the default stream -- we certainly want to be able to support these developers when possible. S1 also brings the burden to consider the complication when say two operands sit on different streams.

In the case of S0 style API simplicity is also a merit -- so more developers can agree on and implement such a protocol correctly. While being limited, it won't slows down computation, because it can also support asynchronize computation when possible. It also separates the opinion about data structure from the opinion of sychronization.

@leofang
Copy link
Contributor

leofang commented Dec 29, 2020

Hi @veritas9872 @tqchen Thanks for your replies! Looks like @tqchen you have put serious thoughts on this, and it's great to have you write it up here for future reference 🙂

I would like to kindly nudge you to read the Synchronization section of the CUDA Array Interface (CAI) v3, as many factors you raised are already considered throughly and covered there. An over-brief summary is below:

  • Clear definitions for Producer/Consumer/User are given (so please refer to the CAI doc to make sure we speak the same language 😅)
  • A synchronization mechanism is specified (via exposing the Producer's stream, see the CAI doc for its rationale), and a synchronization at the exchange point (by the Consumer) is required by default.
  • Optionally, libraries have a defined way to avoid synchronization (by not giving a stream pointer) if it fits the library's semantics.
  • Optionally, libraries are allowed to implement its own approaches to overwrite the default behavior (for example, Numba has NUMBA_CUDA_ARRAY_INTERFACE_SYNC, and CuPy has CUPY_CUDA_ARRAY_INTERFACE_SYNC and CUPY_CUDA_ARRAY_INTERFACE_EXPORT_VERSION).
  • Optionally, libraries can consume CAI arrays but ignore the synchronization requirement (the default behavior before CAI v3), but it should be a rare case with a clear documentation written up for its users. The only notable library I am aware of is mpi4py, since there we do not have control over any CUDA functionalities and simply delegate the device pointers to the underlying CUDA-aware MPI library.

All these considerations together guarantee the correctness of computations, which was deemed critical across the libraries at the time (PyTorch, TF, Numba, CuPy, RAPIDS, etc), while leaving flexibilities to expert users, in particular advanced HPC/ML users. In the CAI documentation a very complicated example was given to demonstrate that without such a synchronization requirement a computation would almost certainly go wrong.

My conjecture for CAI (before v3) and DLPack to work well so far is most libraries live on CUDA's legacy default stream, which has an implicit synchronization/serialization for kernels launched on it. But given that more libraries are adopting the per-thread default stream, this would eventually become a real concern.

If DLPack is going to assume the Producer and Consumer need to live on the same stream, it is fine to me, but it should be clearly stated in the documentation, which is currently not the case. Also, we would need a way to make sure of this if correctness is of top concern, which circles back to the 2nd point above that we need to access the Producer's stream pointer.

The main problem is not about a solution for synchronization, but to get frameworks to agree on the mechanism.

I totally agree. We had a very hard time when revising v3. Library and framework maintainers have their own opinions.

But it's very odd to me that you mentioned TF, because TF was one of the main drivers that were unwilling to adopt CAI unless we specify clearly the synchronization behavior (hence leading to the v3 revision), so it is difficult for me to understand how come they are OK with DLPack's (currently) unsafe behavior but gave us tons of headache (and in the end it's still unclear to me if they'd implement CAI v3).

On the API consideration, I think one important distinction between CAI and DLPack is that we do not expose the intermediate object to the User, as all of the information are simply dumped to a dict and generated (consumed) internally in Producer (Consumer). The benefit is threefold:

  • Users do not care nor manage the intermediate objects
  • There is no object ownership issues (as seen in many libraries' first adoption of DLPack)
  • A stream pointer can be simply included in the dict (CAI v3)

So I think this is the reason that you have to consider the S0 vs S1 styles above. To us it's really not an issue if the struct could have an additional field to hold the stream pointer (following S1), and Users can simply write (pseudo) code as simple as this (see the CAI doc for a better, practical example)

a = producer.array([1, 2, 3])
b = consumer.asarray(a)  # <-- sync happens here, if a stream pointer is given

that is, there is no explicit calls to to_CAI() and from_CAI(); the struct lives in the dict a.__cuda_array_interface__. All the burdens of syncing are on the Producer/Consumer maintainers, but it's really minimal. For example CuPy's Producer code is here, and Consumer code is here. Just a few lines. In fact I've seen shorter ones, so to me it should not be an obstacle (disclaimer: I maintain these codes). (In comparison, CuPy's to_dlpack() is here and from_dlpack() is here, many more lines....)

It is necessary to revise DLPack's struct. Consider we add an optional stream argument to to_dlpack() (but not to from_dlpack() since all it matters is the Producer's stream), following more or less your S0 style:

def to_dlpack(stream=None):
    pass
def from_dlpack(object):
    pass

this has three disadvantages:

  1. it puts the burden of managing streams on the Users, which I don't appreciate
  2. it's not necessary possible for the Producer to expose its internal stream handles (as in TF's case), so Users just cannot access it
  3. the synchronization would happen when the DLPack intermediate object is generated, not when it is actually consumed, because the stream pointer is not carried by the object

Given the above reasons, I just don't think your S0 style would work if steam synchronization is considered an issue.

@rgommers
Copy link
Contributor

On the API consideration, I think one important distinction between CAI and DLPack is that we do not expose the intermediate object to the User

A note just on this: we discussed that recently and agree to change the Python API to from_dlpack(x) and have x.__dlpack__ in order not to have a user-level capsule object: data-apis/consortium-feedback#1 (comment).

@szha
Copy link
Member

szha commented Dec 31, 2020

Taking a step back from CUDA's stream, for supporting synchronization I think DLPack may need to support other modes of scheduling as well (e.g. barrier). It might make sense to start with a more limited scenario first (i.e. always synchronize) that can be implemented in all abstraction and then add more interfaces to relax it.

@tqchen
Copy link
Member

tqchen commented Dec 31, 2020

Right now the options are clearly summarized as per S0 and S1 design choices. The choice of python API has converged on what @rgommers described.

The synchronization perspective of CAI mentioned by @leofang is summarized in S1(I tried to summarize it after reading the CAI v3 proposal). Right now, dlpack adopts a simple synchronization semantics(per S0) is specified as

To keep things simple, producer and consumer are required to operate on the same stream,
otherwise an explicit synchronization is needed to make the data visible to the right stream.

It would be great to try to talk about S0 and S1 style API further, perhaps from the framework implementer's PoV.

@tqchen
Copy link
Member

tqchen commented Dec 31, 2020

Here is my (personal) take on S1 style API(e.g. CAI v3) . First of all, it is certainly possible to specify the runtime dependent synchronization behavior correctly in a S1 API (CAI v3 should be such an example).

Noteably, such specification also brings in additional complications to the implementer that consumes a DLPack tensor. Specifically the following scenarios:

  • K0: Compiler builder need to consider arrays coming from two different streams, and add runtime synchronization mechanism
  • K1: Application developers that consumes a DLPack would need to consider stream synchronizations (not be that trivial due to the complication of producer consumer pairs being considered, as well as the target system's async behavior).

Both K0 and K1 are possible barriers for a fully compatible implementation.

In contrast, S0 style API, while being limited, is simpler to implement overall. We are more likely going to get fully compatible implementations from frameworks.

So our question here is not really about which one is more advanced and flexible(we all agree that S1 is more flexible and "advanced", while S0 is simpler and limits the possible ways to synchronize), but about the tradeoff between simplicity and flexibility.

@szha
Copy link
Member

szha commented Dec 31, 2020

The reason I don't think either S0 or S1 reached the heart of the problem is that they differ only in where to store the reference that's necessary for synchronization, for which we could coordinate as part of defining the synchronization semantics in DLPack. To me, the key difference between the current DLPack and the potential synchronization extension is that in asynchronous setting the provider needs to return a promise object. Such promise object can be a reference to which frameworks on supplying and consuming ends have common knowledge on how to synchronize, such as a stream number, but it's not always true. It's fully possible that the provider supplies such promise whose synchronization depends on a mechanism that the consumer has no knowledge of. More generally, the synchronization methods need not couple with the memory. For example, two imaginary frameworks TF and MX are running on a platform with both TPU and GPU, and TF passes an array promise to MX of which the computation happens on TPU first. It's fully possible that TF is the only framework that has knowledge on how to implement synchronization on TPU. Requiring frameworks to know how to synchronize may not even be possible, which means that DLPack may have to choose not to support such hardware even if the memory space is addressable by all frameworks.

In general, I think it makes sense for the promise object to supply a query function on whether the data is ready or not. Beyond that, I think of CUDA stream, or other synchronization methods commonly supported on multiple frameworks, more of special cases.

@leofang
Copy link
Contributor

leofang commented Dec 31, 2020

Wait...this brings up a question I've been wondering: What's DLPack's requirement for cross architecture transfer (say cpu <-> gpu)? I thought for the purpose of enabling zero copy with no surprise, it should error out and such transfers must be done by other means (usually a sane library would implement it). Is such behavior documented somewhere?

@szha
Copy link
Member

szha commented Dec 31, 2020

I believe DLPack only does one job at the moment, which is to facilitate the exchange of data in commonly addressable memory spaces. It does not yet handle data transfer across devices. What I described above is the scenario where the final memory address is being filled asynchronously as part of the final step of computation on a special device. I think this will come up even in a future version of CUDA when the async mem copy is supported. This is now already supported in CUDA 11.1.

@leofang
Copy link
Contributor

leofang commented Dec 31, 2020

Yeah I think cudaMallocAsync etc appear in CUDA 11.2.

@leofang
Copy link
Contributor

leofang commented Dec 31, 2020

The scenario is a more complex version of what was considered in CAI v3, so in this regard I agree with you that CAI v3 is a special case, but well it was meant for CUDA only 😗 I brought up CAI v3 is mainly to emphasize and echo the original issue that synchronization (or not) behavior at the interoperating points should be considered, determined, and documented.

@szha
Copy link
Member

szha commented Dec 31, 2020

Yes, for practical reasons I think it's valid (and most likely desired), and I'm happy to continue discussion on coordinating CUDA synchronization as a special case. My above comments are meant to make it clear that the same abstraction doesn't necessarily apply generally.

@tqchen
Copy link
Member

tqchen commented Dec 31, 2020

Great discussions, personally i do not have an opinion on style of possible synchronizations. Although both memory copy and other (allocations) all can have a stream based abstraction, so the S0 specification is still sound (although limited).

One observation from the current discussion is that while most of us agrees on data structures, there are many opinions about possible synchronization API(e.g. explicit specify string, stream based, or opaque future/promise). Each of them would bring some level of complexity and pros and cons.

My main point is to ideally de-couple the standardization of synchronization (which is more complicated) from standardization of data structure itself.

This would however, means we may not to couple the synchronization related data structure(e.g. stream) into DLTensor and limit the synchronization to a S0-level requirement at DLPack level. I would also argue that the S0-level semantics is simple and self-contained for now, can handle async settings(via default stream or same stream).

Additional discussions about synchronization API standardizations are great we could open a separate thread for that :)

@tqchen
Copy link
Member

tqchen commented Dec 31, 2020

Let me attempt to summarize the proposals so far and their pros and cons

S0: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self, stream=None):
    """Get a DLTensor capsule that is visible in stream.
   
    The producer will take charge to do dependency synchronize to stream.
    If no stream is None, then it defaults to the default stream(0 in cuda).
    """
    pass

Pros

  • A0: No additional info on data structures is needed.
  • A1: No explicit sync is needed for the most common case (default stream)
  • A2: There exist a simple implementation of producer and consumer while maintaining correctness and efficiency for most cases
    • Producer can always call full synchronization if a non-default stream is passed in.
    • Consumer can always pass in the default stream to begin with, and this is usually the most common usecase.

Cons:

  • K0 Less flexible than S1 in possible in cases as sync strategy is not deferred to consumer.

A further simplification of S0 would be requiring the exchange to always synced to the default stream (Note this is still different from explicit synchronization), this would remove the stream signature from the dlpack function

S1: Producer Stores the stream in the Exchange Data Structure, Defer Sync to Consumer

Mostly from API style described by @leofang, CAI style synchronization.

class DLManagedTensor:
    stream: int

def __dlpack__(self):
    """Return an exchange structure that stores the stream,"""
    pass

Pros:

  • B0 More flexible in possible syncs to write(defers sync to consumer)

Cons:

  • K2 Harder to implement the consumer logic, as consumer need to handle synchronization
  • K3 Harder for compilers (need runtime handling of data from different streams)

S2: Producer Stores an opaque Promise object in the Exchange Data Structure, Defer Sync to Consumer

Proposal by @szha

class DLManagedTensor:
    sync: function<()>
    ready: function<()>

def __dlpack__():
    """Return an exchange structure that stores the stream,"""
    pass

Pros:

  • Can handle non-CUDA case of asynchronity (e.g. producer have its own internal scheduler)

Cons:

  • Sync happens on the CPU side, cannot make use of the driver sync feature as in streams(unlike the stream case)

@tqchen
Copy link
Member

tqchen commented Dec 31, 2020

My current proposal is to start simple, and adopt a S0 style API. The simplicity and benefit bought by A1 should be sufficient for most of the high performing cases.

In reality I believe default stream is the only one that different frameworks could have agreed to, given framework have their opinions about internal stream context. This means the power of B0 in S1 is rarely used.

Additionally, A2 is critical for framework adoptions, as we want simple support to be implemented correctly for most of the cases. While additional effort can be put into bringing in stronger support if necessary.

@rgommers
Copy link
Contributor

rgommers commented Jan 1, 2021

S0 with def __dlpack__(self, stream=None): sounds good to me - simplicity and correctness while remaining performant for most cases seems like a good trade-off. A couple of questions:

  • Should we document which device types this applies to? CUDA and ROCm only, or anything else?
  • If stream is an integer, does that identify streams uniquely? E.g. if producer passes in 3, that is guaranteed to be stream 3 on the consumer side?
  • Is 0 unambigous enough given there's legacy and per-thread synchronization behaviour?
  • Just to make sure: the consumer library is supposed to pass this in and not the user, so the stream keyword will not be present in the from_lapack signature?

@tqchen
Copy link
Member

tqchen commented Jan 1, 2021

@rgommers it applies to driver API that have a global stream space that can be converted to integer, so for now it is CUDA and rocm case. In such cases, stream are guaranteed to be the same thing in the same process. For other driver APIs, the producer and consumer need to agree on the convention of "default stream".

The lack of global default is a limitation of other APIs considering the need of exchange -- in the scenario of single app of course being able to create app-specific context was considered as being flexible. However such flexibility limits the sharing of context between apps bought by standardization. The producer and consumer need to agree to default context in this case. A possible extension later is could be try to create such a global context table for other device APIs, where each application can access.

'None'(instead of '0') indicates the default stream, so it is unambigous. But I believe '0' is also fine.

You are right that from_dlpack does not have to contain stream signature. So user does not need to think about it.

@kkraus14
Copy link
Author

kkraus14 commented Jan 3, 2021

If stream isn't contained in the dlpack structure (and assumedly not cleaned up in the deleter function) and we're passing the stream via an integer, who is responsible for maintaining the lifetime of that stream? I.E.

import my_array_lib

my_arr = my_array_lib.create_array_on_stream([1, 2, 3], stream=my_array_lib.create_stream())
# my_arr internally holds a reference to the passed in `stream` object for lifetime management

my_dlpack = my_arr.to_dlpack(stream=int(my_arr.stream))
my_arr = None
# What state are we in now?

From what I can tell, the only way to guarantee that the stream isn't prematurely destroyed here is to either synchronize it before creating some_other_arr, or hold a reference to some_other_arr. If the handoff was to a C library then there's no clean way for it to hold that reference.

dlpack already has lifetime management semantics for the memory in the form of the deleter, where I'd argue we need similar for any stream resources used. We could say that synchronization needs to happen by producers before handing off the dlpack_tensor object, but then this introduces a lot of unnecessary synchronization and makes dlpack arguably less effective in acting as a good interchange mechanism between libraries.

@rgommers
Copy link
Contributor

rgommers commented Jan 3, 2021

We could say that synchronization needs to happen by producers before handing off the dlpack_tensor object, but then this introduces a lot of unnecessary synchronization and makes dlpack arguably less effective in acting as a good interchange mechanism between libraries.

Isn't this just a way of saying "I prefer S1 over S0", or " I want optimal performance even when both libraries use non-default streams, at the cost of implementation complexity"?

The trade-off here has some unstated assumption on how often this kind of data interchange is used. If one is gluing together two ML models by exchanging some final or intermediate output once in a while, and the occasional extra sync if both of those models use non-default streams isn't a big deal. If you have use cases that use from_dlpack a lot, that trade-off of implementation complexity vs. performance may be different.

@rgommers
Copy link
Contributor

rgommers commented Jan 3, 2021

As a meta-comment: I do not think it's necessarily the right goal to end up with essentially a copy of __cuda_array_interface__ on the C side of DLPack. Given that __cuda_array_interface__ gets away with passing a Python int, it seems like passing a C pointer and managing its lifetime for DLPack should not be necessary (given the CuPy/RAPIDS needs at least).

Other options include:

  • Add stream handling on the Python side of DLPack in the array API standard.
  • Keep DLPack to S0 API, and let the user use __cuda_array_interface__ directly (they'd call, e.g., asarray instead of from_dlpack).
  • Write a new hybrid Python-level from_xxx function that auto-selects between __dlpack__ and __cuda_array_interface__, and possibly __array_interface__ and other such device-specific mechanisms in the future. Let's say they should be named __devicename_array_interface__.

@kkraus14
Copy link
Author

kkraus14 commented Jan 3, 2021

The trade-off here has some unstated assumption on how often this kind of data interchange is used. If one is gluing together two ML models by exchanging some final or intermediate output once in a while, and the occasional extra sync if both of those models use non-default streams isn't a big deal.

This isn't true in many cases. Take for example a library which uses the default stream asynchronously, which is not uncommon. If I'm forced to synchronize the stream to make the memory able to be used immediately by any other stream, I'm required to synchronize the entire device which prevents me from doing any other work on the device until that synchronization is finished. I'd argue this is a big deal.

As a meta-comment: I do not think it's necessarily the right goal to end up with essentially a copy of __cuda_array_interface__ on the C side of DLPack.

I 100% agree that it shouldn't be a copy, but if we're interested in being an interchange protocol for devices that have an asynchronous execution model and an asynchronous memory model, then we should really support that asynchrony in the interchange protocol.

Given that __cuda_array_interface__ gets away with passing a Python int, it seems like passing a C pointer and managing its lifetime for DLPack should not be necessary (given the CuPy/RAPIDS needs at least).

I personally don't like that we're passing things like memory pointers / streams around as integers in Python when we could be passing some type of object oriented wrapper around that handles the lifetime management instead, but we followed the spec for __array_interface__ to start and then followed suit in adding the stream parameter.

@tqchen
Copy link
Member

tqchen commented Jan 4, 2021

We are getting back to the argument of S1 style API and S0 style API :)

Although it would be great to dissect the argument a bit. Right now the discussions goes back to fully synchronize vs async handling, and implies that "S0== sync entire device, S1 == async handling". This is not necessarily true.

I want to highlight that S0 does not imply synchronization of the entire device, and is never meant to say so. The decision of synchronization vs async handling can be done by the implementer in either S0 or S1 settings.

Let us discuss how async exchange can be handled in S0:

  • When both sides uses default stream, then no sync is needed (this is the simplest setting)
  • In cases where both sides uses their own stream, async exchange can still be done by stream dependency queing
// event can also be created on the fly, or create a synchronizer object and cache it.
// We could build auxiliary function that can be called from python side if that helps the frameworks
void PushStreamDep(cudaStream_t src, cudaStream dst) {
    cudaEvent_t event;
    cudaEventCreate(&event);
    cudaEventRecord(&event ,src);
    cudaStreamWaitForEvent(dst, event);
   cudaEventDestroy(&event);
}

The PushStreamDep can make sure the computation of src stream being oberved from the dst stream. In the case of producer consumer runs on different streams, synchronization can be done in two ways:

  • W0: in function __dlpack__(self, consumer_stream), the Producer calls PushStreamDep(producer_stream, consumer_stream)
  • W1: synchronize through default stream:
    • The consumer calls __dlpack__(self, default_stream)
    • Producer calls PushStreamDep(producer_stream, default_stream), so the data will be visible from the default stream
    • After the consumer obtains the data, it calls PushStreamDep(default_stream, consumer_stream), so the data becomes visible in the corresponding consumer stream.

Note that both W0 and W1 does not require synchronization of entire device and falls under the S0 API. The only difference from S1 is that PushStreamDep is called in the __dlpack__ (when case both producer and consumer stream are alive). This also avoid the overhead of stream lifecycle management problem @kkraus14 mentioned. As we know that common frameworks wants to manage the lifecycle of their own streams. So frameworks are not readily smart enough to handle streams exported from another framework.

In both W0 and W1. producer and consumer streams are being managed by producer and consumer only, without having to worry about the stream lifecycle due to exportation. This is again another simplicity bought by the S0 style handling.

To summarize, I think we all agree that it is ideal to handle exchange asynchronously when possible. Both S0 and S1 style API have mechanism to do so.

@kkraus14
Copy link
Author

kkraus14 commented Jan 4, 2021

When both sides uses default stream, then no sync is needed (this is the simplest setting)

Another case we covered in __cuda_array_interface__ is if the data is already synchronized and therefore can be consumed by any stream without doing any unnecessary synchronization or event waiting. I.E. someone produced something on a stream, but is continuing to do additional work on that stream, we don't want to unnecessarily synchronize the stream or wait on an event in the stream (unless the event could be guaranteed to be inserted onto the stream in the correct place relevant to the memory in question, but that sounds unlikely) and be blocked by unrelated work.

We decided to handle this using None for the case where the data does not require any synchronization, and explicitly disallowed 0 as the default stream is ambiguous between the legacy default stream(1) and the per-thread default stream (2).

W0: in function dlpack(self, consumer_stream), the Producer calls PushStreamDep(producer_stream, consumer_stream)

What would C/C++ libraries do? How do a producer C/C++ library and a consumer C/C++ library exchange stream information in a standard way? CUDA streams could likely be an implementation detail of the library where they're not exposed to a user developing with the library.

If we want to avoid putting streams, events, etc. into the dlpack_tensor struct then that's perfectly fine, but it really feels like standardizing how the information needs to be exchanged should be part of whatever protocol/standard is defined.

Producer calls PushStreamDep(producer_stream, default_stream), so the data will be visible from the default stream

Unfortunately, calling cudaStreamWaitEvent on the default stream will synchronize the entire device similar to if you used cudaStreamSynchronize (https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior__default-stream), so this would always synchronize the entire device.

@tqchen
Copy link
Member

tqchen commented Jan 14, 2021

@kkraus14 I believe a common util might be useful, although it should be simple enough for framework to just have their own impl

@leofang
Copy link
Contributor

leofang commented Jan 27, 2021

Can someone remind me why we wanna sync over the Consumer's stream, not the Producer's? Isn't the whole point of synchronizations to wait until the Producer is ready to handle out the data? Did I remember it wrong?

@tqchen
Copy link
Member

tqchen commented Jan 27, 2021

@leofang The "sync" (PushDep) happens asynchrously, and marks a dependency in the queue, which won't block the producer or consumer, at least in the context of CUDA

@leofang
Copy link
Contributor

leofang commented Jan 27, 2021

Thanks for quick reply and for the quotations over sync, @tqchen! But if the Producer does not expose a stream or event handle (in CUDA), how does the Consumer establish such a stream order (the dependency you referred to)?

@tqchen
Copy link
Member

tqchen commented Jan 27, 2021

In the current API, the producer implements producer.__dlpack__, and the producer needs to take charge of calling PushDep to make sure the data is visible to the consumer stream

@oleksandr-pavlyk
Copy link
Contributor

I'm coming to this discussion late, and I have not yet had a chance to internalize all the prior comments.

The way I understand the problem, producer may still have some kernels updating data referenced by producer.__dlpack__. The suggested way to avoid race condition without producer always waiting for those kernels to finish their work is for consumer library to submit to the same stream as the producer (relying on streams executing kernel synchronously).

If my understanding is correct, this does not generalize to SYCL. For consumer to ensure that SYCL kernels submitted by producer have finished updating the memory in the array to be shared, the consumer needs to receive a sycl::event from the producer, and use this event to declare that kernels submitted by consumer to work on the array have dependency of this event, signaling SYCL run-time to block submission of these kernels on this event.

@tqchen
Copy link
Member

tqchen commented Jan 28, 2021

@oleksandr-pavlyk I am not that familar with SYCL. If SYCL have constructs like fence(barrier) that can depends on events. An analogy to CUDA's void PushStreamDep(cudaStream_t src, cudaStream dst) would be for the producer to submit a fence operation(or a dummy kernel) to consumer_stream(perhaps context or queue in opencl's term), so that later operations to the same consumer_queue will see the effect of the data.

@leofang
Copy link
Contributor

leofang commented Jan 28, 2021

Ah right, thanks for the reminder @tqchen. PushStreamDep seems to be the last missing piece for my understanding, though we should make it clear that it's a shortcut construct stemming from #57 (comment), not a CUDA API 🙂

@oleksandr-pavlyk I think if SYCL's queues and events are analog to CUDA's streams and events, PushStreamDep should work fine? Can these SYCL handles be passed as integer in Python, or they have to be some opaque objects?

@oleksandr-pavlyk
Copy link
Contributor

oleksandr-pavlyk commented Feb 2, 2021

sycl::queue can be in-order (like CUDA stream), or out-of-order (default).

It is not possible in SYCL to insert a barrier to an out-of-order queue that will make all further submissions to the queue dependent on that event.

It is possible to do it synchronously from the host with producer_queue.wait().

In view of this, an implementation of PushStreamDep in SYCL is only possible if consumer_queue is in-order, i.e. consumer_queue.is_in_order() returns true.

To synchronize in an out-of-order queue requires passing an event, which will be used when submitting new kernels (e.g. see USM-version of GEMM in oneMKL).

Regarding whether SYCL handles can be passed as integers in Python, I am afraid not (like shared pointers can not be passed as integers in Python). The dpctl Python package exposes SyclQueue, SyclEvent, SyclDevice objects among others.

The opaque pointers to SYCL objects can of course be passed around as named PyCapsule objects.

@tqchen
Copy link
Member

tqchen commented Feb 2, 2021

Thanks @oleksandr-pavlyk . I agree that in such cases we would require exchanges through in order queues provided by the consumer.

To adapt such case, and when the consumer want to operate on out of order queues. The consumer can first create an in-order queue, just for exchange. Create an event X after the exchange finishes, and use X as dependencies for future out of order queues that depend on the data. Likely this won't create additional overhead other than the event depencency tracking.

Such way of adaptation would enable de-couple the issues of the synchronization, lifecycle management from the data structure.

Speaking for myself a framework implementer's PoV. The generalization to arbitrary event dependency chain, although being flexible, creates additional overheads during exchange for the lifecycle management, asynchronization convention etc. Having a layered simpler view(e.g. default in order queue) would help SYCL in the long run, and where we can learn from CUDA (simple and good enough for most cases)

@oleksandr-pavlyk
Copy link
Contributor

@tqchen In SYCL, the consumer would always need to provide an in-order-queue to ensure synchronization (since queue is out of order by default). Not passing any queue would require the producer_queue.wait() to be executed.

There is a SYCL extension proposed to enable one to asynchronously wait on event on the entire queue which may allow to_dlpack_on_stream to handle out-of-order queues as well.

Even so using in-order queue is going to be more performant.

@tqchen
Copy link
Member

tqchen commented Feb 6, 2021

Try to summarize the current state of discussion as well as pointing out another missing points (per GPU streams). Here is a complete proposal under S0

S0a: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self, stream=None):
    """Get a DLTensor capsule that is visible in stream.
   
    The producer will take charge to do dependency synchronize to stream.
    If no stream is None, then it defaults to the legacy default stream
    """
    pass

def __dlpack_device__(self) ->Tuple[Int, Int]:
    """Return a tuple pair of device_type, device_id in DLPack convention"""

# consumer code:
def consumer.from_dlpack(producer):
      device = producer.__dlpack_device__()
      consumer_stream = consumer.find_exchange_stream(device)
      dlpack_caps = producer.__dlpack_stream__(consumer_stream)
      return conver_to_consumer_array(dlpack_caps)

Note that most systems associate different stream to a particular device. So if we want to use non-default stream, then knowing which stream to synchronize to is important. As discussed ealier, the main benefit of this style of APIs are:

  • Separate out stream lifecycle management problem, no need to pass a stream's ownership from producer to consumer(which cause additional problems of attaching stream management function)
  • Have a good default case to start with, start by synching legacy default stream, which is well-defined.

Given the need to quickly move on, we can also start with a reduced version, and continue more discussions on stream exchange.

S0reduced: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self):
    """Get a DLTensor capsule that is visible in stream.
   
    The producer will take charge to do dependency synchronize to
    legacy default streamof the corresponding device
    """
    pass

This way the __dlpack_device__ function is not needed, and the behavior is well-defined and consistent among GPU and CPU cases.

@oleksandr-pavlyk
Copy link
Contributor

I think S0a is the way to go.

@rgommers
Copy link
Contributor

Good catch regarding the need for __dlpack_device__. Regarding the consumer code, two minor tweaks:

    # This call is a consumer-internal implementation detail, doesn't need to have a standardized name
    consumer_stream = _find_exchange_stream(device)
    dlpack_caps = producer.__dlpack__(consumer_stream)  # This was a typo, there's no `__dlpack_stream__`

I think S0a is the way to go.

I agree, adding this second method makes sense and it's not too likely there's something that would need changing later.

@kkraus14
Copy link
Author

Seems like we've aligned on the Python side with having __dlpack_device__ protocol and a __dlpack__ protocol which takes a stream as input, where the producer is responsible for making the data it produces safe to use on the given stream.

Presumably we need to solve this same problem for C libraries now. Should I open a new issue or do we want to continue in this issue?

@tqchen
Copy link
Member

tqchen commented Mar 17, 2021

Thanks @kkraus14 how about we open a new issue? Given that there are no standardized interface for C library exchange, perhaps we could add that as a interface recommendation?

@tqchen tqchen closed this as completed Aug 16, 2021
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Sep 13, 2021
Summary:
Partially Fixes #55090
Depends on #55365

Inspired by dmlc/dlpack#57 (comment)

Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an [`ExternalStream`](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.ExternalStream.html) object like the one we have in CuPy?

TODO: Add tests

Would like some feedback as this design needs quite a few iterations
rgommers leofang

Pull Request resolved: #57110

Reviewed By: saketh-are

Differential Revision: D30761481

Pulled By: mruberry

fbshipit-source-id: e85d78df3c1f8defc2a698878da89cd843cb1209
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants