Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMEM mixed bcast/iter with thread binding fails to validate/produce correct code #1418

Open
naoyam opened this issue Jan 28, 2022 · 4 comments
Assignees
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2022

I'm not sure what the correct behavior should be, but this fusion currently fails at the validation in the beginning of lowering.

TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  fusion.addInput(tv0);

  auto tv1 = sum(tv0, {1});
  auto tv2 = broadcast(tv1, {false, true});
  auto tv3 = add(tv2, tv0);
  fusion.addOutput(tv3);

  tv1->axis(1)->parallelize(ParallelType::TIDx);

  tv2->merge(0, 1);
  tv2->axis(0)->parallelize(ParallelType::TIDx);

  tv3->merge(0, 1);
  tv3->axis(0)->parallelize(ParallelType::TIDx);

  fusion.printMath();
  fusion.printKernel();
}
Inputs:
  T0_g[ iS0{i1}, iS1{i2} ], float
Outputs:
  T3_g[ ithreadIdx.x9{( i1 * i2 )} ], float

%kernel_math {
T1_l[ iS2{i1}, rthreadIdx.x3{i2} ] = reduction( T0_g[ iS0{i1}, iS1{i2} ], op = add, initial value = double(0) )
T2_l[ ithreadIdx.x8{( i1 * 1 )} ] = broadcast( T1_l[ iS2{i1}, rthreadIdx.x3{i2} ] )
T3_g[ ithreadIdx.x9{( i1 * i2 )} ]
   = T2_l[ ithreadIdx.x8{( i1 * 1 )} ]
   + T0_g[ iS0{i1}, iS1{i2} ];
}

unknown file: Failure
C++ exception with description "predicated_parallel_types.none()INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/lower_validation.cpp":492, please report a bug to PyTorch. Invalid parallelization of tensor t2. The tensor is parallelized with threadIdx.x, but it's invalid to use the types as the tensor is also predicated with them., thread prd: threadIdx.x

The issue is not (just) the validation. Even if the validation is skipped, invalid code is generated:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
  alignas(4) extern __shared__ char array[];
  void* shared_mem = array;
  float T1[T0.size[0]];
  #pragma unroll 1
  for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
    T1[i16] = 0;
  }
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
    blockReduce<true, false, false>(
      T1[i13],
      T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
      [](float &a, float b) { a = a + b; },
      threadIdx,
      blockDim,
      static_cast<float*>(shared_mem),
      (((nvfuser_index_t)threadIdx.x) < T0.size[1]),
      float(0));
  }
  float T2[1];
  T2[0]
     = T1[(((nvfuser_index_t)threadIdx.x) / 1)];
  if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T3[(((nvfuser_index_t)threadIdx.x) * 1)]
      = T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
      + T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
  }
}

Notice that there's no blockBroadcast. This is not due to the concrete broadcast domain PR (#1412). It's because merging broadcast and non-broadcast domains results in a non-broadcast domain, so T2 has no broadcast, and no broadcast runtime call is generated.

This problem only happens when a broadcast domain requires an actual parallel broadcast, which means the tensor must be predicated with the same parallel type as used on the broadcast domain.

@csarofeen
Copy link
Owner

csarofeen commented Jan 28, 2022

blockBroadcast cannot support this type of use case. I would expect parallel validation should fail in this case instead of predicate check complaining here. This is a change in the parallelization strategy where T3 has a completely different parallel scheme than T2. The only way this can be supported is when T2 is in shared memory, then we should do the right thing, which we almost do when that check is removed:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
  alignas(4) extern __shared__ char array[];
  unsigned offset = 0;
  void* shared_mem = array;
  offset += ((blockDim.x * blockDim.y * blockDim.z) * sizeof(float));
  float T1[T0.size[0]];
  #pragma unroll 1
  for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
    T1[i16] = 0;
  }
  offset = alignBufferSize(offset,4);
  float* T2 = reinterpret_cast<float*>(array + offset);
  offset += ((T0.size[0] * 1) * sizeof(float));
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
    blockReduce<true, false, false>(
      T1[i13],
      T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
      [](float &a, float b) { a = a + b; },
      threadIdx,
      blockDim,
      static_cast<float*>(shared_mem),
      (((nvfuser_index_t)threadIdx.x) < T0.size[1]),
      float(0));
  }
  if (((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * 1)) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T2[((nvfuser_index_t)threadIdx.x)]
       = T1[(((nvfuser_index_t)threadIdx.x) / 1)];
  }
  if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T3[(((nvfuser_index_t)threadIdx.x) * 1)]
      = T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
      + T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
  }
  __barrier_sync(0);
}

The issues here are:

  1. There's a RAW race between writing T2 and its use in T3
    • I've been thinking that we should be detecting which tensors actually need a RAW protection in a build pass before the RAW pass. This type of pass would be really similar to the parallelization validation pass, and would just mark what type of communication is required to satisfy the parallelization scheme (smem or gmem communication).
  2. The predicate ((nvfuser_index_t)threadIdx.x) == 0 shouldn't exist on the T2 and T3 expressions.

@csarofeen csarofeen added bug and removed question labels Jan 28, 2022
@csarofeen csarofeen self-assigned this Jan 28, 2022
@csarofeen csarofeen changed the title Merging broadcast and non-broadcast domains SMEM mixed bcast/iter with thread binding fails to validate/produce correct code Jan 28, 2022
@csarofeen
Copy link
Owner

This bug isn't high priority as we don't do parallelization schemes like this in practice, but it's not great that this produces such incorrect code.

@csarofeen
Copy link
Owner

Disabled test case is added for this issue in #1412

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 28, 2022

I think another thing we may need to think about is that once a broadcast domain is merged with a non-broadcast domain, it becomes a non-broadcast domain, i.e., IterDomain::isBroadcast() returns false. We often just look at leaf IDs and see if there's any broadcast (e.g., TensorDomain::hasBroadcast()), which would just return false even when the root domain has a broadcast.

What's not really clear to me is what the merged domain of a broadcast and a non-broadcast domain would really mean. It doesn't matter when the broadcast is trivial, e.g., it's on shared memory, not concretized or not predicated. Otherwise, I'm not sure what the right semantics would be. If we don't care as that won't be necessary, we might just want to disable merging of broadcast and non-broadcast domains when the broadcast is not trivial.

naoyam added a commit that referenced this issue Jan 28, 2022
This partially addresses issue #1418. The repro still fails due to the
missing RAW sync.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants