Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance issue on dynamic shaped tensor #2795

Open
jjsjann123 opened this issue Aug 16, 2024 · 17 comments · May be fixed by #2835
Open

performance issue on dynamic shaped tensor #2795

jjsjann123 opened this issue Aug 16, 2024 · 17 comments · May be fixed by #2835
Assignees

Comments

@jjsjann123
Copy link
Collaborator

The re-written rope example has quite different indexing when input q / cos / sin is defined with static or dynamic shapes.

I think this is coming from the inconsistent fusion definition. i.e. when we switch to have inputs defined with dynamic shape, the follow up slice operations aren't using the symbolic slice extent, so we cannot collapse indexing after the slice.

q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])

Currently we python API only allows static number at this point, but no Val* yet. This would also require definition in lowering to be updated as well.

I'm opening this issue for myself. I think it's worth re-writing the script below to see if we can get full perf back with symbolic shape when the proper definition is produced.

import torch
from nvfuser import FusionDefinition, DataType

bsz = 2
block_size = 1024
n_head = 16
head_size = 32
rope_n_elem = 8

def rope_fusion(fd: FusionDefinition) -> None:
    q = fd.define_tensor(
        #shape=[bsz, n_head, block_size, head_size],
        shape=[-1, -1, -1, -1],
        contiguity=[True, True, True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[3, 2, 1, 0],
    )
    cos = fd.define_tensor(
        #shape=[block_size, rope_n_elem],
        shape=[-1, -1],
        contiguity=[True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[1, 0],
    )
    sin = fd.define_tensor(
        #shape=[block_size, rope_n_elem],
        shape=[-1, -1],
        contiguity=[True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[1, 0],
    )

    offset_0 = rope_n_elem // 2

    q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_remainder = fd.ops.slice(q, start_indices=[0, 0, 0, rope_n_elem], end_indices=[bsz, n_head, block_size, head_size], strides=[1, 1, 1, 1])
    q_remainder = fd.ops.pad(q_remainder, list(reversed([0, 0, 0, 0, 0, 0, 0, rope_n_elem])))

    q_left = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
    q_left = fd.ops.pad(q_left, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))
    q_right = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_right = fd.ops.pad(q_right, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))

    # note that this is identical to q_left and q_right. We should be able to merge it back.
    q_left_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
    q_left_cos = fd.ops.pad(q_left_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))
    q_right_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_right_cos = fd.ops.pad(q_right_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))

    # slice cos/sin
    cos_left = fd.ops.slice(cos, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
    cos_left = fd.ops.pad(cos_left, list(reversed([0, 0, head_size - offset_0, 0])))
    cos_left = fd.ops.broadcast_in_dim(cos_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
    cos_right = fd.ops.slice(cos, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
    cos_right = fd.ops.pad(cos_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
    cos_right = fd.ops.broadcast_in_dim(cos_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])

    sin_left = fd.ops.slice(sin, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
    sin_left = fd.ops.pad(sin_left, list(reversed([0, 0, head_size - offset_0, 0])))
    sin_left = fd.ops.broadcast_in_dim(sin_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
    sin_right = fd.ops.slice(sin, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
    sin_right = fd.ops.pad(sin_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
    sin_right = fd.ops.broadcast_in_dim(sin_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])

    q0 = (-q_right) * sin_left + cos_left * q_left_cos
    q1 = q_left * sin_right + cos_right * q_right_cos
    q_out = q0 + q1 + q_remainder
    q_out = fd.ops.cast(q_out, dtype=DataType.BFloat16)
    q0 = fd.ops.cast(q0, dtype=DataType.BFloat16)
    
    fd.add_output(q_out)

with FusionDefinition() as fd:
    rope_fusion(fd)

inputs = [
    torch.randn((bsz, n_head, block_size, head_size), dtype=torch.bfloat16, device="cuda:0"),
    torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
    torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
]

o = fd.execute(inputs)[0]
@jjsjann123 jjsjann123 self-assigned this Aug 16, 2024
@jjsjann123
Copy link
Collaborator Author

cc'ing @zasdfgbnm , I don't think there's any actionable item needed on your side at this moment. I'll update this after I checked the performance with the new definition.

@jacobhinkle
Copy link
Collaborator

Here is a diff of the generated pointwise kernels on my 3090Ti:

--- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
 }
 
 } // namespace fused_reduction
-__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
+__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
   NVFUSER_DEFINE_MAGIC_ZERO;
-  nvfuser_index_t i0;
-  i0 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * ((nvfuser_index_t)blockIdx.y));
-  nvfuser_index_t i1;
-  i1 = 8 * (i0 % 4);
-  nvfuser_index_t i2;
-  i2 = i0 / 4;
   nvfuser_index_t i3;
-  i3 = -4 + i1;
+  i3 = 8 * ((nvfuser_index_t)threadIdx.x);
   nvfuser_index_t i4;
-  i4 = i3 + (T26.alloc_stride[0LL] * i2);
+  i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
   nvfuser_index_t i5;
-  i5 = ((-4 + ((1024 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T6.alloc_stride[2LL] * i2);
+  i5 = i3 + i4;
   nvfuser_index_t i6;
-  i6 = i1 + (T14.alloc_stride[0LL] * i2);
+  i6 = 28 + T26.logical_size[1LL];
   nvfuser_index_t i7;
-  i7 = (((1024 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T10.alloc_stride[2LL] * i2);
+  i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
   nvfuser_index_t i8;
-  i8 = (((1024 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T8.alloc_stride[2LL] * i2);
+  i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
   nvfuser_index_t i9;
-  i9 = i1 + (T22.alloc_stride[0LL] * i2);
+  i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
   nvfuser_index_t i10;
-  i10 = ((-8 + ((1024 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T4.alloc_stride[2LL] * i2);
+  i10 = 28 + T6.logical_size[3LL];
   nvfuser_index_t i11;
-  i11 = i3 + (T18.alloc_stride[0LL] * i2);
+  i11 = 28 + T14.logical_size[1LL];
   nvfuser_index_t i12;
-  i12 = ((-4 + ((1024 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T12.alloc_stride[2LL] * i2);
+  i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
   nvfuser_index_t i13;
-  i13 = 8 * ((nvfuser_index_t)threadIdx.x);
+  i13 = 28 + T10.logical_size[3LL];
   nvfuser_index_t i14;
-  i14 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
+  i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
   nvfuser_index_t i15;
-  i15 = (i13 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i14;
-  bool b16;
-  b16 = (i13 + i14) < 32768;
-  if ((((i13 + 7) + i14) < 32768)) {
-    Array<__bfloat, 8, 8> T50;
+  i15 = 28 + T8.logical_size[3LL];
+  nvfuser_index_t i16;
+  i16 = 28 + T22.logical_size[1LL];
+  nvfuser_index_t i17;
+  i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
+  nvfuser_index_t i18;
+  i18 = 8 + T4.logical_size[3LL];
+  nvfuser_index_t i19;
+  i19 = 28 + T18.logical_size[1LL];
+  nvfuser_index_t i20;
+  i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
+  nvfuser_index_t i21;
+  i21 = 28 + T12.logical_size[3LL];
+  nvfuser_index_t i22;
+  i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+  nvfuser_index_t i23;
+  i23 = 24 * T10.logical_size[2LL];
+  nvfuser_index_t i24;
+  i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
+  bool b25;
+  b25 = i5 < i24;
+  nvfuser_index_t i26;
+  i26 = 28 * T10.logical_size[2LL];
+  bool b27;
+  b27 = i5 < (i26 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
+  bool b28;
+  b28 = i5 < (i26 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
+  bool b29;
+  b29 = i5 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i26);
+  bool b30;
+  b30 = i5 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
+  bool b31;
+  b31 = i5 < (i26 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
+  bool b32;
+  b32 = i5 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
+  bool b33;
+  b33 = i5 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+  if ((((i3 + 7) + i4) < i24)) {
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i18;
-      i18 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i35;
+      i35 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
       T27[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))] : 0.0000e+00;
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
       T7[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i35 / i10))) + (i35 % i10))] : 0.0000e+00;
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
@@ -10766,23 +10794,23 @@
       __bfloat T15[1];
       T15[0] = 0;
       T15[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i35 / i11)) + (i35 % i11))] : 0.0000e+00;
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
       T11[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i35 / i13))) + (i35 % i13))] : 0.0000e+00;
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
@@ -10790,7 +10818,7 @@
       __bfloat T9[1];
       T9[0] = 0;
       T9[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i35 / i15))) + (i35 % i15))] : 0.0000e+00;
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
@@ -10800,16 +10828,16 @@
       __bfloat T23[1];
       T23[0] = 0;
       T23[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i35 / i16)) + (i35 % i16))] : 0.0000e+00;
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
@@ -10817,30 +10845,30 @@
       __bfloat T5[1];
       T5[0] = 0;
       T5[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i35 / i18))) + (i35 % i18))] : 0.0000e+00;
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
       T19[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i35 / i19))) + (i35 % i19))] : 0.0000e+00;
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
       T13[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i35 / i21))) + (i35 % i21))] : 0.0000e+00;
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -10861,78 +10889,78 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
   } else {
-    Array<__bfloat, 8, 8> T50;
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i19;
-      i19 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i36;
+      i36 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
-      if (b16) {
+      if (b25) {
         T27[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + (i36 % i6))] : 0.0000e+00;
       }
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
-      if (b16) {
+      if (b27) {
         T7[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + (i36 % i10))] : 0.0000e+00;
       }
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
-      if (b16) {
+      if (b28) {
         T15[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + (i36 % i11))] : 0.0000e+00;
       }
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
-      if (b16) {
+      if (b29) {
         T11[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + (i36 % i13))] : 0.0000e+00;
       }
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
-      if (b16) {
+      if (b30) {
         T9[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + (i36 % i15))] : 0.0000e+00;
       }
       float T30[1];
       T30[0]
@@ -10942,56 +10970,56 @@
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
-      if (b16) {
+      if (b31) {
         T23[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + (i36 % i16))] : 0.0000e+00;
       }
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
-      if (b16) {
+      if (b32) {
         T5[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + (i36 % i18))] : 0.0000e+00;
       }
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
-      if (b16) {
+      if (b33) {
         T19[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + (i36 % i19))] : 0.0000e+00;
       }
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
-      if (b16) {
+      if (b30) {
         T13[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + (i36 % i21))] : 0.0000e+00;
       }
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -11012,12 +11040,12 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    if (b16) {
-      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    if ((i5 < 32768)) {
+      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
     }
   }
 }

The static kernel is just using the commented lines in the repro posted above. It achieves about 5x higher BW compared to dynamic (runs in 8 us vs 38).

There is more in the preamble for dynamic shapes, but inside the loops the expressions also have slightly more going on. For example (zoomed in and inserted line breaks):

-- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
       __bfloat T27[1];
       T27[0] = 0;
       T27[0] = ((((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
                   4) >= 0) &&
                 (((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
-                  4) < 4))
-          ? T26[(i4 + i18)]
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
+                  4) < T26.logical_size[1LL]))
+          ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))]
           : 0.0000e+00;
       __bfloat T28[1];
       T28[0] = T27[0];

In this context i35 is a loop index, so we might not be able to simplify the last diff line much, but we also are not hoisting (T26.logical_size[1LL] + 4) + 24) for some reason...

@jacobhinkle
Copy link
Collaborator

As for the preamble, there are lots of max and mins in the dynamic kernel, which could be avoided using #511 (I'm looking at updating this). As discussed last week, we could temporarily make all sliced input extents and all slice ranges constant at concretization, which I think would give us a kernel similar to static.cu above.

@jjsjann123
Copy link
Collaborator Author

For our own sanity, here's a simplified cpp test. Indexing isn't being simplified even when the slice is passing in the correct extent val.

For @jacobhinkle 's WAR in #511. I understand it as that, we wouldn't need this form of definition and the performance in the original python repro shouldn't regress with dynamic shape.

Creating this repro for @zasdfgbnm , I'm assuming the definition here should be enough to tell us that we are doing two non-overlapping slice and indexing maybe could be simplified, even without concretization...

TEST_F(NVFuserTest, DynamicShapedPad) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  auto zero = fusion->zeroVal();
  auto one = fusion->oneVal();

  std::vector<int64_t> shape{32, 1024, 16};

#if 0
  auto tv0 = makeContigConcreteTensor(shape);
  auto dim0 = IrBuilder::create<Val>(32, DataType::Index);
  auto dim1 = IrBuilder::create<Val>(1024, DataType::Index);
  auto dim2 = IrBuilder::create<Val>(16, DataType::Index);
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = IrBuilder::create<Val>(8, DataType::Index);
#else
  auto tv0 = makeContigTensor(3);
  auto dim0 = tv0->axis(0)->extent();
  auto dim1 = tv0->axis(1)->extent();
  auto dim2 = tv0->axis(2)->extent();
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = sub(dim2, val_slice);
#endif
  Slice slice_dim_0{zero, dim0, one};
  Slice slice_dim_1{zero, dim1, one};
  Slice slice_dim_2_l{zero, val_slice, one};
  Slice slice_dim_2_r{val_slice, dim2, one};

  std::vector<Slice> slice_l_ind = {slice_dim_0, slice_dim_1, slice_dim_2_l};
  std::vector<Slice> slice_r_ind = {slice_dim_0, slice_dim_1, slice_dim_2_r};

  auto slice_l = slice(tv0, slice_l_ind);
  auto slice_r = slice(tv0, slice_r_ind);

  fusion->addInput(tv0);

  auto rope_l = pad(slice_r, {zero, val_slice, zero, zero, zero, zero});
  // avoid segmentation.
  // auto rope_r = pad(neg(slice_l), {val_remain, zero, zero, zero, zero, zero});
  auto rope_r = neg(pad(slice_l, {val_remain, zero, zero, zero, zero, zero}));

  auto o = add(rope_l, rope_r);

  fusion->addOutput(o);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn(shape, options);
  std::vector<c10::IValue> aten_inputs({t0});

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

  testValidate(
      executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__);
}

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Aug 20, 2024

Thanks @jjsjann123 for providing this repro. I do believe there are some opportunities to simplify symbolicly here:

For example, the predicate of T3 = pad(T2) looks like:

Static shape:

-8 + (4 * (threadIdx.x % 4)) < -i0

Dynamic shape:

  ((((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) %
   ((8 * T1.logical_size[1LL]) +
    (T1.logical_size[1LL] * T2.logical_size[2LL]))) %
          (8 + T2.logical_size[2LL]) <
      T2.logical_size[2LL]

For the dynamic shape case, note that let:

a = ((4 * threadIdx.x) + (512 * blockIdx.x)) + i0;
b = T1.logical_size[1LL];
c = 8 + T2.logical_size[2LL];

then the predicate is:

a % (b * c) % c < T2.logical_size[2LL]

which clearly can be simplified as:

a % c = (((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) % (8 + T2.logical_size[2LL]) < T2.logical_size[2LL]

Clearly not as good as the static shape case, but still an improve.

Kernel diff: https://www.diffchecker.com/vK2pS9ak/

zasdfgbnm added a commit that referenced this issue Aug 20, 2024
Found this issue while reading
#2795
@jjsjann123
Copy link
Collaborator Author

Yeah if there's no low-hanging fruits, I don't think it matters at this point, since we are going down the path with @jacobhinkle 's plan on static shapes during concretization.

We can revisit this if we decide to push it further afterwards.

@jacobhinkle
Copy link
Collaborator

static shapes during concretization.

BTW in implementing this I just noticed that a lot of the resizes are dynamic but for the provided inputs are actually trivial:

    ?S9{( fmax(0, ( fmin(i0, 2) )) )}rf (index=0) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S11{( fmax(0, ( fmin(i1, 16) )) )}rf (index=1) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S13{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=2) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S15{( fmax(0, ( fmin(i3, 8) )) )}rf (index=3) is a resize of input extent 32 with left_pad=0 and right_pad=-24
    ?S43{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=4) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S45{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=5) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S47{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=6) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S49{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=7) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S112{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=8) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S114{( fmax(0, ( fmin(i8, 4) )) )}rf (index=9) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S82{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=10) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S84{( fmax(0, ( fmin(i6, 4) )) )}rf (index=11) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S56{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=12) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S58{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=13) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S60{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=14) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S62{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=15) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S30{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=16) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S32{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=17) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S34{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=18) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S36{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=19) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S127{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=20) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S129{( ( fmax(4, ( fmin(i8, 8) )) ) - 4 )}rf (index=21) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S97{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=22) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S99{( ( fmax(4, ( fmin(i6, 8) )) ) - 4 )}rf (index=23) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S69{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=24) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S71{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=25) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S73{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=26) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S75{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=27) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S17{( fmax(0, ( fmin(i0, 2) )) )}rf (index=28) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S19{( fmax(0, ( fmin(i1, 16) )) )}rf (index=29) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S21{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=30) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S23{( ( fmax(8, ( fmin(i3, 32) )) ) - 8 )}rf (index=31) is a resize of input extent 32 with left_pad=-8 and right_pad=0

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

@jjsjann123
Copy link
Collaborator Author

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

Thanks for pointing out that. Yes that's expected, this is one of the mismatch on thunder's static program. slice with [..., ] is baked in as constants as well... We need to re-write those logic later.

@jacobhinkle jacobhinkle linked a pull request Aug 23, 2024 that will close this issue
@jacobhinkle
Copy link
Collaborator

As of #2714, with the repro in the description of this issue, we went from 35 us on main to 13 us. Here is the diff of the generated kernel:

 __global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
   NVFUSER_DEFINE_MAGIC_ZERO;
   nvfuser_index_t i3;
   i3 = 8 * ((nvfuser_index_t)threadIdx.x);
   nvfuser_index_t i4;
   i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
   nvfuser_index_t i5;
   i5 = i3 + i4;
   nvfuser_index_t i6;
-  i6 = 28 + T26.logical_size[1LL];
+  i6 = 28 + T18.logical_size[1LL];
   nvfuser_index_t i7;
   i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
   nvfuser_index_t i8;
   i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
   nvfuser_index_t i9;
   i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
   nvfuser_index_t i10;
-  i10 = 28 + T6.logical_size[3LL];
+  i10 = 28 + T10.logical_size[3LL];
   nvfuser_index_t i11;
-  i11 = 28 + T14.logical_size[1LL];
+  i11 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
   nvfuser_index_t i12;
-  i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
+  i12 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
   nvfuser_index_t i13;
-  i13 = 28 + T10.logical_size[3LL];
+  i13 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
   nvfuser_index_t i14;
-  i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
+  i14 = 8 + T4.logical_size[3LL];
   nvfuser_index_t i15;
-  i15 = 28 + T8.logical_size[3LL];
+  i15 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
   nvfuser_index_t i16;
-  i16 = 28 + T22.logical_size[1LL];
-  nvfuser_index_t i17;
-  i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
-  nvfuser_index_t i18;
-  i18 = 8 + T4.logical_size[3LL];
-  nvfuser_index_t i19;
-  i19 = 28 + T18.logical_size[1LL];
-  nvfuser_index_t i20;
-  i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
-  nvfuser_index_t i21;
-  i21 = 28 + T12.logical_size[3LL];
-  nvfuser_index_t i22;
-  i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+  i16 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
-  nvfuser_index_t i23;
-  i23 = 24 * T10.logical_size[2LL];
-  nvfuser_index_t i24;
-  i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
-  nvfuser_index_t i25;
-  i25 = (7 + i3) + i4;
-  bool b26;
-  b26 = i25 < i24;
-  nvfuser_index_t i27;
-  i27 = 28 * T10.logical_size[2LL];
-  bool b28;
-  b28 = i25 < (i27 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
-  bool b29;
-  b29 = i25 < (i27 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
-  bool b30;
-  b30 = i25 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i27);
-  bool b31;
+  bool b17;
-  b31 = i25 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
-  bool b32;
-  b32 = i25 < (i27 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
-  bool b33;
-  b33 = i25 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
-  bool b34;
-  b34 = i25 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+  b17 = ((7 + i3) + i4) < 32768;
-  if ((((i3 + 7) + i4) < i24)) {
+  if ((((i3 + 7) + i4) < 32768)) {
     Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+    for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
-      nvfuser_index_t i36;
+      nvfuser_index_t i19;
-      i36 = i5 + (i35 + nvfuser_zero);
+      i19 = i5 + (i18 + nvfuser_zero);
+      nvfuser_index_t i20;
+      i20 = i19 % i6;
+      nvfuser_index_t i21;
+      i21 = -4 + i20;
+      nvfuser_index_t i22;
+      i22 = i19 / i6;
+      bool b23;
+      b23 = (i21 >= 0) && (i21 < T18.logical_size[1LL]);
+      nvfuser_index_t i24;
+      i24 = i19 % i10;
+      nvfuser_index_t i25;
+      i25 = i19 / i10;
+      bool b26;
+      b26 = i24 < T10.logical_size[3LL];
-      nvfuser_index_t i37;
+      nvfuser_index_t i27;
-      i37 = i36 % i6;
+      i27 = i19 % i14;
-      nvfuser_index_t i38;
+      nvfuser_index_t i28;
-      i38 = -4 + i37;
+      i28 = -8 + i27;
-      nvfuser_index_t i39;
-      i39 = i36 % i10;
-      nvfuser_index_t i40;
-      i40 = -4 + i39;
-      nvfuser_index_t i41;
-      i41 = i36 % i11;
-      nvfuser_index_t i42;
-      i42 = i36 % i13;
-      nvfuser_index_t i43;
-      i43 = i36 % i15;
-      nvfuser_index_t i44;
-      i44 = i36 % i16;
-      nvfuser_index_t i45;
-      i45 = i36 % i18;
-      nvfuser_index_t i46;
-      i46 = -8 + i45;
-      nvfuser_index_t i47;
-      i47 = i36 % i19;
-      nvfuser_index_t i48;
-      i48 = -4 + i47;
-      nvfuser_index_t i49;
-      i49 = i36 % i21;
-      nvfuser_index_t i50;
-      i50 = -4 + i49;
       __bfloat T27[1];
       T27[0] = 0;
       T27[0]
-         = ((i38 >= 0) && (i38 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + i37)] : 0.0000e+00;
+         = b23 ? T26[(i21 + (T26.alloc_stride[0LL] * i22))] : 0.0000e+00;
       __bfloat T28[1];
       T28[0]
          = T27[0];
       __bfloat T52[1];
       T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
       T7[0]
-         = ((i40 >= 0) && (i40 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + i39)] : 0.0000e+00;
+         = b23 ? T6[((i9 + i20) + (T6.alloc_stride[2LL] * i22))] : 0.0000e+00;
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
          = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
       T15[0]
-         = (i41 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + i41)] : 0.0000e+00;
+         = b26 ? T14[(i24 + (T14.alloc_stride[0LL] * i25))] : 0.0000e+00;
       __bfloat T16[1];
       T16[0]
          = T15[0];
       __bfloat T51[1];
       T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
       T11[0]
-         = (i42 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + i42)] : 0.0000e+00;
+         = b26 ? T10[((i11 + (T10.alloc_stride[2LL] * i25)) + i24)] : 0.0000e+00;
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
          = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
       T9[0]
-         = (i43 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + i43)] : 0.0000e+00;
+         = b26 ? T8[((i12 + i24) + (T8.alloc_stride[2LL] * i25))] : 0.0000e+00;
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
       float T31[1];
       T31[0]
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
       T23[0]
-         = (i44 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + i44)] : 0.0000e+00;
+         = b26 ? T22[(i24 + (T22.alloc_stride[0LL] * i25))] : 0.0000e+00;
       __bfloat T24[1];
       T24[0]
          = T23[0];
       __bfloat T50[1];
       T50[0]
          = T24[0];
       float T32[1];
       T32[0]
          = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
       T5[0]
-         = ((i46 >= 0) && (i46 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + i45)] : 0.0000e+00;
+         = ((i28 >= 0) && (i28 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i19 / i14))) + i27)] : 0.0000e+00;
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
       T19[0]
-         = ((i48 >= 0) && (i48 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + i47)] : 0.0000e+00;
+         = b23 ? T18[((-4 + (T18.alloc_stride[0LL] * i22)) + i20)] : 0.0000e+00;
       __bfloat T20[1];
       T20[0]
          = T19[0];
       __bfloat T53[1];
       T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
       T13[0]
-         = ((i50 >= 0) && (i50 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + i49)] : 0.0000e+00;
+         = b23 ? T12[((i15 + i20) + (T12.alloc_stride[2LL] * i22))] : 0.0000e+00;
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
          = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
         * T42[0];
       float T44[1];
       T44[0]
         = T40[0]
         + T43[0];
       float T37[1];
       T37[0]
         = T33[0]
         + T36[0];
       float T45[1];
       T45[0]
         = T37[0]
         + T44[0];
       float T47[1];
       T47[0]
         = T45[0]
         + T46[0];
-      T54[i35]
+      T54[i18]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
   } else {
     Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+    for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
+      nvfuser_index_t i29;
+      i29 = i5 + (i18 + nvfuser_zero);
+      nvfuser_index_t i30;
+      i30 = i29 % i6;
-      nvfuser_index_t i51;
+      nvfuser_index_t i31;
-      i51 = i5 + (i35 + nvfuser_zero);
+      i31 = -4 + i30;
-      nvfuser_index_t i52;
+      nvfuser_index_t i32;
-      i52 = i51 % i6;
+      i32 = i29 / i6;
+      bool b33;
+      b33 = (i31 >= 0) && (i31 < T18.logical_size[1LL]);
-      nvfuser_index_t i53;
+      nvfuser_index_t i34;
-      i53 = -4 + i52;
+      i34 = i29 % i10;
-      nvfuser_index_t i54;
+      nvfuser_index_t i35;
-      i54 = i51 % i10;
+      i35 = i29 / i10;
-      nvfuser_index_t i55;
-      i55 = -4 + i54;
-      nvfuser_index_t i56;
-      i56 = i51 % i11;
+      bool b36;
+      b36 = i34 < T10.logical_size[3LL];
-      nvfuser_index_t i57;
+      nvfuser_index_t i37;
-      i57 = i51 % i13;
+      i37 = i29 % i14;
-      nvfuser_index_t i58;
+      nvfuser_index_t i38;
-      i58 = i51 % i15;
-      nvfuser_index_t i59;
-      i59 = i51 % i16;
-      nvfuser_index_t i60;
-      i60 = i51 % i18;
-      nvfuser_index_t i61;
-      i61 = -8 + i60;
+      i38 = -8 + i37;
-      nvfuser_index_t i62;
-      i62 = i51 % i19;
-      nvfuser_index_t i63;
-      i63 = -4 + i62;
-      nvfuser_index_t i64;
-      i64 = i51 % i21;
-      nvfuser_index_t i65;
-      i65 = -4 + i64;
       __bfloat T27[1];
       T27[0] = 0;
-      if (b26) {
+      if (b17) {
         T27[0]
-           = ((i53 >= 0) && (i53 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i51 / i6))) + i52)] : 0.0000e+00;
+           = b33 ? T26[(i31 + (T26.alloc_stride[0LL] * i32))] : 0.0000e+00;
       }
       __bfloat T28[1];
       T28[0]
          = T27[0];
       __bfloat T52[1];
       T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
-      if (b28) {
+      if (b17) {
         T7[0]
-           = ((i55 >= 0) && (i55 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i51 / i10))) + i54)] : 0.0000e+00;
+           = b33 ? T6[((i9 + i30) + (T6.alloc_stride[2LL] * i32))] : 0.0000e+00;
       }
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
          = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
-      if (b29) {
+      if (b17) {
         T15[0]
-           = (i56 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i51 / i11)) + i56)] : 0.0000e+00;
+           = b36 ? T14[(i34 + (T14.alloc_stride[0LL] * i35))] : 0.0000e+00;
       }
       __bfloat T16[1];
       T16[0]
          = T15[0];
       __bfloat T51[1];
       T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
-      if (b30) {
+      if (b17) {
         T11[0]
-           = (i57 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i51 / i13))) + i57)] : 0.0000e+00;
+           = b36 ? T10[((i11 + (T10.alloc_stride[2LL] * i35)) + i34)] : 0.0000e+00;
       }
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
          = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
-      if (b31) {
+      if (b17) {
         T9[0]
-           = (i58 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i51 / i15))) + i58)] : 0.0000e+00;
+           = b36 ? T8[((i12 + i34) + (T8.alloc_stride[2LL] * i35))] : 0.0000e+00;
       }
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
       float T31[1];
       T31[0]
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
-      if (b32) {
+      if (b17) {
         T23[0]
-           = (i59 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i51 / i16)) + i59)] : 0.0000e+00;
+           = b36 ? T22[(i34 + (T22.alloc_stride[0LL] * i35))] : 0.0000e+00;
       }
       __bfloat T24[1];
       T24[0]
          = T23[0];
       __bfloat T50[1];
       T50[0]
          = T24[0];
       float T32[1];
       T32[0]
          = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
-      if (b33) {
+      if (b17) {
         T5[0]
-           = ((i61 >= 0) && (i61 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i51 / i18))) + i60)] : 0.0000e+00;
+           = ((i38 >= 0) && (i38 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i29 / i14))) + i37)] : 0.0000e+00;
       }
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
-      if (b34) {
+      if (b17) {
         T19[0]
-           = ((i63 >= 0) && (i63 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i51 / i19))) + i62)] : 0.0000e+00;
+           = b33 ? T18[((-4 + (T18.alloc_stride[0LL] * i32)) + i30)] : 0.0000e+00;
       }
       __bfloat T20[1];
       T20[0]
          = T19[0];
       __bfloat T53[1];
       T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
-      if (b31) {
+      if (b17) {
         T13[0]
-           = ((i65 >= 0) && (i65 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i51 / i21))) + i64)] : 0.0000e+00;
+           = b33 ? T12[((i15 + i30) + (T12.alloc_stride[2LL] * i32))] : 0.0000e+00;
       }
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
          = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
         * T42[0];
       float T44[1];
       T44[0]
         = T40[0]
         + T43[0];
       float T37[1];
       T37[0]
         = T33[0]
         + T36[0];
       float T45[1];
       T45[0]
         = T37[0]
         + T44[0];
       float T47[1];
       T47[0]
         = T45[0]
         + T46[0];
-      T54[i35]
+      T54[i18]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    if ((i25 < 32768)) {
+    if (b17) {
-      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
     }
   }
 }

I think there are simpler conditions for the padding predicates, and they're getting hoisted. With static shapes, this runtime can be improved further to about 8 us.

BTW looking at larger problem size (changing head size from 32 to 256 and bsz from 2 to 256), we have 28.1 ms vs 10.7 ms (2.6x speedup similar to the smaller problem size).

@jjsjann123
Copy link
Collaborator Author

patch this from thunder side per our earlier conversation: Lightning-AI/lightning-thunder#1096
That PR would effectively change program to be statically shaped in thunder. Until we pull through dynamic shape support in thunder.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Sep 16, 2024

After looking at this more closely, I believe the root of this problem is that our index simplification cannot deduce that e.g. i6 is divisible by 8 (the vectorization size), which means we cannot simplify expressions like this so that they no longer depend on the loop index i18 which is less than 8:

      i19 = i5 + (i18 + nvfuser_zero);
      nvfuser_index_t i20;
      i20 = i19 % i6;
      nvfuser_index_t i21;
      i21 = -4 + i20;
      nvfuser_index_t i22;
      i22 = i19 / i6;

Notice that we should already be able to prove that i5 is a multiple of 8, so we just need to know that i6 is also divisible by 8 to simplify i20 to i18 + 8*((i5 / 8) % (i6 / 8)) and i22 to i5 / i6, which does not depend on i18.

I am not sure yet whether it would enough for us to provide proof that vec size divides some of the extents, since in this case there are slices and pads so we get the constant offset i6 = 28+T18.logical_size[1LL]; above. In that case, T18.logical_size[1LL] is equal to 4 which is not itself divisible by 8, but once we add 28 it becomes divisible. I think if we figure out how to prove this, we will get a big speedup.

@naoyam
Copy link
Collaborator

naoyam commented Sep 16, 2024

Does i6 refer to the extent of a vectorized domain?

@jacobhinkle
Copy link
Collaborator

Does i6 refer to the extent of a vectorized domain?

I believe so. We segment into 6 NoOp segments that do slices and one Pointwise that has those sliced inputs as the intermediates. T18 is one of those intermediates (maybe cos_right?), and in the Pointwise segment we are padding it by 4 (left) and 24 (right). So yes, I think i6 is the padded size of T18, which is the same as the output size which is vectorized.

@jacobhinkle
Copy link
Collaborator

Maybe we could take all the vectorizable inputs and outputs and assume that the product of the inner contiguous extents up to the breakpoint is divisible by the vectorization. After that I'm not sure yet how we should represent divisibility in a way that we can use it in simplifyExpr.

@naoyam
Copy link
Collaborator

naoyam commented Sep 16, 2024

We do obtain that information but I don't think it's used by the expr simplifier.

https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/divisible_split.cpp#L23

@jacobhinkle
Copy link
Collaborator

I'm not sure yet how we should represent divisibility in a way that we can use it in simplifyExpr.

Thinking ahead a bit, I think we could do this with an "e-graph analysis" in a similar way to how constant folding is done with e-graphs. We would build an e-graph of integer scalars and merge exact mapped extents and cache it during indexing similar to how we cache proofs currently in the simplification Context. Each e-class would hold an unordered set of e-classes that are its known (overlapping) factors. When merging e-classes we would union their factor sets. We would also union when propagating through multiplication. We would propagate through addition and subtraction by intersecting the factor sets. Then in simplifyExpr we can query this data structure for common divisors between two scalars in O(n) and check if any scalar divides another in O(1) time.

@jacobhinkle
Copy link
Collaborator

We do obtain that information but I don't think it's used by the expr simplifier.

https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/divisible_split.cpp#L23

Ah yes. I was thinking that only did reshape splits but it also does vectorized. That's great.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants