-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
performance issue on dynamic shaped tensor #2795
Comments
cc'ing @zasdfgbnm , I don't think there's any actionable item needed on your side at this moment. I'll update this after I checked the performance with the new definition. |
Here is a diff of the generated pointwise kernels on my 3090Ti: --- static.cu 2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu 2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
}
} // namespace fused_reduction
-__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
+__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
NVFUSER_DEFINE_MAGIC_ZERO;
- nvfuser_index_t i0;
- i0 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * ((nvfuser_index_t)blockIdx.y));
- nvfuser_index_t i1;
- i1 = 8 * (i0 % 4);
- nvfuser_index_t i2;
- i2 = i0 / 4;
nvfuser_index_t i3;
- i3 = -4 + i1;
+ i3 = 8 * ((nvfuser_index_t)threadIdx.x);
nvfuser_index_t i4;
- i4 = i3 + (T26.alloc_stride[0LL] * i2);
+ i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
nvfuser_index_t i5;
- i5 = ((-4 + ((1024 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T6.alloc_stride[2LL] * i2);
+ i5 = i3 + i4;
nvfuser_index_t i6;
- i6 = i1 + (T14.alloc_stride[0LL] * i2);
+ i6 = 28 + T26.logical_size[1LL];
nvfuser_index_t i7;
- i7 = (((1024 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T10.alloc_stride[2LL] * i2);
+ i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
nvfuser_index_t i8;
- i8 = (((1024 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T8.alloc_stride[2LL] * i2);
+ i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
nvfuser_index_t i9;
- i9 = i1 + (T22.alloc_stride[0LL] * i2);
+ i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
nvfuser_index_t i10;
- i10 = ((-8 + ((1024 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T4.alloc_stride[2LL] * i2);
+ i10 = 28 + T6.logical_size[3LL];
nvfuser_index_t i11;
- i11 = i3 + (T18.alloc_stride[0LL] * i2);
+ i11 = 28 + T14.logical_size[1LL];
nvfuser_index_t i12;
- i12 = ((-4 + ((1024 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T12.alloc_stride[2LL] * i2);
+ i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
nvfuser_index_t i13;
- i13 = 8 * ((nvfuser_index_t)threadIdx.x);
+ i13 = 28 + T10.logical_size[3LL];
nvfuser_index_t i14;
- i14 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
+ i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
nvfuser_index_t i15;
- i15 = (i13 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i14;
- bool b16;
- b16 = (i13 + i14) < 32768;
- if ((((i13 + 7) + i14) < 32768)) {
- Array<__bfloat, 8, 8> T50;
+ i15 = 28 + T8.logical_size[3LL];
+ nvfuser_index_t i16;
+ i16 = 28 + T22.logical_size[1LL];
+ nvfuser_index_t i17;
+ i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
+ nvfuser_index_t i18;
+ i18 = 8 + T4.logical_size[3LL];
+ nvfuser_index_t i19;
+ i19 = 28 + T18.logical_size[1LL];
+ nvfuser_index_t i20;
+ i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
+ nvfuser_index_t i21;
+ i21 = 28 + T12.logical_size[3LL];
+ nvfuser_index_t i22;
+ i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+ nvfuser_index_t i23;
+ i23 = 24 * T10.logical_size[2LL];
+ nvfuser_index_t i24;
+ i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
+ bool b25;
+ b25 = i5 < i24;
+ nvfuser_index_t i26;
+ i26 = 28 * T10.logical_size[2LL];
+ bool b27;
+ b27 = i5 < (i26 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
+ bool b28;
+ b28 = i5 < (i26 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
+ bool b29;
+ b29 = i5 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i26);
+ bool b30;
+ b30 = i5 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
+ bool b31;
+ b31 = i5 < (i26 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
+ bool b32;
+ b32 = i5 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
+ bool b33;
+ b33 = i5 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+ if ((((i3 + 7) + i4) < i24)) {
+ Array<__bfloat, 8, 8> T54;
#pragma unroll
- for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
- nvfuser_index_t i18;
- i18 = i17 + nvfuser_zero;
+ for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+ nvfuser_index_t i35;
+ i35 = i5 + (i34 + nvfuser_zero);
__bfloat T27[1];
T27[0] = 0;
T27[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))] : 0.0000e+00;
__bfloat T28[1];
T28[0]
= T27[0];
- __bfloat T29[1];
- T29[0]
+ __bfloat T52[1];
+ T52[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
T7[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i35 / i10))) + (i35 % i10))] : 0.0000e+00;
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
- = __bfloat2float(T29[0]);
+ = __bfloat2float(T52[0]);
float T40[1];
T40[0]
= T38[0]
@@ -10766,23 +10794,23 @@
__bfloat T15[1];
T15[0] = 0;
T15[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i35 / i11)) + (i35 % i11))] : 0.0000e+00;
__bfloat T16[1];
T16[0]
= T15[0];
- __bfloat T17[1];
- T17[0]
+ __bfloat T51[1];
+ T51[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
T11[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i35 / i13))) + (i35 % i13))] : 0.0000e+00;
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
- = __bfloat2float(T17[0]);
+ = __bfloat2float(T51[0]);
float T36[1];
T36[0]
= T34[0]
@@ -10790,7 +10818,7 @@
__bfloat T9[1];
T9[0] = 0;
T9[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i35 / i15))) + (i35 % i15))] : 0.0000e+00;
float T30[1];
T30[0]
= __bfloat2float(T9[0]);
@@ -10800,16 +10828,16 @@
__bfloat T23[1];
T23[0] = 0;
T23[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i35 / i16)) + (i35 % i16))] : 0.0000e+00;
__bfloat T24[1];
T24[0]
= T23[0];
- __bfloat T25[1];
- T25[0]
+ __bfloat T50[1];
+ T50[0]
= T24[0];
float T32[1];
T32[0]
- = __bfloat2float(T25[0]);
+ = __bfloat2float(T50[0]);
float T33[1];
T33[0]
= T31[0]
@@ -10817,30 +10845,30 @@
__bfloat T5[1];
T5[0] = 0;
T5[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i35 / i18))) + (i35 % i18))] : 0.0000e+00;
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
T19[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i35 / i19))) + (i35 % i19))] : 0.0000e+00;
__bfloat T20[1];
T20[0]
= T19[0];
- __bfloat T21[1];
- T21[0]
+ __bfloat T53[1];
+ T53[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
T13[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i18)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i35 / i21))) + (i35 % i21))] : 0.0000e+00;
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
- = __bfloat2float(T21[0]);
+ = __bfloat2float(T53[0]);
float T43[1];
T43[0]
= T41[0]
@@ -10861,78 +10889,78 @@
T47[0]
= T45[0]
+ T46[0];
- T50[i17]
+ T54[i34]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
- loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+ loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
} else {
- Array<__bfloat, 8, 8> T50;
+ Array<__bfloat, 8, 8> T54;
#pragma unroll
- for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
- nvfuser_index_t i19;
- i19 = i17 + nvfuser_zero;
+ for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+ nvfuser_index_t i36;
+ i36 = i5 + (i34 + nvfuser_zero);
__bfloat T27[1];
T27[0] = 0;
- if (b16) {
+ if (b25) {
T27[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + (i36 % i6))] : 0.0000e+00;
}
__bfloat T28[1];
T28[0]
= T27[0];
- __bfloat T29[1];
- T29[0]
+ __bfloat T52[1];
+ T52[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
- if (b16) {
+ if (b27) {
T7[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + (i36 % i10))] : 0.0000e+00;
}
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
- = __bfloat2float(T29[0]);
+ = __bfloat2float(T52[0]);
float T40[1];
T40[0]
= T38[0]
* T39[0];
__bfloat T15[1];
T15[0] = 0;
- if (b16) {
+ if (b28) {
T15[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + (i36 % i11))] : 0.0000e+00;
}
__bfloat T16[1];
T16[0]
= T15[0];
- __bfloat T17[1];
- T17[0]
+ __bfloat T51[1];
+ T51[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
- if (b16) {
+ if (b29) {
T11[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + (i36 % i13))] : 0.0000e+00;
}
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
- = __bfloat2float(T17[0]);
+ = __bfloat2float(T51[0]);
float T36[1];
T36[0]
= T34[0]
* T35[0];
__bfloat T9[1];
T9[0] = 0;
- if (b16) {
+ if (b30) {
T9[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + (i36 % i15))] : 0.0000e+00;
}
float T30[1];
T30[0]
@@ -10942,56 +10970,56 @@
= -T30[0];
__bfloat T23[1];
T23[0] = 0;
- if (b16) {
+ if (b31) {
T23[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + (i36 % i16))] : 0.0000e+00;
}
__bfloat T24[1];
T24[0]
= T23[0];
- __bfloat T25[1];
- T25[0]
+ __bfloat T50[1];
+ T50[0]
= T24[0];
float T32[1];
T32[0]
- = __bfloat2float(T25[0]);
+ = __bfloat2float(T50[0]);
float T33[1];
T33[0]
= T31[0]
* T32[0];
__bfloat T5[1];
T5[0] = 0;
- if (b16) {
+ if (b32) {
T5[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + (i36 % i18))] : 0.0000e+00;
}
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
- if (b16) {
+ if (b33) {
T19[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + (i36 % i19))] : 0.0000e+00;
}
__bfloat T20[1];
T20[0]
= T19[0];
- __bfloat T21[1];
- T21[0]
+ __bfloat T53[1];
+ T53[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
- if (b16) {
+ if (b30) {
T13[0]
- = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i19)] : 0.0000e+00;
+ = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + (i36 % i21))] : 0.0000e+00;
}
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
- = __bfloat2float(T21[0]);
+ = __bfloat2float(T53[0]);
float T43[1];
T43[0]
= T41[0]
@@ -11012,12 +11040,12 @@
T47[0]
= T45[0]
+ T46[0];
- T50[i17]
+ T54[i34]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
- if (b16) {
- loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+ if ((i5 < 32768)) {
+ loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
}
}
} The static kernel is just using the commented lines in the repro posted above. It achieves about 5x higher BW compared to dynamic (runs in 8 us vs 38). There is more in the preamble for dynamic shapes, but inside the loops the expressions also have slightly more going on. For example (zoomed in and inserted line breaks): -- static.cu 2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu 2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
__bfloat T27[1];
T27[0] = 0;
T27[0] = ((((((((((nvfuser_index_t)blockIdx.y) *
((nvfuser_index_t)blockDim.x)) +
((nvfuser_index_t)threadIdx.x)) *
8) +
- (i17 + nvfuser_zero)) %
- 32) -
+ (i34 + nvfuser_zero)) %
+ ((T26.logical_size[1LL] + 4) + 24)) -
4) >= 0) &&
(((((((((nvfuser_index_t)blockIdx.y) *
((nvfuser_index_t)blockDim.x)) +
((nvfuser_index_t)threadIdx.x)) *
8) +
- (i17 + nvfuser_zero)) %
- 32) -
- 4) < 4))
- ? T26[(i4 + i18)]
+ (i34 + nvfuser_zero)) %
+ ((T26.logical_size[1LL] + 4) + 24)) -
+ 4) < T26.logical_size[1LL]))
+ ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))]
: 0.0000e+00;
__bfloat T28[1];
T28[0] = T27[0]; In this context |
As for the preamble, there are lots of |
For our own sanity, here's a simplified cpp test. Indexing isn't being simplified even when the slice is passing in the correct extent val. For @jacobhinkle 's WAR in #511. I understand it as that, we wouldn't need this form of definition and the performance in the original python repro shouldn't regress with dynamic shape. Creating this repro for @zasdfgbnm , I'm assuming the definition here should be enough to tell us that we are doing two non-overlapping slice and indexing maybe could be simplified, even without concretization...
|
Thanks @jjsjann123 for providing this repro. I do believe there are some opportunities to simplify symbolicly here: For example, the predicate of Static shape: -8 + (4 * (threadIdx.x % 4)) < -i0 Dynamic shape: ((((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) %
((8 * T1.logical_size[1LL]) +
(T1.logical_size[1LL] * T2.logical_size[2LL]))) %
(8 + T2.logical_size[2LL]) <
T2.logical_size[2LL] For the dynamic shape case, note that let: a = ((4 * threadIdx.x) + (512 * blockIdx.x)) + i0;
b = T1.logical_size[1LL];
c = 8 + T2.logical_size[2LL]; then the predicate is: a % (b * c) % c < T2.logical_size[2LL] which clearly can be simplified as: a % c = (((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) % (8 + T2.logical_size[2LL]) < T2.logical_size[2LL] Clearly not as good as the static shape case, but still an improve. Kernel diff: https://www.diffchecker.com/vK2pS9ak/ |
Found this issue while reading #2795
Yeah if there's no low-hanging fruits, I don't think it matters at this point, since we are going down the path with @jacobhinkle 's plan on static shapes during concretization. We can revisit this if we decide to push it further afterwards. |
BTW in implementing this I just noticed that a lot of the resizes are dynamic but for the provided inputs are actually trivial:
By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon. |
Thanks for pointing out that. Yes that's expected, this is one of the mismatch on thunder's static program. slice with |
As of #2714, with the repro in the description of this issue, we went from 35 us on __global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i3;
i3 = 8 * ((nvfuser_index_t)threadIdx.x);
nvfuser_index_t i4;
i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
nvfuser_index_t i5;
i5 = i3 + i4;
nvfuser_index_t i6;
- i6 = 28 + T26.logical_size[1LL];
+ i6 = 28 + T18.logical_size[1LL];
nvfuser_index_t i7;
i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
nvfuser_index_t i8;
i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
nvfuser_index_t i9;
i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
nvfuser_index_t i10;
- i10 = 28 + T6.logical_size[3LL];
+ i10 = 28 + T10.logical_size[3LL];
nvfuser_index_t i11;
- i11 = 28 + T14.logical_size[1LL];
+ i11 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
nvfuser_index_t i12;
- i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
+ i12 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
nvfuser_index_t i13;
- i13 = 28 + T10.logical_size[3LL];
+ i13 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
nvfuser_index_t i14;
- i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
+ i14 = 8 + T4.logical_size[3LL];
nvfuser_index_t i15;
- i15 = 28 + T8.logical_size[3LL];
+ i15 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
nvfuser_index_t i16;
- i16 = 28 + T22.logical_size[1LL];
- nvfuser_index_t i17;
- i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
- nvfuser_index_t i18;
- i18 = 8 + T4.logical_size[3LL];
- nvfuser_index_t i19;
- i19 = 28 + T18.logical_size[1LL];
- nvfuser_index_t i20;
- i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
- nvfuser_index_t i21;
- i21 = 28 + T12.logical_size[3LL];
- nvfuser_index_t i22;
- i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+ i16 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
- nvfuser_index_t i23;
- i23 = 24 * T10.logical_size[2LL];
- nvfuser_index_t i24;
- i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
- nvfuser_index_t i25;
- i25 = (7 + i3) + i4;
- bool b26;
- b26 = i25 < i24;
- nvfuser_index_t i27;
- i27 = 28 * T10.logical_size[2LL];
- bool b28;
- b28 = i25 < (i27 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
- bool b29;
- b29 = i25 < (i27 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
- bool b30;
- b30 = i25 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i27);
- bool b31;
+ bool b17;
- b31 = i25 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
- bool b32;
- b32 = i25 < (i27 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
- bool b33;
- b33 = i25 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
- bool b34;
- b34 = i25 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+ b17 = ((7 + i3) + i4) < 32768;
- if ((((i3 + 7) + i4) < i24)) {
+ if ((((i3 + 7) + i4) < 32768)) {
Array<__bfloat, 8, 8> T54;
#pragma unroll
- for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+ for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
- nvfuser_index_t i36;
+ nvfuser_index_t i19;
- i36 = i5 + (i35 + nvfuser_zero);
+ i19 = i5 + (i18 + nvfuser_zero);
+ nvfuser_index_t i20;
+ i20 = i19 % i6;
+ nvfuser_index_t i21;
+ i21 = -4 + i20;
+ nvfuser_index_t i22;
+ i22 = i19 / i6;
+ bool b23;
+ b23 = (i21 >= 0) && (i21 < T18.logical_size[1LL]);
+ nvfuser_index_t i24;
+ i24 = i19 % i10;
+ nvfuser_index_t i25;
+ i25 = i19 / i10;
+ bool b26;
+ b26 = i24 < T10.logical_size[3LL];
- nvfuser_index_t i37;
+ nvfuser_index_t i27;
- i37 = i36 % i6;
+ i27 = i19 % i14;
- nvfuser_index_t i38;
+ nvfuser_index_t i28;
- i38 = -4 + i37;
+ i28 = -8 + i27;
- nvfuser_index_t i39;
- i39 = i36 % i10;
- nvfuser_index_t i40;
- i40 = -4 + i39;
- nvfuser_index_t i41;
- i41 = i36 % i11;
- nvfuser_index_t i42;
- i42 = i36 % i13;
- nvfuser_index_t i43;
- i43 = i36 % i15;
- nvfuser_index_t i44;
- i44 = i36 % i16;
- nvfuser_index_t i45;
- i45 = i36 % i18;
- nvfuser_index_t i46;
- i46 = -8 + i45;
- nvfuser_index_t i47;
- i47 = i36 % i19;
- nvfuser_index_t i48;
- i48 = -4 + i47;
- nvfuser_index_t i49;
- i49 = i36 % i21;
- nvfuser_index_t i50;
- i50 = -4 + i49;
__bfloat T27[1];
T27[0] = 0;
T27[0]
- = ((i38 >= 0) && (i38 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + i37)] : 0.0000e+00;
+ = b23 ? T26[(i21 + (T26.alloc_stride[0LL] * i22))] : 0.0000e+00;
__bfloat T28[1];
T28[0]
= T27[0];
__bfloat T52[1];
T52[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
T7[0]
- = ((i40 >= 0) && (i40 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + i39)] : 0.0000e+00;
+ = b23 ? T6[((i9 + i20) + (T6.alloc_stride[2LL] * i22))] : 0.0000e+00;
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
= __bfloat2float(T52[0]);
float T40[1];
T40[0]
= T38[0]
* T39[0];
__bfloat T15[1];
T15[0] = 0;
T15[0]
- = (i41 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + i41)] : 0.0000e+00;
+ = b26 ? T14[(i24 + (T14.alloc_stride[0LL] * i25))] : 0.0000e+00;
__bfloat T16[1];
T16[0]
= T15[0];
__bfloat T51[1];
T51[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
T11[0]
- = (i42 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + i42)] : 0.0000e+00;
+ = b26 ? T10[((i11 + (T10.alloc_stride[2LL] * i25)) + i24)] : 0.0000e+00;
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
= __bfloat2float(T51[0]);
float T36[1];
T36[0]
= T34[0]
* T35[0];
__bfloat T9[1];
T9[0] = 0;
T9[0]
- = (i43 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + i43)] : 0.0000e+00;
+ = b26 ? T8[((i12 + i24) + (T8.alloc_stride[2LL] * i25))] : 0.0000e+00;
float T30[1];
T30[0]
= __bfloat2float(T9[0]);
float T31[1];
T31[0]
= -T30[0];
__bfloat T23[1];
T23[0] = 0;
T23[0]
- = (i44 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + i44)] : 0.0000e+00;
+ = b26 ? T22[(i24 + (T22.alloc_stride[0LL] * i25))] : 0.0000e+00;
__bfloat T24[1];
T24[0]
= T23[0];
__bfloat T50[1];
T50[0]
= T24[0];
float T32[1];
T32[0]
= __bfloat2float(T50[0]);
float T33[1];
T33[0]
= T31[0]
* T32[0];
__bfloat T5[1];
T5[0] = 0;
T5[0]
- = ((i46 >= 0) && (i46 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + i45)] : 0.0000e+00;
+ = ((i28 >= 0) && (i28 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i19 / i14))) + i27)] : 0.0000e+00;
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
T19[0]
- = ((i48 >= 0) && (i48 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + i47)] : 0.0000e+00;
+ = b23 ? T18[((-4 + (T18.alloc_stride[0LL] * i22)) + i20)] : 0.0000e+00;
__bfloat T20[1];
T20[0]
= T19[0];
__bfloat T53[1];
T53[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
T13[0]
- = ((i50 >= 0) && (i50 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + i49)] : 0.0000e+00;
+ = b23 ? T12[((i15 + i20) + (T12.alloc_stride[2LL] * i22))] : 0.0000e+00;
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
= __bfloat2float(T53[0]);
float T43[1];
T43[0]
= T41[0]
* T42[0];
float T44[1];
T44[0]
= T40[0]
+ T43[0];
float T37[1];
T37[0]
= T33[0]
+ T36[0];
float T45[1];
T45[0]
= T37[0]
+ T44[0];
float T47[1];
T47[0]
= T45[0]
+ T46[0];
- T54[i35]
+ T54[i18]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
- loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+ loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
} else {
Array<__bfloat, 8, 8> T54;
#pragma unroll
- for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+ for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
+ nvfuser_index_t i29;
+ i29 = i5 + (i18 + nvfuser_zero);
+ nvfuser_index_t i30;
+ i30 = i29 % i6;
- nvfuser_index_t i51;
+ nvfuser_index_t i31;
- i51 = i5 + (i35 + nvfuser_zero);
+ i31 = -4 + i30;
- nvfuser_index_t i52;
+ nvfuser_index_t i32;
- i52 = i51 % i6;
+ i32 = i29 / i6;
+ bool b33;
+ b33 = (i31 >= 0) && (i31 < T18.logical_size[1LL]);
- nvfuser_index_t i53;
+ nvfuser_index_t i34;
- i53 = -4 + i52;
+ i34 = i29 % i10;
- nvfuser_index_t i54;
+ nvfuser_index_t i35;
- i54 = i51 % i10;
+ i35 = i29 / i10;
- nvfuser_index_t i55;
- i55 = -4 + i54;
- nvfuser_index_t i56;
- i56 = i51 % i11;
+ bool b36;
+ b36 = i34 < T10.logical_size[3LL];
- nvfuser_index_t i57;
+ nvfuser_index_t i37;
- i57 = i51 % i13;
+ i37 = i29 % i14;
- nvfuser_index_t i58;
+ nvfuser_index_t i38;
- i58 = i51 % i15;
- nvfuser_index_t i59;
- i59 = i51 % i16;
- nvfuser_index_t i60;
- i60 = i51 % i18;
- nvfuser_index_t i61;
- i61 = -8 + i60;
+ i38 = -8 + i37;
- nvfuser_index_t i62;
- i62 = i51 % i19;
- nvfuser_index_t i63;
- i63 = -4 + i62;
- nvfuser_index_t i64;
- i64 = i51 % i21;
- nvfuser_index_t i65;
- i65 = -4 + i64;
__bfloat T27[1];
T27[0] = 0;
- if (b26) {
+ if (b17) {
T27[0]
- = ((i53 >= 0) && (i53 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i51 / i6))) + i52)] : 0.0000e+00;
+ = b33 ? T26[(i31 + (T26.alloc_stride[0LL] * i32))] : 0.0000e+00;
}
__bfloat T28[1];
T28[0]
= T27[0];
__bfloat T52[1];
T52[0]
= T28[0];
__bfloat T7[1];
T7[0] = 0;
- if (b28) {
+ if (b17) {
T7[0]
- = ((i55 >= 0) && (i55 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i51 / i10))) + i54)] : 0.0000e+00;
+ = b33 ? T6[((i9 + i30) + (T6.alloc_stride[2LL] * i32))] : 0.0000e+00;
}
float T38[1];
T38[0]
= __bfloat2float(T7[0]);
float T39[1];
T39[0]
= __bfloat2float(T52[0]);
float T40[1];
T40[0]
= T38[0]
* T39[0];
__bfloat T15[1];
T15[0] = 0;
- if (b29) {
+ if (b17) {
T15[0]
- = (i56 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i51 / i11)) + i56)] : 0.0000e+00;
+ = b36 ? T14[(i34 + (T14.alloc_stride[0LL] * i35))] : 0.0000e+00;
}
__bfloat T16[1];
T16[0]
= T15[0];
__bfloat T51[1];
T51[0]
= T16[0];
__bfloat T11[1];
T11[0] = 0;
- if (b30) {
+ if (b17) {
T11[0]
- = (i57 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i51 / i13))) + i57)] : 0.0000e+00;
+ = b36 ? T10[((i11 + (T10.alloc_stride[2LL] * i35)) + i34)] : 0.0000e+00;
}
float T35[1];
T35[0]
= __bfloat2float(T11[0]);
float T34[1];
T34[0]
= __bfloat2float(T51[0]);
float T36[1];
T36[0]
= T34[0]
* T35[0];
__bfloat T9[1];
T9[0] = 0;
- if (b31) {
+ if (b17) {
T9[0]
- = (i58 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i51 / i15))) + i58)] : 0.0000e+00;
+ = b36 ? T8[((i12 + i34) + (T8.alloc_stride[2LL] * i35))] : 0.0000e+00;
}
float T30[1];
T30[0]
= __bfloat2float(T9[0]);
float T31[1];
T31[0]
= -T30[0];
__bfloat T23[1];
T23[0] = 0;
- if (b32) {
+ if (b17) {
T23[0]
- = (i59 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i51 / i16)) + i59)] : 0.0000e+00;
+ = b36 ? T22[(i34 + (T22.alloc_stride[0LL] * i35))] : 0.0000e+00;
}
__bfloat T24[1];
T24[0]
= T23[0];
__bfloat T50[1];
T50[0]
= T24[0];
float T32[1];
T32[0]
= __bfloat2float(T50[0]);
float T33[1];
T33[0]
= T31[0]
* T32[0];
__bfloat T5[1];
T5[0] = 0;
- if (b33) {
+ if (b17) {
T5[0]
- = ((i61 >= 0) && (i61 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i51 / i18))) + i60)] : 0.0000e+00;
+ = ((i38 >= 0) && (i38 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i29 / i14))) + i37)] : 0.0000e+00;
}
float T46[1];
T46[0]
= __bfloat2float(T5[0]);
__bfloat T19[1];
T19[0] = 0;
- if (b34) {
+ if (b17) {
T19[0]
- = ((i63 >= 0) && (i63 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i51 / i19))) + i62)] : 0.0000e+00;
+ = b33 ? T18[((-4 + (T18.alloc_stride[0LL] * i32)) + i30)] : 0.0000e+00;
}
__bfloat T20[1];
T20[0]
= T19[0];
__bfloat T53[1];
T53[0]
= T20[0];
__bfloat T13[1];
T13[0] = 0;
- if (b31) {
+ if (b17) {
T13[0]
- = ((i65 >= 0) && (i65 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i51 / i21))) + i64)] : 0.0000e+00;
+ = b33 ? T12[((i15 + i30) + (T12.alloc_stride[2LL] * i32))] : 0.0000e+00;
}
float T42[1];
T42[0]
= __bfloat2float(T13[0]);
float T41[1];
T41[0]
= __bfloat2float(T53[0]);
float T43[1];
T43[0]
= T41[0]
* T42[0];
float T44[1];
T44[0]
= T40[0]
+ T43[0];
float T37[1];
T37[0]
= T33[0]
+ T36[0];
float T45[1];
T45[0]
= T37[0]
+ T44[0];
float T47[1];
T47[0]
= T45[0]
+ T46[0];
- T54[i35]
+ T54[i18]
= __float2bfloat(T47[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
- if ((i25 < 32768)) {
+ if (b17) {
- loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+ loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
}
}
} I think there are simpler conditions for the padding predicates, and they're getting hoisted. With static shapes, this runtime can be improved further to about 8 us. BTW looking at larger problem size (changing head size from 32 to 256 and bsz from 2 to 256), we have 28.1 ms vs 10.7 ms (2.6x speedup similar to the smaller problem size). |
patch this from thunder side per our earlier conversation: Lightning-AI/lightning-thunder#1096 |
After looking at this more closely, I believe the root of this problem is that our index simplification cannot deduce that e.g. i19 = i5 + (i18 + nvfuser_zero);
nvfuser_index_t i20;
i20 = i19 % i6;
nvfuser_index_t i21;
i21 = -4 + i20;
nvfuser_index_t i22;
i22 = i19 / i6; Notice that we should already be able to prove that I am not sure yet whether it would enough for us to provide proof that vec size divides some of the extents, since in this case there are slices and pads so we get the constant offset |
Does |
I believe so. We segment into 6 NoOp segments that do slices and one Pointwise that has those sliced inputs as the intermediates. |
Maybe we could take all the vectorizable inputs and outputs and assume that the product of the inner contiguous extents up to the breakpoint is divisible by the vectorization. After that I'm not sure yet how we should represent divisibility in a way that we can use it in |
We do obtain that information but I don't think it's used by the expr simplifier. https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/divisible_split.cpp#L23 |
Thinking ahead a bit, I think we could do this with an "e-graph analysis" in a similar way to how constant folding is done with e-graphs. We would build an e-graph of integer scalars and merge exact mapped extents and cache it during indexing similar to how we cache proofs currently in the simplification Context. Each e-class would hold an unordered set of e-classes that are its known (overlapping) factors. When merging e-classes we would union their factor sets. We would also union when propagating through multiplication. We would propagate through addition and subtraction by intersecting the factor sets. Then in simplifyExpr we can query this data structure for common divisors between two scalars in O(n) and check if any scalar divides another in O(1) time. |
Ah yes. I was thinking that only did reshape splits but it also does vectorized. That's great. |
The re-written rope example has quite different indexing when input
q
/cos
/sin
is defined with static or dynamic shapes.I think this is coming from the inconsistent fusion definition. i.e. when we switch to have inputs defined with dynamic shape, the follow up
slice
operations aren't using the symbolic slice extent, so we cannot collapse indexing after the slice.q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
Currently we python API only allows static number at this point, but no
Val*
yet. This would also require definition in lowering to be updated as well.I'm opening this issue for myself. I think it's worth re-writing the script below to see if we can get full perf back with symbolic shape when the proper definition is produced.
The text was updated successfully, but these errors were encountered: