-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SMEM mixed bcast/iter with thread binding fails to validate/produce correct code #1418
Comments
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
alignas(4) extern __shared__ char array[];
unsigned offset = 0;
void* shared_mem = array;
offset += ((blockDim.x * blockDim.y * blockDim.z) * sizeof(float));
float T1[T0.size[0]];
#pragma unroll 1
for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
T1[i16] = 0;
}
offset = alignBufferSize(offset,4);
float* T2 = reinterpret_cast<float*>(array + offset);
offset += ((T0.size[0] * 1) * sizeof(float));
#pragma unroll 1
for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
blockReduce<true, false, false>(
T1[i13],
T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
[](float &a, float b) { a = a + b; },
threadIdx,
blockDim,
static_cast<float*>(shared_mem),
(((nvfuser_index_t)threadIdx.x) < T0.size[1]),
float(0));
}
if (((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * 1)) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T2[((nvfuser_index_t)threadIdx.x)]
= T1[(((nvfuser_index_t)threadIdx.x) / 1)];
}
if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T3[(((nvfuser_index_t)threadIdx.x) * 1)]
= T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
+ T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
}
__barrier_sync(0);
} The issues here are:
|
This bug isn't high priority as we don't do parallelization schemes like this in practice, but it's not great that this produces such incorrect code. |
Disabled test case is added for this issue in #1412 |
I think another thing we may need to think about is that once a broadcast domain is merged with a non-broadcast domain, it becomes a non-broadcast domain, i.e., What's not really clear to me is what the merged domain of a broadcast and a non-broadcast domain would really mean. It doesn't matter when the broadcast is trivial, e.g., it's on shared memory, not concretized or not predicated. Otherwise, I'm not sure what the right semantics would be. If we don't care as that won't be necessary, we might just want to disable merging of broadcast and non-broadcast domains when the broadcast is not trivial. |
This partially addresses issue #1418. The repro still fails due to the missing RAW sync.
I'm not sure what the correct behavior should be, but this fusion currently fails at the validation in the beginning of lowering.
The issue is not (just) the validation. Even if the validation is skipped, invalid code is generated:
Notice that there's no
blockBroadcast
. This is not due to the concrete broadcast domain PR (#1412). It's because merging broadcast and non-broadcast domains results in a non-broadcast domain, soT2
has no broadcast, and no broadcast runtime call is generated.This problem only happens when a broadcast domain requires an actual parallel broadcast, which means the tensor must be predicated with the same parallel type as used on the broadcast domain.
The text was updated successfully, but these errors were encountered: