diff --git a/CodingConventions.md b/CodingConventions.md index d1806c4d0..52348b672 100644 --- a/CodingConventions.md +++ b/CodingConventions.md @@ -588,3 +588,10 @@ changes) while working on a feature and even in "WIP" pull requests, as long as the pieces are recombined (e.g., through an interactive rebase) into logical units when the feature is ready for merging. Force-pushing in PR branches is fine. + +Coding Conventions for writing Tensor Comprehensions +==================================================== + +Please see the following documentation +[entry](https://facebookresearch.github.io/TensorComprehensions/coding_conventions.html) +on how to write Tensor Comprehensions in a standard legible fashion. diff --git a/benchmarks/MLP_model.cc b/benchmarks/MLP_model.cc index bd7676294..366a12f5a 100644 --- a/benchmarks/MLP_model.cc +++ b/benchmarks/MLP_model.cc @@ -64,23 +64,23 @@ DEFINE_uint32(Q, 2, "W4_h"); // float(E1, D) LUT1, int32(B, L1) I1, // float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) // { -// O1(i, j) +=! LUT1(I1(i, k), j) -// O2(i, j) +=! LUT2(I2(i, k), j) +// O1(b, d) +=! LUT1(I1(b, r_l1), d) +// O2(b, d) +=! LUT2(I2(b, r_l2), d) // } // def _3FCRELU( // float(B,M) I, float(O,N) W2, float(O) B2, // float(P,O) W3, float(P) B3, float(Q,P) W4, // float(Q) B4) -> (O1, O2, O3, O4) // { -// O2(b, o) = B2(o) -// O2(b, o) += O1(b, n) * W2(o, n) -// O2(b, o) = fmax(O2(b, o), 0) -// O3(b, p) = B3(p) -// O3(b, p) += O2(b, o) * W3(p, o) -// O3(b, p) = fmax(O3(b, p), 0) -// O4(b, q) = B4(q) -// O4(b, q) += O3(b, p) * W4(q, p) -// O4(b, q) = fmax(O4(b, q), 0) +// O2(b, o) = B2(o) +// O2(b, o) += O1(b, n) * W2(o, n) +// O2(b, o) = fmax(O2(b, o), 0) +// O3(b, p) = B3(p) +// O3(b, p) += O2(b, o) * W3(p, o) +// O3(b, p) = fmax(O3(b, p), 0) +// O4(b, q) = B4(q) +// O4(b, q) += O3(b, p) * W4(q, p) +// O4(b, q) = fmax(O4(b, q), 0) // } // def prod_model(float(E1, D) LUT1, int32(B, L1) I1, // float(E2, D) LUT2, int32(B, L2) I2, @@ -91,15 +91,15 @@ DEFINE_uint32(Q, 2, "W4_h"); // float(Q,P) W4, float(Q) B4) // -> (C1, C2, C3, I, O1, O2, O3, O4) // { -// (C1, C2) = _2LUT(LUT1, I1, LUT2, I2) -// C3(b, wy) += I3(b, wxx) * W(wy, wxx) -// I(b, m) = Concat(C1, C2, C3) // not in TC atm -// O1(b, n) = B1(n) -// O1(b, n) += I(b, m) * W1(n, m) -// O1(b, n) = fmax(O1(b, n), 0) +// (C1, C2) = _2LUT(LUT1, I1, LUT2, I2) +// C3(b, wy) +=! I3(b, r_wx) * W(wy, r_wx) +// I(b, m) = Concat(C1, C2, C3) // not in TC atm +// O1(b, n) = B1(n) +// O1(b, n) +=! I(b, m) * W1(n, m) +// O1(b, n) = fmax(O1(b, n), 0) // (O2, O3, O4) = -// _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4) -// # O4 goes out to binary classifier, omitted here +// _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4) +// # O4 goes out to binary classifier, omitted here // } class ProductionModel : public Benchmark { @@ -191,9 +191,9 @@ void ProductionModel::run1LUT( std::vector inputs = {LUT1, IDX1}; std::string tc = R"( - def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) { - O1(i, j) +=! LUT1(I1(i, k), j) - } +def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) { + O1(b, d) +=! LUT1(I1(b, r_l1), d) +} )"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + @@ -294,10 +294,10 @@ void ProductionModel::run2LUT( std::vector inputs = {LUT1, IDX1, LUT2, IDX2}; std::string tc = R"( - def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) { - O1(i, j) +=! LUT1(I1(i, k), j) - O2(i, j) +=! LUT2(I2(i, k), j) - } +def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) { + O1(b, d) +=! LUT1(I1(b, r_l1), d) + O2(b, d) +=! LUT2(I2(b, r_l2), d) +} )"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + @@ -353,9 +353,9 @@ void ProductionModel::runC3( std::vector inputs = {I, W}; std::string tc = R"TC( - def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) { - C3(b, wy) +=! I(b, wxx) * W(wy, wxx) - } +def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) { + C3(b, wy) +=! I(b, r_wx) * W(wy, r_wx) +} )TC"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + @@ -408,11 +408,11 @@ void ProductionModel::runMLP1( std::vector inputs = {I, W1, B1}; std::string tc = R"TC( - def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) { - O1(b, n) +=! I(b, mm) * W1(mm, n) - O1(b, n) = O1(b, n) + B1(n) - O1(b, n) = fmax(O1(b, n), 0) - } +def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) { + O1(b, n) +=! I(b, r_m) * W1(r_m, n) + O1(b, n) = O1(b, n) + B1(n) + O1(b, n) = fmax(O1(b, n), 0) +} )TC"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + @@ -474,17 +474,17 @@ void ProductionModel::runMLP3( std::vector inputs = {I, W2, B2, W3, B3, W4, B4}; std::string tc = R"TC( - def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) { - O2(b, o) +=! I(b, n) * W2(o, n) - O2(b, o) = O2(b, o) + B2(o) - O2(b, o) = fmax(O2(b, o), 0) +def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) { + O2(b, o) +=! I(b, n) * W2(o, n) + O2(b, o) = O2(b, o) + B2(o) + O2(b, o) = fmax(O2(b, o), 0) O3(b, p) +=! O2(b, o) * W3(p, o) - O3(b, p) = O3(b, p) + B3(p) - O3(b, p) = fmax(O3(b, p), 0) + O3(b, p) = O3(b, p) + B3(p) + O3(b, p) = fmax(O3(b, p), 0) O4(b, q) +=! O3(b, p) * W4(q, p) - O4(b, q) = O4(b, q) + B4(q) - O4(b, q) = fmax(O4(b, q), 0) - } + O4(b, q) = O4(b, q) + B4(q) + O4(b, q) = fmax(O4(b, q), 0) +} )TC"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + diff --git a/benchmarks/batchmatmul.cc b/benchmarks/batchmatmul.cc index 2bfd9fb0d..186363649 100644 --- a/benchmarks/batchmatmul.cc +++ b/benchmarks/batchmatmul.cc @@ -76,9 +76,9 @@ void BatchMatMul::runBatchMatMul( std::vector inputs = {X, Y}; std::string tc = R"( - def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) - } +def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) +} )"; std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) + diff --git a/benchmarks/group_convolution.cc b/benchmarks/group_convolution.cc index 9b41fa7fb..38fe8b2b4 100644 --- a/benchmarks/group_convolution.cc +++ b/benchmarks/group_convolution.cc @@ -122,13 +122,13 @@ void GroupConvolution::runGroupConvolution( .resize_({G, F}); std::vector inputs = {tI, tW, tB}; std::string tc = R"( - def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) - -> (O) - { +def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) +-> (O) +{ O(n, g, f, h, w) +=! - I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw) - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) - } + I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) +} )"; std::string suffix = std::string("_N_") + std::to_string(FLAGS_N) + diff --git a/benchmarks/tmm.cc b/benchmarks/tmm.cc index 79610efc7..03b8884ab 100644 --- a/benchmarks/tmm.cc +++ b/benchmarks/tmm.cc @@ -73,9 +73,9 @@ void TransposedMatMul::runTransposedMatMul( std::vector inputs = {A, B}; std::string tc = R"TC( - def tmm(float(M,K) A, float(N,K) B) -> (C) { - C(m, n) +=! A(m, kk) * B(n, kk) - } +def tmm(float(M,K) A, float(N,K) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(n, r_k) +} )TC"; std::string suffix = std::string("_M_") + std::to_string(FLAGS_M) + diff --git a/docs/doxygen/index.md b/docs/doxygen/index.md index 47adbbfeb..24a9f93d3 100644 --- a/docs/doxygen/index.md +++ b/docs/doxygen/index.md @@ -13,7 +13,7 @@ with a few basic functionalities. Tensor Comprehension Notation ----------------------------- -TC borrow three ideas from Einstein notation that make expressions concise: +TC borrows three ideas from Einstein notation that make expressions concise: 1. Loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index. 2. Indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions. @@ -21,22 +21,22 @@ TC borrow three ideas from Einstein notation that make expressions concise: Let's start with a simple example is a matrix vector product: - def mv(float(R,C) A, float(C) B) -> (o) { - o(i) +=! A(i,j) * B(j) + def mv(float(R,C) A, float(C) x) -> (o) { + o(r) +=! A(r,r_c) * x(r_c) } `A` and `x` are input tensors. `o` is an output tensor. -The statement `o(i) += A(i,j) * b(j)` introduces two index variables `i` and `j`. -Their range is inferred by their use indexing `A` and `B`. `i = [0,R)`, `j = [0,C)`. -Because `j` only appears on the right side, -stores into `o` will reduce over `j` with the reduction specified for the loop. +The statement `o(r) +=! A(r,r_c) * x(r_c)` introduces two index variables `r` and `r_c`. +Their range is inferred by their use indexing `A` and `x`. `r = [0,R)`, `r_c = [0,C)`. +Because `r_c` only appears on the righthand side, +stores into `o` will reduce over `r_c` with the reduction specified for the loop. Reductions can occur across multiple variables, but they all share the same kind of associative reduction (e.g. +=) to maintain invariant (3). `mv` computes the same thing as this C++ loop: for(int i = 0; i < R; i++) { o(i) = 0.0f; for(int j = 0; j < C; j++) { - o(i) += A(i,j) * B(j); + o(i) += A(i,j) * x(j); } } @@ -50,7 +50,7 @@ We provide a few basic examples. **Simple matrix-vector**: def mv(float(R,C) A, float(C) B) -> (o) { - o(i) += A(i,j) * B(j) + o(r) +=! A(r,r_c) * B(r_c) } **Simple matrix-multiply:** @@ -59,21 +59,20 @@ Note the layout for B is transposed and matches the traditional layout of the weight matrix in a linear layer): def mm(float(X,Y) A, float(Y,Z) B) -> (R) { - R(i,j) += A(i,j) * B(j,k) + R(x,z) +=! A(x,r_y) * B(r_y,z) } **Simple 2-D convolution (no stride, no padding):** def conv(float(B,IP,H,W) input, float(OP,IP,KH,KW) weight) -> (output) { - output(b, op, h, w) += input(b, ip, h + kh, w + kw) * weight(op, ip, kh, kw) + output(b, op, h, w) +=! input(b, r_ip, h + r_kh, w + r_kw) * weight(op, r_ip, r_kh, r_kw) } **Simple 2D max pooling:** -Note the similarity with a convolution with a -"select"-style kernel): +Note the similarity with a convolution with a "select"-style kernel: def maxpool2x2(float(B,C,H,W) input) -> (output) { - output(b,c,i,j) max= input(b,c,2*i + kw, 2*j + kh) - where kw = [0, 2[, kh = [0, 2[ + output(b,c,h,w) max=! input(b,c,2*h + r_kw, 2*w + r_kh) + where r_kw in 0:2, r_kh in 0..2 } diff --git a/docs/source/coding_conventions.rst b/docs/source/coding_conventions.rst new file mode 100644 index 000000000..3031713d9 --- /dev/null +++ b/docs/source/coding_conventions.rst @@ -0,0 +1,120 @@ +Coding Conventions +================== + +In order to increase readability across Tensor Comprehensions written by +multiple authors and to reduce the amount of surprising behavior, the +following conventions should be adopted when writing TC. Generally in TC, one +should increment nesting by 4 whitespaces at each level and align tensor names +and indices where appropriate to make memory access patterns emerge. Since +these two goals can easily be conflicting, use your best judgement to tradeoff +between the two goals. Such examples are provided below. + +Use indices named after parameters +---------------------------------- + +Use upper-case names for parameters and capital-case names for input/output tensors. +Use lower-case names for indices to match the name of the parameter +corresponding to the dimension upon which they iterate. +In other words, prefer: + +.. code:: + + def copy2d(float(M, N) I) -> (O) { + O(m, n) = I(m, n) + } + +to: + +.. code:: + + def copy2d(float(M, N) I) -> (O) { + O(i, j) = I(i, j) + } + +Prefix reduction index names with :code:`r_` +-------------------------------------------- + +By definition, reduction indices are the ones that appear on the RHS of a TC +expression but not on the LHS. On larger expressions it can get challenging to easily +detect the reduction variables by mentally parsing the set of indices on the +RHS and subtracting the set of indices on the LHS from it. To alleviate such +issues, name the reduction variables with a :code:`r_` prefix. +In other words, prefer: + +.. code:: + + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) + } + +to: + +.. code:: + + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, k) * B(k, n) + } + +Filter non-rectangular regions with data-dependencies +----------------------------------------------------- + +TC semantics are restricted to (hyper-)rectangular iteration spaces. +This is a hard requirement to ensure range inference is non-ambiguous (see inference_). +To simulate non-rectangular iteration spaces, one can use the following: + +.. code:: + + def matmul(float(M, K) L, float(K, M) U) -> (LU) { + LU(m1, m2) +=! (r_k >= m1 and r_k =< m2) ? L(m1, r_k) * U(r_k, m2) : 0 + } + +However, non-(hyper)-rectangular iteration spaces (e.g. triangular) are +incompatible with range inference and will fail the semantic checks in the TC +compiler: + +.. code:: + + def matmul(float(M, K) L, float(K, M) U) -> (LU) { + LU(m1, m2) +=! L(m1, r_k) * U(r_k, m2) where r_k in m1:M, r_k in 0:m2+1 + } + +The reader may remark that this is an inefficient way of writing +matrix-multiplication of triangular matrices. +Lowering such operations efficiently from TC is the subject of future work. + +Prefix gradient tensors names with :code:`d_` +--------------------------------------------- + +When implementing backward operations, pass the inputs to the backwards pass +in the same order as the outputs of the forward pass and use the same tensor +name prefixed by :code:`d_`. For instance: + +.. code:: + + def conv(float(N,C,H,W) I, float(M,C,KH,KW) Wt) -> (O) { + ... + } + + def conv_bw(float(N,C,H,W) I, float(M,C,KH,KW) Wt, float(N,M,HO,WO) d_O) -> (d_I) { + ... + } + +A more complex example +---------------------- + +The following shows a possible implementation for a more complex forward and +backward example. Notice the proper alignment of indices in the backward pass +and the emergence of an antidiagonal pattern in the reduction accesses: + +.. code:: + + def matmul(float(M,K) A, float(K,N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) + } + def matmul_bw(float(M,K) A, float(K,N) B, float(M,N) d_C) -> (d_A, d_B){ + d_A(m, k) +=! d_C( m, r_n) * B( k, r_n) + d_B(k, n) +=! d_C(r_m, n) * A(r_m, k) + } + +Reasoning on such reduction patterns at the level of TC has already proven +valuable in other circumstances. diff --git a/docs/source/framework/caffe2_integration/integration_with_example.rst b/docs/source/framework/caffe2_integration/integration_with_example.rst index a4773c762..0b948feb5 100644 --- a/docs/source/framework/caffe2_integration/integration_with_example.rst +++ b/docs/source/framework/caffe2_integration/integration_with_example.rst @@ -50,8 +50,8 @@ For demonstration purpose, we will pick a simple example for :code:`matmul` laye dyndep.InitOpsLibrary(os.path.join(os.environ.get("CONDA_PREFIX"), "lib/libtc_c2.so")) lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ mat1, mat2 = np.random.rand(100, 400), np.random.rand(400, 500) @@ -68,4 +68,4 @@ Future ------ The integration with Caffe2 is very basic at the moment. We do not provide autotuner -support for Caffe2 and welcome contributions from community. +support for Caffe2 at the moment and welcome contributions from the community. diff --git a/docs/source/framework/pytorch_integration/autograd_with_tc.rst b/docs/source/framework/pytorch_integration/autograd_with_tc.rst index d1d5ab047..b59ec0e5e 100644 --- a/docs/source/framework/pytorch_integration/autograd_with_tc.rst +++ b/docs/source/framework/pytorch_integration/autograd_with_tc.rst @@ -27,11 +27,11 @@ Examples from torch.nn.parameter import Parameter CONV_LANG = """ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{ - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) }} - def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) -> (I_grad, W1_grad) {{ - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) + def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{ + d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) }} """ N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1 @@ -66,11 +66,11 @@ them, the example for that would be: from torch.nn.parameter import Parameter CONV_LANG = """ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{ - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) }} - def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) -> (I_grad, W1_grad) {{ - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) + def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{ + d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) }} """ N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1 @@ -100,11 +100,11 @@ Let's see how to cache options to file when we tune a training layer. import torch CONV_LANG = """ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{ - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) }} - def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) -> (I_grad, W1_grad) {{ - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) + def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{ + d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) }} """ N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1 @@ -133,14 +133,14 @@ the example below for how to use it: import torch LANG = """ def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (tmp, O) { - tmp(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw) + tmp(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) O(n, m, h, w) = tmp(n, m, h, w) + B(m) } - def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B, float(N, M, H, W) O_grad) - -> (I_grad, W1_grad, B_grad) { - I_grad(n, c, h, w) +=! O_grad(n, m, h - kh, w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, h - kh, w - kw) * I(n, c, h, w) - B_grad(m) +=! O_grad(n, m, h, w) + def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B, float(N, M, H, W) d_O) + -> (d_I, d_W1, d_B) { + d_I(n, c, h, w) +=! d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w) + d_B(m) +=! d_O(n, m, h, w) } """ diff --git a/docs/source/framework/pytorch_integration/autotuning_layers.rst b/docs/source/framework/pytorch_integration/autotuning_layers.rst index 502e49037..4574e4f85 100644 --- a/docs/source/framework/pytorch_integration/autotuning_layers.rst +++ b/docs/source/framework/pytorch_integration/autotuning_layers.rst @@ -24,8 +24,8 @@ An example demonstrating each step above is: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -108,8 +108,8 @@ An example for how to pass options: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -134,8 +134,8 @@ argument to the autotuning call. There are two ways of caching the tuned options import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -151,8 +151,8 @@ argument to the autotuning call. There are two ways of caching the tuned options import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -182,8 +182,8 @@ For example: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -207,8 +207,8 @@ For example: import tensor_comprehensions as tc lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") @@ -237,8 +237,8 @@ Below is example describing the above usage: import tensor_comprehensions as tc cache = "{}/matmul_3_4_5".format(PATH_PREFIX) lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") diff --git a/docs/source/framework/pytorch_integration/frequently_asked_questions.rst b/docs/source/framework/pytorch_integration/frequently_asked_questions.rst index e2a2ebc19..8e4f5026f 100644 --- a/docs/source/framework/pytorch_integration/frequently_asked_questions.rst +++ b/docs/source/framework/pytorch_integration/frequently_asked_questions.rst @@ -20,10 +20,10 @@ as input not output. .. code:: def softmax(float(N, D) I) -> (O, maxVal, expDistance) { - maxVal(n) max= I(n, d) - expDistance(n, d) = exp(I(n, d) - maxVal(n)) - expSum(n) +=! expDistance(n, d) - O(n, d) = expDistance(n, d) / expSum(n) + maxVal(n) max=! I(n, d) + expDistance(n, d) = exp(I(n, d) - maxVal(n)) + expSum(n) +=! expDistance(n, d) + O(n, d) = expDistance(n, d) / expSum(n) } **Valid TC** @@ -33,15 +33,15 @@ The correct TC would be: .. code:: def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum) { - maxVal(n) max= I(n, d) - expDistance(n, d) = exp(I(n, d) - maxVal(n)) - expSum(n) +=! expDistance(n, d) - O(n, d) = expDistance(n, d) / expSum(n) + maxVal(n) max=! I(n, d) + expDistance(n, d) = exp(I(n, d) - maxVal(n)) + expSum(n) +=! expDistance(n, d) + O(n, d) = expDistance(n, d) / expSum(n) } Can I re-use a temporary variable? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You can as long as the tensor dependencies are strictly DAG. For example: +You can as long as the tensor dependencies form a DAG. For example: **Invalid** @@ -54,19 +54,18 @@ You can as long as the tensor dependencies are strictly DAG. For example: O(n, d) = O(n, d) / tmp(n) } -This TC is invalid because :code:`tmp` and :code:`O(n, d)` have cyclic dependency. +This TC is invalid because :code:`tmp` and :code:`O(n, d)` have a cyclic dependency. **Valid** .. code:: def softmax(float(N, D) I) -> (O, expsum, maxVal) { - maxVal(n) max= I(n, d) + maxVal(n) max=! I(n, d) expsum(n) +=! exp(I(n, d) - maxVal(n)) O(n, d) = exp(I(n, d) - maxVal(n)) / expsum(n) } - Autotuner --------- diff --git a/docs/source/framework/pytorch_integration/getting_started.rst b/docs/source/framework/pytorch_integration/getting_started.rst index 3d526bd43..50f6f1d53 100644 --- a/docs/source/framework/pytorch_integration/getting_started.rst +++ b/docs/source/framework/pytorch_integration/getting_started.rst @@ -70,8 +70,8 @@ For demonstration purpose, we will pick a simple example for :code:`matmul` laye import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M,K) A, float(N,K) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(n, r_k) } """ matmul = tc.define(lang, name="matmul") diff --git a/docs/source/framework/pytorch_integration/layers_database.rst b/docs/source/framework/pytorch_integration/layers_database.rst index 3f798f154..e577deb84 100644 --- a/docs/source/framework/pytorch_integration/layers_database.rst +++ b/docs/source/framework/pytorch_integration/layers_database.rst @@ -17,11 +17,11 @@ An example to do so: .. code-block:: python - import tensor_comprehensions as tc - import torch - matmul = tc.define(tc.database['matmul']['lang'], name='matmul') - mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() - out = matmul(mat1, mat2) + import tensor_comprehensions as tc + import torch + matmul = tc.define(tc.database['matmul']['lang'], name='matmul') + mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() + out = matmul(mat1, mat2) Pooling Layers @@ -32,8 +32,9 @@ Average pooling .. code:: - def avgpool(float(B, C, H, W) input) -> (output) {{ - output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW}) where kh in 0:{kH}, kw in 0:{kW} + def avgpool(float(B, C, H, W) Input) -> (Output) {{ + Output(b, c, h, w) +=! Input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW}) + where r_kh in 0:{kH}, r_kw in 0:{kW} }} @@ -42,8 +43,9 @@ Max pooling .. code:: - def maxpool(float(B, C, H, W) input) -> (output) {{ - output(b, c, h, w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} + def maxpool(float(B, C, H, W) Input) -> (Output) {{ + Output(b, c, h, w) max=! Input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) + where r_kh in 0:{kH}, r_kw in 0:{kW} }} Convolution layers @@ -55,8 +57,8 @@ Simple Convolution .. code:: def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) { - O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw) - O(n, m, h, w) = O(n, m, h, w) + B(m) + O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) + O(n, m, h, w) = O(n, m, h, w) + B(m) } Strided Convolution @@ -65,8 +67,8 @@ Strided Convolution .. code:: def convolution_strided(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {{ - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) - O(n, m, h, w) = O(n, m, h, w) + B(m) + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) + O(n, m, h, w) = O(n, m, h, w) + B(m) }} Strided Convolution Gradient @@ -74,9 +76,9 @@ Strided Convolution Gradient .. code:: - def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) -> (I_grad, W1_grad) {{ - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) + def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) d_O) -> (d_I, d_W1) {{ + d_I(n, c, h, w) +=! d_O(n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) }} Simple Group Convolution @@ -85,8 +87,8 @@ Simple Group Convolution .. code:: def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) { - O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw) - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) + O(n, g, f, h, w) +=! I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) } Group Convolution Strided @@ -95,7 +97,7 @@ Group Convolution Strided .. code:: def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{ - O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw) + O(n, g, f, h, w) +=! I(n, g, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(g, f, r_c, r_kh, r_kw) O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) }} @@ -108,7 +110,7 @@ Fully Connected layer .. code:: def fully_connected(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) { - O1(b, n) +=! I(b, m) * W1(n, m) + O1(b, n) +=! I(b, r_m) * W1(n, r_m) O1(b, n) = O1(b, n) + B1(n) } @@ -138,11 +140,11 @@ Softmax .. code:: - def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum) { - maxVal(n) max=! I(n, d) - expDistance(n, d) = exp(I(n, d) - maxVal(n)) - expSum(n) +=! expDistance(n, d) - O(n, d) = expDistance(n, d) / expSum(n) + def softmax(float(N, D) I) -> (O, MaxVal, ExpDistance, ExpSum) { + MaxVal(n) max=! I(n, d) + ExpDistance(n, d) = exp(I(n, d) - MaxVal(n)) + ExpSum(n) +=! ExpDistance(n, d) + O(n, d) = ExpDistance(n, d) / ExpSum(n) } Tanh @@ -172,7 +174,7 @@ TensorDot .. code:: def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) } Matmul @@ -180,8 +182,8 @@ Matmul .. code:: - def matmul(float(M, N) A, float(N, K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } Matmul Gradient @@ -189,9 +191,9 @@ Matmul Gradient .. code:: - def matmul_grad(float(M, N) A, float(N, K) B) -> (output) { - A_grad(i, j) +=! O_grad(i, kk) * B(j, kk) - B_grad(i, j) +=! O_grad(kk, j) * A(kk, i) + def matmul_bw(float(M,K) A, float(K,N) B, float(M,N) d_C) -> (d_A, d_B){ + d_A(m, k) +=! d_C( m, r_n) * B( k, r_n) + d_B(k, n) +=! d_C(r_m, n) * A(r_m, k) } Batch Matmul @@ -200,7 +202,7 @@ Batch Matmul .. code:: def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) } Absolute @@ -217,8 +219,8 @@ Add .. code:: - def add(float(N) A, float(N) B) -> (output) { - output(i) = A(i) + B(i) + def add(float(N) A, float(N) B) -> (Output) { + Output(n) = A(n) + B(n) } Tensor Operations @@ -229,8 +231,8 @@ Indexing .. code:: - def indexing(float(H, W) input, int32(L) index) -> (output) {{ - output(l, w) = input(index(l), w) where l in 0:{L} + def indexing(float(H, W) Input, int32(L) Index) -> (Output) {{ + Output(l, w) = Input(Index(l), w) }} Lookup Table @@ -239,7 +241,7 @@ Lookup Table .. code:: def lut(float(B, R) LUT, int32(B, N) I) -> (O) { - O(b, n) +=! LUT(I(b, n), r) + O(b, n) +=! LUT(I(b, n), r_r) } Transpose @@ -275,7 +277,7 @@ Copy .. code:: def copy(float(M, N) I) -> (O) { - O(i, j) = I(i, j) + O(m, n) = I(m, n) } Scale @@ -296,9 +298,9 @@ FCRelu .. code:: def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1){ - O1(b, n) +=! I(b, m) * W1(n, m) - O1(b, n) = O1(b, n) + B1(n) - O1(b, n) = fmax(O1(b, n), 0) + O1(b, n) +=! I(b, r_m) * W1(n, r_m) + O1(b, n) = O1(b, n) + B1(n) + O1(b, n) = fmax(O1(b, n), 0) } Small MobileNet @@ -308,12 +310,12 @@ Small MobileNet def small_mobilenet(float(C1, H, W) I, float(C1, KH1, KW1) W1, float(C1) B1, float(C2, C1) W2, float(C2) B2) -> (O1, O2) { - O1(c1, h, w) +=! I(c1, h + kh, w + kw) * W1(c1, kh, kw) - O1(c1, h, w) = O1(c1, h, w) + B1(c1) + O1(c1, h, w) +=! I(c1, h + r_kh, w + r_kw) * W1(c1, r_kh, r_kw) + O1(c1, h, w) = O1(c1, h, w) + B1(c1) O1(c1, h, w) = fmax(O1(c1, h, w), 0) - O2(c2, h, w) +=! O1(c1, h, w) * W2(c2, c1) - O2(c2, h, w) = O2(c2, h, w) + B2(c2) + O2(c2, h, w) +=! O1(r_c1, h, w) * W2(c2, r_c1) + O2(c2, h, w) = O2( c2, h, w) + B2(c2) O2(c2, h, w) = fmax(O2(c2, h, w), 0) } @@ -325,17 +327,17 @@ Batch Normalization .. code:: - def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) - -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) + def batchnorm(float(N,C,H,W) I, float(C) RMeanIn, float(C) RVarIn) + -> (O, RMeanOut, RVarOut, Mean, Centered, Variance, ExpectedVariance, normalizedOut) {{ - mean(c) +=! I(nn, c, hh, ww) - mean(c) = mean(c) / (N * H * W) - rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) - centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) - variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) - expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) - rVarOut(c) = rsqrt((1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) - O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + Mean(c) +=! I(nn, c, hh, ww) + Mean(c) = Mean(c) / (N * H * W) + RMeanOut(c) = (1 - {momentum}) * RMeanIn(c) + {momentum} * Mean(c) + Centered(n, c, h, w) = I(n, c, h, w) - RMeanOut(c) + Variance(n, c, h, w) = Centered(n, c, h, w) * Centered(n, c, h, w) + ExpectedVariance(c) +=! (Variance(n, c, h, w) + {eps}) / (N * H * W) + RVarOut(c) = rsqrt((1 - {momentum}) * RVarIn(c) + {momentum} * ExpectedVariance(c)) + O(n, c, h, w) = Centered(n, c, h, w) * RVarOut(c) normalizedOut(n, c, h, w) = O(n, c, h, w) }} @@ -344,12 +346,12 @@ Layer Normalization .. code:: - def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {{ - mean(t, b) +=! I(t, b, c) / C - centered(t, b, c) = I(t, b, c) - mean(t, b) - var(t, b) +=! centered(t, b, c) * centered(t, b, c) - var(t, b) = (var(t, b) + {eps}) / C - O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) + def layernorm(float(T, B, C) I) -> (O, Mean, Centered, Var) {{ + Mean(t, b) +=! I(t, b, c) / C + Centered(t, b, c) = I(t, b, c) - Mean(t, b) + Var(t, b) +=! Centered(t, b, c) * Centered(t, b, c) + Var(t, b) = (Var(t, b) + {eps}) / C + O(t, b, c) = Centered(t, b, c) / rsqrt(Var(t, b)) }} Distance Functions @@ -360,10 +362,10 @@ Cosine Similarity .. code:: - def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2) {{ - sumI1(m) +=! I1(m, n) * I1(m, n) - sumI2(m) +=! I2(m, n) * I2(m, n) - O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) + def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, SumI1, SumI2) {{ + SumI1(m) +=! I1(m, n) * I1(m, n) + SumI2(m) +=! I2(m, n) * I2(m, n) + O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(SumI1(m)) * sqrt(SumI2(m)), {eps}) }} What operations can not be expressed diff --git a/docs/source/framework/pytorch_integration/note_about_performance.rst b/docs/source/framework/pytorch_integration/note_about_performance.rst index dc8ba34be..3b8a97992 100644 --- a/docs/source/framework/pytorch_integration/note_about_performance.rst +++ b/docs/source/framework/pytorch_integration/note_about_performance.rst @@ -18,8 +18,8 @@ argument when you run the TC. For a concrete example: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(lang, name="matmul") diff --git a/docs/source/framework/pytorch_integration/writing_layers.rst b/docs/source/framework/pytorch_integration/writing_layers.rst index 1a7f3e751..e02f035b0 100644 --- a/docs/source/framework/pytorch_integration/writing_layers.rst +++ b/docs/source/framework/pytorch_integration/writing_layers.rst @@ -22,8 +22,8 @@ An example demonstrating each step above is: import tensor_comprehensions as tc import torch MATMUL_LANG = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } """ # the `name` should match the definition name in the `lang` @@ -72,8 +72,8 @@ An example for how to pass options: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(lang, name="matmul") @@ -107,8 +107,8 @@ of input sizes, you need to define TC once. An example: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(lang, name="matmul") @@ -138,11 +138,11 @@ definition and get the TC layer for it. Below is an example for how to do this: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } def abs(float(M, N) A) -> (O1) { - O1(m, n) = fabs(A(m, n)) + O1(m, n) = fabs(A(m, n)) } """ matmul = tc.define(lang, name="matmul") @@ -182,7 +182,8 @@ adopt whatever feels more convenient. import torch lang = """ def avgpool(float(B, C, H, W) input) -> (output) {{ - output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW}) where kh in 0:{kH}, kw in 0:{kW} + output(b, c, h, w) +=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW}) + where r_kh in 0:{kH}, r_kw in 0:{kW} }} """ avgpool = tc.define(lang, name="avgpool", constants={"sH":1, "sW":1, "kH":2, "kW":2}) @@ -205,7 +206,8 @@ adopt whatever feels more convenient. import re LANG=""" def avgpool(float(B, C, H, W) input) -> (output) { - output(b, c, h, w) +=! input(b, c, h * + kh, w * + kw) / ( * ) where kh in 0:, kw in 0: + output(b, c, h, w) +=! input(b, c, h * + r_kh, w * + r_kw) / ( * ) + where r_kh in 0:, r_kw in 0: } """ sH, sW, kH, kW = 1, 1, 2, 2 @@ -233,7 +235,7 @@ call. For example: import torch lang = """ def add(float(N) A, float(N) B) -> (output) { - output(i) = A(i) + B(i) + 1 + output(n) = A(n) + B(n) } """ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5c46479c0..7a8f900c9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ Machine Learning. ml_with_tc integrating_any_ml_framework + coding_conventions .. toctree:: :maxdepth: 1 diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 6abdef270..87a541392 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -18,8 +18,8 @@ over :code:`k` in a matrix multiply: .. code:: - def matmul(float(I, K) B, float(K, J) C) -> A { - A(i, j) +=! B(i, k) * C(k, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index efaa2eccf..cb1b09b5b 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -19,8 +19,8 @@ An example of how using TC in PyTorch looks like: import tensor_comprehensions as tc import torch lang = """ - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(lang, name="matmul") @@ -48,24 +48,25 @@ Let's start with a simple example is a matrix vector product: .. code:: - def mv(float(R,C) A, float(C) B) -> (o) { - o(i) += A(i,j) * B(j) + def mv(float(R,C) A, float(C) x) -> (o) { + o(r) +=! A(r,r_c) * x(r_c) } -:code:`A` and :code:`B` are input tensors. :code:`o` is an output tensor. -The statement :code:`o(i) += A(i,j) * B(j)` introduces two index variables :code:`i` and :code:`j`. -Their range is inferred by their use indexing :code:`A` and :code:`B`. :code:`i = [0,R)`, :code:`j = [0,C)`. -Because :code:`j` only appears on the right side, -stores into :code:`o` will reduce over :code:`j` with the reduction specified for the loop. +:code:`A` and :code:`x` are input tensors. :code:`o` is an output tensor. +The statement :code:`o(r) += A(r,r_c) * x(r_c)` introduces two index variables :code:`r` and :code:`r_`. +Their range is inferred by their use indexing :code:`A` and :code:`x`. :code:`r = [0,R)`, :code:`r_c = [0,C)`. +Because :code:`r_c` only appears on the right side, +stores into :code:`o` will reduce over :code:`r_c` with the reduction specified for the loop. Reductions can occur across multiple variables, but they all share the same kind of associative reduction (e.g. :code:`+=`) -to maintain invariant (3). :code:`mv` computes the same thing as this C++ loop: +to maintain invariant (3). Note that we prefix reduction indices names with +:code:`r_` for improved readability. :code:`mv` computes the same thing as this C++ loop: .. code:: for(int i = 0; i < R; i++) { o(i) = 0.0f; for(int j = 0; j < C; j++) { - o(i) += A(i,j) * B(j); + o(i) += A(i,j) * x(j); } } @@ -81,8 +82,8 @@ Simple matrix-vector .. code:: - def mv(float(R,C) A, float(C) B) -> (o) { - o(i) += A(i,j) * B(j) + def mv(float(R,C) A, float(C) x) -> (o) { + o(r) +=! A(r,r_c) * x(r_c) } Simple 2-D convolution (no stride, no padding) @@ -91,7 +92,7 @@ Simple 2-D convolution (no stride, no padding) .. code:: def conv(float(B,IP,H,W) input, float(OP,IP,KH,KW) weight) -> (output) { - output(b, op, h, w) += input(b, ip, h + kh, w + kw) * weight(op, ip, kh, kw) + output(b, op, h, w) +=! input(b, r_ip, h + r_kh, w + r_kw) * weight(op, r_ip, r_kh, r_kw) } Simple 2D max pooling @@ -102,6 +103,6 @@ Note the similarity with a convolution with a "select"-style kernel: .. code:: def maxpool2x2(float(B,C,H,W) input) -> (output) { - output(b,c,i,j) max= input(b,c,2*i + kw, 2*j + kh) - where kw in 0:2, kh in 0:2 + output(b,c,h,w) max=! input(b,c,2*h + r_kw, 2*w + r_kh) + where r_kw in 0:2, r_kh in 0..2 } diff --git a/docs/source/ml_with_tc.rst b/docs/source/ml_with_tc.rst index f92bc1512..54d9cee72 100644 --- a/docs/source/ml_with_tc.rst +++ b/docs/source/ml_with_tc.rst @@ -44,8 +44,8 @@ consider the TC definition below: .. code:: def softmax(float(N, D) I) -> (O, expsum) { - expsum(n) +=! exp(I(n, d)) - O(n, d) = exp(I(n, d)) / expsum(n) + expsum(n) +=! exp(I(n, d)) + O(n, d) = exp(I(n, d)) / expsum(n) } In this TC, :code:`expsum` is a temporary variable that needs to be computed but diff --git a/docs/source/tutorials/tutorial_tensordot_with_tc.rst b/docs/source/tutorials/tutorial_tensordot_with_tc.rst index 765d2f626..6a98695df 100644 --- a/docs/source/tutorials/tutorial_tensordot_with_tc.rst +++ b/docs/source/tutorials/tutorial_tensordot_with_tc.rst @@ -39,18 +39,18 @@ A simple 2D matrix multiply operation in TC is expressed as: .. code:: def matmul(float(M, N) X, float(N, K) W) -> (output) { - output(m, k) +=! X(m, nn) * W(nn, k) + output(m, k) +=! X(m, r_n) * W(r_n, k) } -The variable :code:`nn` is being reduced in above expression. Now, let's write a +The variable :code:`r_n` is being reduced in above expression. Now, let's write a **batched matrix-multiply** operation using above expression. For that, we need to add a batch dimension to it and the expression becomes: .. code:: def batch_matmul(float(B, M, N) X, float(B, N, K) W) -> (output) { - output(b, m, k) +=! X(b, m, nn) * W(b, nn, k) + output(b, m, k) +=! X(b, m, r_n) * W(b, r_n, k) } Now, for the tensordot operation, we need to add spatial dimensions :code:`H` and :code:`W` @@ -59,7 +59,7 @@ to the batched matrix multiply, and the expression for TensorDot becomes: .. code:: def tensordot(float(B, C1, C2, H, W) I0, float(B, C2, C3, H, W) I1) -> (O) { - O(b, c1, c3, h, w) +=! I0(b, c1, c2, h, w) * I1(b, c2, c3, h, w) + O(b, c1, c3, h, w) +=! I0(b, c1, r_c2, h, w) * I1(b, r_c2, c3, h, w) } Now, we have our :code:`TensorDot` expression, we are ready to use this and write @@ -73,7 +73,7 @@ Now, we have our :code:`TensorDot` expression, we are ready to use this and writ # define the operation as TC language lang = """ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) } """ diff --git a/tensor_comprehensions/library/layers.yaml b/tensor_comprehensions/library/layers.yaml index 4c6357cee..6d9fbd504 100644 --- a/tensor_comprehensions/library/layers.yaml +++ b/tensor_comprehensions/library/layers.yaml @@ -25,226 +25,224 @@ - name: indexing lang: | - def indexing(float(H, W) input, int32(L) index) -> (output) {{ - output(l, w) = input(index(l), w) where l in 0:{L} + def indexing(float(H, W) Input, int32(L) Index) -> (Output) {{ + Output(l, w) = Input(Index(l), w) where l in 0:{L} }} - name: lookup_table lang: | def lookup_table(float(B, R) LUT, int32(B, N) I) -> (O) { - O(b, n) +=! LUT(I(b, n), r) + O(b, n) +=! LUT(I(b, n), r_r) } - name: matmul lang: | - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) + def matmul(float(M, K) A, float(K, N) B) -> (C) { + C(m, n) +=! A(m, r_k) * B(r_k, n) } grad: | - def matmul_grad(float(M,N) A, float(N,K), float(M,K) O_grad) -> (A_grad, B_grad){ - A_grad(i, j) +=! O_grad(i, kk) * B(j, kk) - B_grad(i, j) +=! O_grad(kk, j) * A(kk, i) + def matmul_grad(float(M,K) A, float(K,N) B, float(M,N) d_C) -> (d_A, d_B){ + d_A(m, k) +=! d_C( m, r_n) * B( k, r_n) + d_B(k, n) +=! d_C(r_m, n) * A(r_m, k) } - name: batch_matmul lang: | def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) } - name: transpose lang: | def transpose(float(N, C, H, W) I) -> (O) { - O(c, n, w, h) = I(n, c, h, w) + O(c, n, w, h) = I(n, c, h, w) } - name: avgpool lang: | - def avgpool(float(B, C, H, W) input) -> (output) {{ - output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW}) where kh in 0:{kH}, kw in 0:{kW} + def avgpool(float(B, C, H, W) Input) -> (Output) {{ + Output(b, c, h, w) +=! Input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW}) + where r_kh in 0:{kH}, r_kw in 0:{kW} }} - name: maxpool lang: | - def maxpool(float(B, C, H, W) input) -> (output) {{ - output(b, c, h, w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} + def maxpool(float(B, C, H, W) Input) -> (Output) {{ + Output(b, c, h, w) max=! Input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) + where r_kh in 0:{kH}, r_kw in 0:{kW} }} - name: scale lang: | def scale(float(M, N) I) -> (O) {{ - O(m, n) = I(m, n) * {s} + O(m, n) = I(m, n) * {s} }} - name: sigmoid lang: | def sigmoid(float(N, C, H, W) I) -> (O) { - O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w))) + O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w))) } - name: softmax lang: | - def softmax(float(N, D) I) -> (O, expsum, maxVal) { - maxVal(n) max= I(n, d) - expsum(n) +=! exp(I(n, d) - maxVal(n)) - O(n, d) = exp(I(n, d)) / expsum(n) + def softmax(float(N, D) I) -> (O, ExpSum, MaxVal) { + MaxVal(n) max= I(n, d) + ExpSum(n) +=! exp(I(n, d) - MaxVal(n)) + O(n, d) = exp(I(n, d)) / ExpSum(n) } - name: Tanh lang: | def Tanh(float(M) I) -> (O) { - O(m) = tanh(I(m)) + O(m) = tanh(I(m)) } - name: tensordot lang: | def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) } - name: fully_connected lang: | def fully_connected(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) { - O1(b, n) +=! I(b, m) * W1(n, m) - O1(b, n) = O1(b, n) + B1(n) + O1(b, n) +=! I(b, r_m) * W1(n, r_m) + O1(b, n) = O1(b, n) + B1(n) } - name: relu lang: | def relu(float(B, M) I) -> (O1){ - O1(b, m) = fmax(I(b, m), 0) + O1(b, m) = fmax(I(b, m), 0) } - name: fcrelu lang: | - def fcrelu(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1){ - O1(b, n) +=! I(b, m) * W1(n, m) - O1(b, n) = O1(b, n) + B1(n) - O1(b, n) = fmax(O1(b, n), 0) + def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) { + O1(b, n) +=! I(b, r_m) * W1(n, r_m) + O1(b, n) = O1(b, n) + B1(n) + O1(b, n) = fmax(O1(b, n), 0) } - name: cast lang: | - def cast(float(M, N) A) -> (int32(M, N) O1) {{ - O1(m, n) = int32(A(m, n) + {constant}) + def cast(float(M,N) A) -> (int32(M,N) O1) {{ + O1(m, n) = int32(A(m, n) + {constant}) }} - name: concat lang: | def concat(float(M, N) A, float(M, N) B) -> (O1) { - O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 + O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 } - name: convolution lang: | def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) { - O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw) - O(n, m, h, w) = O(n, m, h, w) + B(m) + O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) + O(n, m, h, w) = O(n, m, h, w) + B(m) } - name: convolution_strided lang: | def convolution_strided(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {{ - O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) - O(n, m, h, w) = O(n, m, h, w) + B(m) + O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw) + O(n, m, h, w) = O(n, m, h, w) + B(m) }} grad: | - def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) - -> (I_grad, W1_grad) {{ - I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) - W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) + def convolution_strided_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) d_O) -> (d_I, d_W1) {{ + d_I(n, c, h, w) +=! d_O(n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw) + d_W1(m, c, kh, kw) +=! d_O(n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w) }} - name: group_convolution lang: | def group_convolution(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) { - O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw) - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) + O(n, g, f, h, w) +=! I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) } - name: group_convolution_strided lang: | - def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) - {{ - O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw) - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) + def group_convolution_strided(float(N, G, C, H, W) I, float(G, F, C, KH, KW) W1, float(G, F) B) -> (O) {{ + O(n, g, f, h, w) +=! I(n, g, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) }} - name: copy2D lang: | def copy2D(float(M, N) I) -> (O) { - O(i, j) = I(i, j) + O(m, n) = I(m, n) } - name: copy lang: | def copy(float({dimParams}) I) -> (O) {{ - O({dimIndices}) = I({dimIndices}) + O({dimIndices}) = I({dimIndices}) }} - name: cosine lang: | def cosine(float(M) I) -> (O) { - O(i) = cos(I(i)) + O(i) = cos(I(i)) } - name: cosine_similarity lang: | def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2) {{ - sumI1(m) +=! I1(m, n) * I1(m, n) - sumI2(m) +=! I2(m, n) * I2(m, n) - O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) + sumI1(m) +=! I1(m, r_n) * I1(m, r_n) + sumI2(m) +=! I2(m, r_n) * I2(m, r_n) + O(m) +=! (I1(m, r_n) * I2(m, r_n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) }} - name: add lang: | - def add(float(N) A, float(N) B) -> (output) { - output(i) = A(i) + B(i) + def add(float(N) A, float(N) B) -> (Output) { + Output(n) = A(n) + B(n) } - name: abs lang: | def abs(float(M, N) A) -> (O1) { - O1(m, n) = fabs(A(m, n)) + O1(m, n) = fabs(A(m, n)) } - name: layernorm lang: | def layernorm(float(T, B, C) I) -> (O, mean, centered, var) {{ - mean(t, b) +=! I(t, b, c) / C - centered(t, b, c) = I(t, b, c) - mean(t, b) - var(t, b) +=! centered(t, b, c) * centered(t, b, c) - var(t, b) = (var(t, b) + {eps}) / C - O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) + mean(t, b) +=! I(t, b, c) / C + centered(t, b, c) = I(t, b, c) - mean(t, b) + var(t, b) +=! centered(t, b, c) * centered(t, b, c) + var(t, b) = (var(t, b) + {eps}) / C + O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) }} - name: batchnorm lang: | - def batchnorm(float(N, C, H, W) I, float(C) rMeanIn, float(C) rVarIn) + def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) {{ - mean(c) +=! I(nn, c, hh, ww) - mean(c) = mean(c) / (N * H * W) - rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) - centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) - variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) - expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) - rVarOut(c) = rsqrt( - (1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) - O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) - normalizedOut(n, c, h, w) = O(n, c, h, w) + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) + rVarOut(c) = rsqrt((1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) }} - name: small_mobilenet lang: | def small_mobilenet(float(C1, H, W) I, float(C1, KH1, KW1) W1, float(C1) B1, float(C2, C1) W2, float(C2) B2) - -> (O1, O2) - { - O1(c1, h, w) +=! I(c1, h + kh, w + kw) * W1(c1, kh, kw) - O1(c1, h, w) = O1(c1, h, w) + B1(c1) - O1(c1, h, w) = fmax(O1(c1, h, w), 0) - - O2(c2, h, w) +=! O1(c1, h, w) * W2(c2, c1) - O2(c2, h, w) = O2(c2, h, w) + B2(c2) - O2(c2, h, w) = fmax(O2(c2, h, w), 0) + -> (O1, O2) { + O1(c1, h, w) +=! I(c1, h + r_kh, w + r_kw) * W1(c1, r_kh, r_kw) + O1(c1, h, w) = O1(c1, h, w) + B1(c1) + O1(c1, h, w) = fmax(O1(c1, h, w), 0) + + O2(c2, h, w) +=! O1(r_c1, h, w) * W2(c2, r_c1) + O2(c2, h, w) = O2( c2, h, w) + B2(c2) + O2(c2, h, w) = fmax(O2(c2, h, w), 0) } diff --git a/test/test_autotuner.cc b/test/test_autotuner.cc index f66a4c0ac..7de32a0c3 100644 --- a/test/test_autotuner.cc +++ b/test/test_autotuner.cc @@ -96,13 +96,14 @@ TEST_F(ATenCompilationUnitTest, LayerNorm) { std::vector outputs; static constexpr auto TC = R"TC( - def layernorm(float(T, B, C) I) -> (O, mean, centered, var) { - mean(t, b) +=! I(t, b, c) / C - centered(t, b, c) = I(t, b, c) - mean(t, b) - var(t, b) +=! centered(t, b, c) * centered(t, b, c) - var(t, b) = (var(t, b)) / C - O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) - } +def layernorm(float(T, B, C) I) -> (O, mean, centered, var) { + mean(t, b) +=! I(t, b, c) / C + centered(t, b, c) = I(t, b, c) - mean(t, b) + + var(t, b) +=! centered(t, b, c) * centered(t, b, c) + var(t, b) = var(t, b) / C + O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); auto name = "layernorm"; @@ -119,9 +120,9 @@ TEST_F(ATenCompilationUnitTest, MatmulA) { std::vector outputs; static constexpr auto TC = R"TC( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) - } +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); auto name = "matmul"; @@ -138,9 +139,9 @@ TEST_F(ATenCompilationUnitTest, MatmulB) { std::vector outputs; static constexpr auto TC = R"TC( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) - } +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); auto name = "matmul"; @@ -157,9 +158,9 @@ TEST_F(ATenCompilationUnitTest, MatmulC) { std::vector outputs; static constexpr auto TC = R"TC( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(i, j) +=! A(i, kk) * B(kk, j) - } +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); auto name = "matmul"; @@ -176,9 +177,9 @@ TEST_F(ATenCompilationUnitTest, TensorDot) { std::vector outputs; static constexpr auto TC = R"TC( - def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) - } +def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} )TC"; auto options = tc::CudaMappingOptions::makeConvolutionCudaMappingOptions(); auto name = "tensordot"; diff --git a/test/test_autotuner_utility.cc b/test/test_autotuner_utility.cc index 571cd0d97..0558096a1 100644 --- a/test/test_autotuner_utility.cc +++ b/test/test_autotuner_utility.cc @@ -64,9 +64,9 @@ TEST(RestoreCandidates, NoCache) { } static constexpr auto tc_ = R"( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(m, k) +=! A(m, nn) * B(nn, k) - })"; +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +})"; void EnableCaches() { tc::CudaCache::enableCache(); diff --git a/test/test_caffe2.cc b/test/test_caffe2.cc index fac6559e2..b0a17326d 100644 --- a/test/test_caffe2.cc +++ b/test/test_caffe2.cc @@ -76,7 +76,7 @@ TEST_F(Caffe2CopyTest, DISABLED_TcCopyOp_Gradient1D) { auto AddInput = TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M}, "I"); - AddInput(w, {M}, "O_grad"); + AddInput(w, {M}, "g_O"); }; OperatorDef def = TestHarness::ConfigureCUDA("TcCopyOp", {"I"}, {"O"}, {strategyArg}); @@ -99,7 +99,7 @@ TEST_F(Caffe2CopyTest, DISABLED_TcCopyOp_Gradient2D) { auto AddInput = TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M, N}, "I"); - AddInput(w, {M, N}, "O_grad"); + AddInput(w, {M, N}, "g_O"); }; OperatorDef def = TestHarness::ConfigureCUDA("TcCopyOp", {"I"}, {"O"}, {strategyArg}); @@ -122,7 +122,7 @@ TEST_F(Caffe2CopyTest, DISABLED_TcCopyOp_Gradient3D) { auto AddInput = TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M, N, P}, "I"); - AddInput(w, {M, N, P}, "O_grad"); + AddInput(w, {M, N, P}, "g_O"); }; OperatorDef def = TestHarness::ConfigureCUDA("TcCopyOp", {"I"}, {"O"}, {strategyArg}); @@ -145,7 +145,7 @@ TEST_F(Caffe2CopyTest, DISABLED_TcCopyOp_Gradient4D) { auto AddInput = TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M, N, P, Q}, "I"); - AddInput(w, {M, N, P, Q}, "O_grad"); + AddInput(w, {M, N, P, Q}, "g_O"); }; OperatorDef def = TestHarness::ConfigureCUDA("TcCopyOp", {"I"}, {"O"}, {strategyArg}); @@ -168,7 +168,7 @@ TEST_F(Caffe2CopyTest, DISABLED_TcCopyOp_Gradient5D) { auto AddInput = TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M, N, P, Q, R}, "I"); - AddInput(w, {M, N, P, Q, R}, "O_grad"); + AddInput(w, {M, N, P, Q, R}, "g_O"); }; OperatorDef def = TestHarness::ConfigureCUDA("TcCopyOp", {"I"}, {"O"}, {strategyArg}); @@ -201,7 +201,7 @@ TEST_F(Caffe2Test, DISABLED_TcMatMulOp_Gradient) { TestHarness::AddDeterministicallyRandomInput; AddInput(w, {M, K}, "I"); AddInput(w, {K, N}, "W"); - AddInput(w, {M, N}, "O_grad"); + AddInput(w, {M, N}, "g_O"); }; CudaMappingOptions options = tc::makeBaseCliStrategy() @@ -572,7 +572,7 @@ TEST_F(Caffe2Test, DISABLED_TcConvolutionOp_Gradient) { AddInput(w, {NN, C, H, W}, "I"); AddInput(w, {F, C, KH, KW}, "filter"); AddInput(w, {F}, "bias"); - AddInput(w, {NN, F, H - KH + 1, W - KW + 1}, "H_grad"); + AddInput(w, {NN, F, H - KH + 1, W - KW + 1}, "g_H"); }; CudaMappingOptions options = diff --git a/test/test_compilation_cache.cc b/test/test_compilation_cache.cc index 3a4226edc..2517cf85f 100644 --- a/test/test_compilation_cache.cc +++ b/test/test_compilation_cache.cc @@ -996,7 +996,7 @@ TEST( * std::vector inputs_; * int M; * static constexpr auto tc_ = R"( - * def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) { + * def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) { * O1(b, n) += I(b, m) * W1(n, m) * O1(b, n) = O1(b, n) + B1(n) * O1(b, n) = fmax(O1(b, n), 0) @@ -1026,9 +1026,9 @@ class MatMulTester { std::vector inputs_; int M; static constexpr auto tc_ = R"( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(m, k) +=! A(m, nn) * B(nn, k) - })"; +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +})"; }; class ConvolutionTester { @@ -1061,11 +1061,12 @@ class ConvolutionTester { int KH; int KW; static constexpr auto tc_ = R"( - def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) - -> (tmp, O1) { - tmp(n, o, h, w) +=! I(n, c, h + kh, w + kw) * W1(o, c, kh, kw) - O1(n, o, h, w) = tmp(n, o, h, w) + B(o) - })"; +def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) +-> (tmp, O1) +{ + tmp(n, o, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(o, r_c, r_kh, r_kw) + O1(n, o, h, w) = tmp(n, o, h, w) + B(o) +})"; }; class CompilationCacheTest : public ::testing::Test { @@ -1357,9 +1358,9 @@ TEST_F(CompilationCacheTest, Serialization) { TEST(CompilationCache, ManualInjection) { static constexpr auto tc = R"( - def add(float(N) A, float(N) B) -> (output) { - output(i) = A(i) + B(i) - })"; +def add(float(N) A, float(N) B) -> (output) { + output(n) = A(n) + B(n) +})"; tc::ManualCudaCache::enableCache(); tc::ATenCompilationUnit atCompl; diff --git a/test/test_core.cc b/test/test_core.cc index d8c7f29d3..022ef1020 100644 --- a/test/test_core.cc +++ b/test/test_core.cc @@ -61,30 +61,30 @@ struct GenericHalideCoreTest : public ::testing::Test { TEST_F(GenericHalideCoreTest, TwoMatmul) { string tc = R"TC( def fun(float(M, K) I, float(K, N) W1, float(N, P) W2) -> (O1, O2) { - O1(i, j) +=! I(i, k) * W1(k, j) - O2(i, j) +=! O1(i, n) * W2(n, j) + O1(m, n) +=! I(m, r_k) * W1(r_k, n) + O2(m, p) +=! O1(m, r_n) * W2(r_n, p) } )TC"; CheckC( tc, { - "for (int O1_s0_i = 0; O1_s0_i < M; O1_s0_i++) {", - " for (int O1_s0_j = 0; O1_s0_j < N; O1_s0_j++) {", - " O1[O1_s0_i][O1_s0_j] = 0.000000f", - " for (int O1_s1_k = 0; O1_s1_k < K; O1_s1_k++) {", - " O1[O1_s0_i][O1_s0_j] = (O1[O1_s0_i][O1_s0_j] + (I[O1_s0_i][O1_s1_k]*W1[O1_s1_k][O1_s0_j]))", - "for (int O2_s0_i = 0; O2_s0_i < M; O2_s0_i++) {", - " for (int O2_s0_j = 0; O2_s0_j < P; O2_s0_j++) {", - " O2[O2_s0_i][O2_s0_j] = 0.000000f", - " for (int O2_s1_n = 0; O2_s1_n < N; O2_s1_n++) {", - " O2[O2_s0_i][O2_s0_j] = (O2[O2_s0_i][O2_s0_j] + (O1[O2_s0_i][O2_s1_n]*W2[O2_s1_n][O2_s0_j]))", + "for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {", + " for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {", + " O1[O1_s0_m][O1_s0_n] = 0.000000f", + " for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {", + " O1[O1_s0_m][O1_s0_n] = (O1[O1_s0_m][O1_s0_n] + (I[O1_s0_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s0_n]))", + "for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {", + " for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {", + " O2[O2_s0_m][O2_s0_p] = 0.000000f", + " for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {", + " O2[O2_s0_m][O2_s0_p] = (O2[O2_s0_m][O2_s0_p] + (O1[O2_s0_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s0_p]))", }); } TEST_F(GenericHalideCoreTest, Convolution) { string tc = R"TC( def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) { - O1(n, f, h, w) +=! I1(n, c, h + kh, w + kw) * W1(c, f, kh, kw) + O1(n, f, h, w) +=! I1(n, r_c, h + r_kh, w + r_kw) * W1(r_c, f, r_kh, r_kw) } )TC"; CheckC( @@ -94,10 +94,10 @@ def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) { " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {", " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {", " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f", - " for (int O1_s1_c = 0; O1_s1_c < C; O1_s1_c++) {", - " for (int O1_s1_kh = 0; O1_s1_kh < KH; O1_s1_kh++) {", - " for (int O1_s1_kw = 0; O1_s1_kw < KW; O1_s1_kw++) {", - " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s1_c][(O1_s0_h + O1_s1_kh)][(O1_s0_w + O1_s1_kw)]*W1[O1_s1_c][O1_s0_f][O1_s1_kh][O1_s1_kw]))"}); + " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {", + " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {", + " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {", + " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))"}); } TEST_F(GenericHalideCoreTest, Copy) { @@ -112,7 +112,7 @@ TEST_F(GenericHalideCoreTest, Copy) { TEST_F(GenericHalideCoreTest, GroupConvolution) { string tc = R"TC( def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) { - O1(n, g, f, h, w) +=! I1(n, g, c, h + kh, w + kw) * W1(g, c, f, kh, kw) + O1(n, g, f, h, w) +=! I1(n, g, r_c, h + r_kh, w + r_kw) * W1(g, r_c, f, r_kh, r_kw) } )TC"; CheckC( @@ -123,10 +123,10 @@ def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) { " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {", " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {", " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f", - " for (int O1_s1_c = 0; O1_s1_c < C; O1_s1_c++) {", - " for (int O1_s1_kh = 0; O1_s1_kh < KH; O1_s1_kh++) {", - " for (int O1_s1_kw = 0; O1_s1_kw < KW; O1_s1_kw++) {", - " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s0_g][O1_s1_c][(O1_s0_h + O1_s1_kh)][(O1_s0_w + O1_s1_kw)]*W1[O1_s0_g][O1_s1_c][O1_s0_f][O1_s1_kh][O1_s1_kw]))"}); + " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {", + " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {", + " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {", + " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s0_g][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s0_g][O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))"}); } TEST_F(GenericHalideCoreTest, Matmul) { @@ -170,7 +170,7 @@ struct TC2Isl : public ::testing::Test { TEST_F(TC2Isl, Copy1D) { string tc = R"TC( def fun(float(M) I) -> (O) { - O(i) = I(i) + O(m) = I(m) } )TC"; Check(tc, {123}); @@ -179,7 +179,7 @@ def fun(float(M) I) -> (O) { TEST_F(TC2Isl, Copy2D) { string tc = R"TC( def fun(float(M, N) I) -> (O) { - O(i, j) = I(i, j) + O(m, n) = I(m, n) } )TC"; Check(tc, {123, 1}); @@ -188,7 +188,7 @@ def fun(float(M, N) I) -> (O) { TEST_F(TC2Isl, Copy3D) { string tc = R"TC( def fun(float(M, N, P) I) -> (O) { - O(i, j, k) = I(i, j, k) + O(m, n, p) = I(m, n, p) } )TC"; Check(tc, {123, 3, 2}); @@ -197,7 +197,7 @@ def fun(float(M, N, P) I) -> (O) { TEST_F(TC2Isl, Copy4D) { string tc = R"TC( def fun(float(M, N, P, Q) I) -> (O) { - O(i, j, k, l) = I(i, j, k, l) + O(m, n, p, q) = I(m, n, p, q) } )TC"; Check(tc, {123, 3, 4, 5}); @@ -206,7 +206,7 @@ def fun(float(M, N, P, Q) I) -> (O) { TEST_F(TC2Isl, Copy5D) { string tc = R"TC( def fun(float(M, N, P, Q, R) I) -> (O) { - O(i, j, k, l, m) = I(i, j, k, l, m) + O(m, n, p, q, r) = I(m, n, p, q, r) } )TC"; Check(tc, {123, 10, 2, 3, 4}); @@ -216,7 +216,7 @@ def fun(float(M, N, P, Q, R) I) -> (O) { TEST_F(TC2Isl, DISABLED_Reduction1D) { string tc = R"TC( def fun(float(M) I) -> (O) { - O(0) +=! I(i) + O(0) +=! I(r_m) } )TC"; Check(tc, {123}); @@ -225,7 +225,7 @@ def fun(float(M) I) -> (O) { TEST_F(TC2Isl, Reduction2D) { string tc = R"TC( def fun(float(M, N) I) -> (O) { - O(i) +=! I(i, j) + O(m) +=! I(m, r_n) } )TC"; Check(tc, {123, 12}); @@ -234,7 +234,7 @@ def fun(float(M, N) I) -> (O) { TEST_F(TC2Isl, Reduction3D) { string tc = R"TC( def fun(float(M, N, P) I) -> (O) { - O(i) +=! I(i, j, k) + O(m) +=! I(m, r_n, r_p) } )TC"; Check(tc, {123, 12, 16}); @@ -243,8 +243,8 @@ def fun(float(M, N, P) I) -> (O) { TEST_F(TC2Isl, Copy1D2Stmt) { string tc = R"TC( def fun(float(M) I) -> (O1, O2) { - O1(i) = I(i) - O2(i) = O1(i) + O1(m) = I(m) + O2(m) = O1(m) } )TC"; Check(tc, {123}); @@ -253,8 +253,8 @@ def fun(float(M) I) -> (O1, O2) { TEST_F(TC2Isl, Copy2D2Stmt) { string tc = R"TC( def fun(float(M, N) I) -> (O1, O2) { - O1(i, j) = I(i, j) - O2(i, j) = O1(i, j) + O1(m, n) = I(m, n) + O2(m, n) = O1(m, n) } )TC"; Check(tc, {123, 13}); @@ -263,9 +263,9 @@ def fun(float(M, N) I) -> (O1, O2) { TEST_F(TC2Isl, Copy2D3Stmt) { string tc = R"TC( def fun(float(M, N) I) -> (O1, O2, O3) { - O1(i, j) = I(i, j) - O2(i, j) = O1(i, j) - O3(i, j) = O2(i, j) + O1(m, n) = I(m, n) + O2(m, n) = O1(m, n) + O3(m, n) = O2(m, n) } )TC"; Check(tc, {123, 13}); @@ -275,8 +275,8 @@ def fun(float(M, N) I) -> (O1, O2, O3) { TEST_F(TC2Isl, DISABLED_Reduction1D2Stmt) { string tc = R"TC( def fun(float(M) I) -> (O1, O2) { - O1(i) = I(i) - O2(i) = O1(i) + O1(m) = I(m) + O2(m) = O1(m) } )TC"; Check(tc, {123}); @@ -285,8 +285,8 @@ def fun(float(M) I) -> (O1, O2) { TEST_F(TC2Isl, Reduction2D2StmtA) { string tc = R"TC( def fun(float(M, N) I) -> (O1, O2) { - O1(i) +=! I(i, j) - O2(i) = O1(i) + O1(m) +=! I(m, r_n) + O2(m) = O1(m) } )TC"; Check(tc, {123, 13}); @@ -295,8 +295,8 @@ def fun(float(M, N) I) -> (O1, O2) { TEST_F(TC2Isl, Reduction2D2StmtB) { string tc = R"TC( def fun(float(M, N) I) -> (O1, O2) { - O1(i, j) = I(i, j) - O2(i) +=! O1(i, j) + O1(m, n) = I(m, n) + O2(m) +=! O1(m, r_n) } )TC"; Check(tc, {123, 13}); @@ -305,9 +305,9 @@ def fun(float(M, N) I) -> (O1, O2) { TEST_F(TC2Isl, Reduction2D3Stmt) { string tc = R"TC( def fun(float(M, N) I) -> (O1, O2, O3) { - O1(i, j) = I(i, j) - O2(i) +=! O1(i, j) - O3(i) = O2(i) + O1(m, n) = I(m, n) + O2(m) +=! O1(m, r_n) + O3(m) = O2(m) } )TC"; Check(tc, {123, 13}); diff --git a/test/test_execution_engine.cc b/test/test_execution_engine.cc index 3e0971d8d..246cbd0f1 100644 --- a/test/test_execution_engine.cc +++ b/test/test_execution_engine.cc @@ -51,15 +51,15 @@ TEST_F(ATenCompilationUnitTest, DISABLED_SoftmaxA) { std::vector inputs = {a}; std::vector outputs; - // Tensor dependencies should strictly be DAG + // Tensor dependencies should form a DAG Check( R"( - def softmax(float(N, D) I) -> (O, tmp) { - tmp(n) max= I(n, d) - O(n, d) = exp(I(n, d) - tmp(n)) - tmp(n) +=! O(n, d) - O(n, d) = O(n, d) / tmp(n) - } +def softmax(float(N, D) I) -> (O, tmp) { + tmp(n) max= I(n, d) + O(n, d) = exp(I(n, d) - tmp(n)) + tmp(n) +=! O(n, d) + O(n, d) = O(n, d) / tmp(n) +} )", "softmax", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -72,15 +72,15 @@ TEST_F(ATenCompilationUnitTest, DISABLED_SoftmaxB) { std::vector inputs = {a}; std::vector outputs; - // Tensor dependencies should strictly be DAG + // Tensor dependencies should form a DAG Check( R"( - def softmax(float(N, D) I) -> (O, tmp, tmp1) { - tmp(n) max=! I(n, d) - O(n, d) = exp(I(n, d) - tmp(n)) - tmp1(n) +=! O(n, d) - O(n, d) = O(n, d) / tmp1(n) - } +def softmax(float(N, D) I) -> (O, tmp) { + tmp(n) max= I(n, d) + O(n, d) = exp(I(n, d) - tmp(n)) + tmp(n) +=! O(n, d) + O(n, d) = O(n, d) / tmp(n) +} )", "softmax", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -95,11 +95,11 @@ TEST_F(ATenCompilationUnitTest, SoftmaxC) { Check( R"( - def softmax(float(N, D) I) -> (O, expsum, maxVal) { - maxVal(n) max=! I(n, d) - expsum(n) +=! exp(I(n, d) - maxVal(n)) - O(n, d) = exp(I(n, d) - maxVal(n)) / expsum(n) - } +def softmax(float(N, D) I) -> (O, expsum, maxVal) { + maxVal(n) max=! I(n, d) + expsum(n) +=! exp(I(n, d) - maxVal(n)) + O(n, d) = exp(I(n, d) - maxVal(n)) / expsum(n) +} )", "softmax", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -114,12 +114,12 @@ TEST_F(ATenCompilationUnitTest, SoftmaxD) { Check( R"( - def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum) { - maxVal(n) max=! I(n, d) - expDistance(n, d) = exp(I(n, d) - maxVal(n)) - expSum(n) +=! expDistance(n, d) - O(n, d) = expDistance(n, d) / expSum(n) - } +def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum) { + maxVal(n) max=! I(n, d) + expDistance(n, d) = exp(I(n, d) - maxVal(n)) + expSum(n) +=! expDistance(n, d) + O(n, d) = expDistance(n, d) / expSum(n) +} )", "softmax", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -135,9 +135,9 @@ TEST_F(ATenCompilationUnitTest, Concat) { Check( R"( - def concat(float(M, N) A, float(M, N) B) -> (O1) { - O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 - } +def concat(float(M, N) A, float(M, N) B) -> (O1) { + O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 +} )", "concat", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -153,9 +153,9 @@ TEST_F(ATenCompilationUnitTest, Indexing) { Check( R"( - def indexing(float(H, W) input, int32(L) index) -> (output) { - output(l, w) = input(index(l), w) where l in 0:2 - } +def indexing(float(H, W) input, int32(L) index) -> (output) { + output(l, w) = input(index(l), w) where l in 0:2 +} )", "indexing", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), @@ -171,9 +171,9 @@ TEST_F(ATenCompilationUnitTest, MatMul) { Check( R"( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(m, k) +=! A(m, nn) * B(nn, k) - } +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) +=! A(m, r_n) * B(r_n, k) +} )", "matmul", tc::CudaMappingOptions::makeMlpCudaMappingOptions(), @@ -194,9 +194,9 @@ TEST_F(ATenCompilationUnitTest, MatMulInplace) { Check( R"( - def matmul(float(M,N) A, float(N,K) B) -> (output) { - output(m, k) += A(m, nn) * B(nn, k) - } +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(m, k) += A(m, r_n) * B(r_n, k) +} )", "matmul", tc::CudaMappingOptions::makeMlpCudaMappingOptions(), @@ -216,14 +216,14 @@ TEST_F(ATenCompilationUnitTest, Convolution2d) { Check( R"( - def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) - -> (tmp, O1) { - tmp(n, o, h, w) +=! I(n, c, h + kh, w + kw) * W1(o, c, kh, kw) - # this can be equivalently written with =, - # but this line tests that we correctly handle - # degenerate +=! that have no reduction dimensions - O1(n, o, h, w) +=! tmp(n, o, h, w) + B(o) - } +def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) +-> (tmp, O1) { + tmp(n, o, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(o, r_c, r_kh, r_kw) + # this can be equivalently written with =, + # but this line tests that we correctly handle + # degenerate +=! that have no reduction dimensions + O1(n, o, h, w) +=! tmp(n, o, h, w) + B(o) +} )", "convolution", tc::CudaMappingOptions::makeConvolutionCudaMappingOptions(), @@ -243,11 +243,11 @@ TEST_F(ATenCompilationUnitTest, Convolution2dStrided) { std::vector outputs; constexpr static auto convolutionStrided = R"TC( - def convolutionStrided(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) - -> (O1) { - O1(n, o, h, w) +=! I(n, c, * h + kh, * w + kw) * W1(o, c, kh, kw) - O1(n, o, h, w) = O1(n, o, h, w) + B(o) - } +def convolutionStrided(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) +-> (O1) { + O1(n, o, h, w) +=! I(n, r_c, * h + r_kh, * w + r_kw) * W1(o, r_c, r_kh, r_kw) + O1(n, o, h, w) = O1(n, o, h, w) + B(o) +} )TC"; std::string tcStr; @@ -278,9 +278,9 @@ TEST_F(ATenCompilationUnitTest, Casts) { Check( R"( - def cast(float(M,N) A, int32 four) -> (int32(M,N) output) { - output(i,j) = int32(A(i,j) + four) - } +def cast(float(M,N) A, int32 four) -> (int32(M,N) output) { + output(m,n) = int32(A(m,n) + four) +} )", "cast", tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), diff --git a/test/test_execution_engine_cache.cc b/test/test_execution_engine_cache.cc index 45c710b72..acadc4548 100644 --- a/test/test_execution_engine_cache.cc +++ b/test/test_execution_engine_cache.cc @@ -33,9 +33,9 @@ TEST(ATenCompilationCacheTest, Matmul) { tc::ATenCompilationUnit atCompl; auto tc = R"( - def matmul(float(M,K) A, float(K,N) B) -> (output) { - output(m, n) +=! A(m, kk) * B(kk, n) - } +def matmul(float(M,K) A, float(K,N) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(r_k, n) +} )"; atCompl.define(tc); diff --git a/test/test_execution_engine_db.cc b/test/test_execution_engine_db.cc index f6c68f101..7078a315c 100644 --- a/test/test_execution_engine_db.cc +++ b/test/test_execution_engine_db.cc @@ -34,14 +34,14 @@ TEST(ATenCompilationDbTest, MultiTc) { KW = 3; tc::ATenCompilationUnit atCompl; auto tc = R"( - def matmul(float(M,K) A, float(K,N) B) -> (output) { - output(m, n) +=! A(m, kk) * B(kk, n) - } - def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) - -> (tmp, O1) { - tmp(n, o, h, w) +=! I(n, c, h + kh, w + kw) * W1(o, c, kh, kw) - O1(n, o, h, w) = tmp(n, o, h, w) + B(o) - } +def matmul(float(M,K) A, float(K,N) B) -> (output) { + output(m, n) +=! A(m, r_k) * B(r_k, n) +} +def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) +-> (tmp, O1) { + tmp(n, o, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(o, r_c, r_kh, r_kw) + O1(n, o, h, w) = tmp(n, o, h, w) + B(o) +} )"; atCompl.define(tc); diff --git a/test/test_mapper.cc b/test/test_mapper.cc index 01827947a..45f303d38 100644 --- a/test/test_mapper.cc +++ b/test/test_mapper.cc @@ -163,7 +163,7 @@ struct PolyhedralMapperTest : public ::testing::Test { TEST_F(PolyhedralMapperTest, Basic) { string tc = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C) { - C(i, j) = A(i, j) + B(i, j) + C(n, m) = A(n, m) + B(n, m) } )TC"; @@ -203,9 +203,9 @@ TEST_F(PolyhedralMapperTest, MultiStmt) { def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D) -> (O1, O2, O3) { - O1(i, j) +=! A(i, j, rk, rl) * B(i, j) - O2(i, j) = C(i, j) * D(i, j) - O3(i, j) = O1(i, j) + O2(i, j) + O1(n0, n1) +=! A(n0, n1, r_n2, r_n3) * B(n0, n1) + O2(n0, n1) = C(n0, n1) * D(n0, n1) + O3(n0, n1) = O1(n0, n1) + O2(n0, n1) } )TC"; @@ -253,7 +253,7 @@ TEST_F(PolyhedralMapperTest, BareVariables) { string tc = R"TC( def fun(float(N, N) A) -> (O) { - O(i, j) = A(i, j) + i + j + N + O(n0, n1) = A(n0, n1) + n0 + n1 + N } )TC"; @@ -281,7 +281,7 @@ TEST_F(PolyhedralMapperTest, CudaFunctions) { string tc = R"TC( def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) { - O(i, j) = nextafter(C(i), exp(A(i, j))) + log(B(j, i)) + O(n0, n1) = nextafter(C(n0), exp(A(n0, n1))) + log(B(n1, n0)) } )TC"; @@ -378,7 +378,7 @@ TEST_F(PolyhedralMapperTest, Match1) { TEST_F(PolyhedralMapperTest, CopyTC) { string tc = R"TC( def fun(float(M, N) I) -> (O) { - O(i, j) = I(i, j) + O(m, n) = I(m, n) } )TC"; @@ -390,7 +390,7 @@ def fun(float(M, N) I) -> (O) { TEST_F(PolyhedralMapperTest, MatmulTC) { string tc = R"TC( def fun(float(M, K) A, float(K, N) B) -> (C) { - C(i, j) +=! A(i, k) * B(k, j) + C(m, n) +=! A(m, r_k) * B(r_k, n) } )TC"; @@ -425,7 +425,7 @@ TEST_F(PolyhedralMapperTest, MatmulNoshiftNoscale) { static const string kTcAdd = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C) { - C(i, j) = A(i, j) + B(i, j) + C(n, m) = A(n, m) + B(n, m) } )TC"; @@ -474,7 +474,7 @@ TEST_F(PolyhedralMapperTest, Unroll2D) { TEST_F(PolyhedralMapperTest, Copy1D) { auto tc = R"TC( def fun(float(N) I) -> (O) { - O(i) = I(i) + O(n) = I(n) } )TC"; auto scop = Prepare(tc); @@ -494,7 +494,7 @@ def fun(float(N) I) -> (O) { TEST_F(PolyhedralMapperTest, DISABLED_0D) { auto tc = R"TC( def fun() -> (O) { - O = 0 + O = 0 } )TC"; auto code = codegenMapped(tc, DefaultOptions()); @@ -511,8 +511,8 @@ def fun() -> (O) { TEST_F(PolyhedralMapperTest, Copy2) { auto tc = R"TC( def fun(float(N) I) -> (O1, O2) { - O1(i) = I(i) - O2(i) = O1(i) + O1(n) = I(n) + O2(n) = O1(n) } )TC"; auto mappingOptions = DefaultOptions(); @@ -539,8 +539,8 @@ def fun(float(N) I) -> (O1, O2) { TEST_F(PolyhedralMapperTest, CopyUnbalanced) { auto tc = R"TC( def fun(float(N) I1, float(N, N) I2) -> (O1, O2) { - O1(i) = I1(i) - O2(i, j) = I2(i, j) + O1(n) = I1(n) + O2(n0, n1) = I2(n0, n1) } )TC"; auto mappingOptions = DefaultOptions(); @@ -557,8 +557,8 @@ def fun(float(N) I1, float(N, N) I2) -> (O1, O2) { TEST_F(PolyhedralMapperTest, ReschedulingMaxMinFuse) { std::string tc = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C,D) { - C(i, j) = A(i, j) - D(i, j) = B(i, j) + C(n, m) = A(n, m) + D(n, m) = B(n, m) })TC"; auto originalScop = Prepare(tc); @@ -647,8 +647,8 @@ def fun(float(N, M) A, float(N, M) B) -> (C,D) { TEST_F(PolyhedralMapperTest, Rescheduling2MM) { std::string tc = R"TC( def fun(float(M, K) A, float(K, N) B, float(K, N) C) -> (D, E) { - D(i, j) +=! A(i, k) * B(k, j) - E(i, j) +=! A(i, k) * C(k, j) + D(m, n) +=! A(m, r_k) * B(r_k, n) + E(m, n) +=! A(m, r_k) * C(r_k, n) })TC"; auto mappingOptions = DefaultOptions(); @@ -680,7 +680,7 @@ def fun(float(M, K) A, float(K, N) B, float(K, N) C) -> (D, E) { TEST_F(PolyhedralMapperTest, Reduction1D) { string tc = R"TC( def fun(float(N) I) -> (O) { - O +=! I(i) + O +=! I(r_n) } )TC"; auto mappingOptions = DefaultOptions(); @@ -694,7 +694,7 @@ def fun(float(N) I) -> (O) { static const string kTcMM = R"TC( def fun(float(M, K) A, float(K, N) B) -> (C) { - C(i, j) +=! A(i, k) * B(k, j) + C(m, n) +=! A(m, r_k) * B(r_k, n) })TC"; /* diff --git a/test/test_mapper_llvm.cc b/test/test_mapper_llvm.cc index 065ade992..2c4ce4f26 100644 --- a/test/test_mapper_llvm.cc +++ b/test/test_mapper_llvm.cc @@ -40,7 +40,7 @@ using namespace tc::polyhedral::detail; TEST(LLVMCodegen, Basic) { string tc = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C) { - C(i, j) = A(i, j) + B(i, j) + C(n, m) = A(n, m) + B(n, m) } )TC"; auto N = 40; @@ -69,7 +69,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { TEST(LLVMCodegen, DISABLED_BasicExecutionEngine) { string tc = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C) { - C(i, j) = A(i, j) + B(i, j) + C(n, m) = A(n, m) + B(n, m) } )TC"; @@ -94,9 +94,9 @@ TEST(LLVMCodegen, MultiStmt) { def fun(float(N, M, K, L) A, float(N, M) B, float(N, M) C, float(N, M) D) -> (O1, O2, O3) { - O1(i, j) +=! A(i, j, rk, rl) * B(i, j) - O2(i, j) = C(i, j) * D(i, j) - O3(i, j) = O1(i, j) + O2(i, j) + O1(n, m) +=! A(n, m, r_k, r_l) * B(n, m) + O2(n, m) = C(n, m) * D(n, m) + O3(n, m) = O1(n, m) + O2(n, m) } )TC"; @@ -168,9 +168,9 @@ TEST(LLVMCodegen, BatchMatMul) { auto M = 24; auto K = 21; std::string tc = R"( - def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) - } +def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) +} )"; at::Tensor X = at::CPU(at::kFloat).rand({B, N, M}); at::Tensor Y = at::CPU(at::kFloat).rand({B, M, K}); @@ -200,11 +200,11 @@ TEST(LLVMCodegen, Convolution) { auto KW = 2; auto KH = 3; std::string tc = R"( - def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) - -> (tmp, O1) { - tmp(n, o, h, w) +=! I(n, c, h + kh, w + kw) * W1(o, c, kh, kw) - O1(n, o, h, w) = tmp(n, o, h, w) + B(o) - } +def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) -> (tmp, O1) +{ + tmp(n, o, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(o, r_c, r_kh, r_kw) + O1(n, o, h, w) = tmp(n, o, h, w) + B(o) +} )"; at::Tensor I = at::CPU(at::kFloat).rand({NN, C, H, W}); diff --git a/test/test_mapper_memory_promotion.cc b/test/test_mapper_memory_promotion.cc index 5439c273b..2f56e54a9 100644 --- a/test/test_mapper_memory_promotion.cc +++ b/test/test_mapper_memory_promotion.cc @@ -91,7 +91,7 @@ class Sum4D : public TestMapper { std::vector childPos) { string tc = R"TC( def fun(float(N,M,K,L) A, float(N,M,K,L) B) -> (C) { - C(i,j,k,l) = A(i,j,k,l) + B(i,j,k,l) + C(n,m,k,l) = A(n,m,k,l) + B(n,m,k,l) } )TC"; @@ -245,7 +245,7 @@ class MapperMemoryPromotionSum2D : public MapperMemoryPromotion2DHelper { public: const string tc = R"TC( def fun(float(N, M) A, float(N, M) B) -> (C) { - C(i, j) = A(i, j) + B(i, j) + C(n, m) = A(n, m) + B(n, m) } )TC"; @@ -326,8 +326,8 @@ class MapperMemoryPromotionRAW : public MapperMemoryPromotion2DHelper { public: const string tc = R"TC( def fun(float(N, M) A) -> (B, C) { - B(j, i) = A(j, i) - C(j, i) = B(i, j) + B(n, m) = A(n, m) + C(m, n) = B(n, m) } )TC"; diff --git a/test/test_tc_mapper.cc b/test/test_tc_mapper.cc index de0b0b45f..a5f5f0c28 100644 --- a/test/test_tc_mapper.cc +++ b/test/test_tc_mapper.cc @@ -76,22 +76,22 @@ struct TcMapperTest : public ::testing::Test { constexpr auto reduction1DTCs = { R"TC( def sum1D(float(M) A) -> (C) { - C(0) +=! A(j) where i in 0:2 + C(0) +=! A(r_m) where i in 0:1 } )TC", R"TC( def sum1D(float(M) A) -> (C) { - C() +=! A(j) + C() +=! A(r_m) } )TC", R"TC( def sum1D(float(M) A) -> (C) { - C +=! A(j) + C +=! A(r_m) } )TC", R"TC( def sum1D(float(M) A) -> (C) { - C(i) +=! A(j) where i in 0:1 + C(i) +=! A(r_m) where i in 0:1 } )TC"}; @@ -176,7 +176,7 @@ struct TcMapper2DReductionTest : public TcMapperTest { bool skipCheck = false) { string tc = R"TC( def sum2D(float(M, N) A) -> (C) { - C(i) +=! A(i, j) + C(m) +=! A(m, r_n) } )TC"; auto refOutput = A.sum(1); @@ -300,7 +300,7 @@ struct TcMapperMatmulTest : public TcMapperTest { const tc::CudaMappingOptions& mappingOptions) { string tc = R"TC( def matmul(float(M, K) A, float(K, N) B) -> (C) { - C(i, j) +=! A(i, k) * B(k, j) + C(m, n) +=! A(m, r_k) * B(r_k, n) } )TC"; auto refOutput = A.mm(B); @@ -387,9 +387,9 @@ struct TcMapperBatchMatmulTest : public TcMapperTest { at::Tensor B, const tc::CudaMappingOptions& mappingOptions) { string tc = R"TC( - def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { - Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) - } +def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { + Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k) +} )TC"; auto refOutput = A.bmm(B); auto checkFun = [&, refOutput]( @@ -430,13 +430,9 @@ TEST_F(TcMapperTest, BatchTripleHadamard) { std::vector outputs; static constexpr auto TC = R"TC( - def batch_triple_hadamard(float(B, D) U, - float(B, D) V, - float(B, D) W) - -> (Z) - { - Z(b, d) = U(b, d) * V(b, d) * W(b, d) - } +def batch_triple_hadamard(float(B, D) U, float(B, D) V, float(B, D) W) -> (Z) { + Z(b, d) = U(b, d) * V(b, d) * W(b, d) +} )TC"; auto checkFun = [=](const std::vector& inputs, @@ -460,12 +456,9 @@ TEST_F(TcMapperTest, TensorDot) { std::vector outputs; static constexpr auto TC = R"TC( - def tensordot(float(N, C1, C2, H, W) I0, - float(N, C2, C3, H, W) I1) - -> (O) - { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) - } +def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} )TC"; // No defaults for this case auto checkFun = [](const std::vector& inputs, @@ -486,7 +479,7 @@ TEST_F(TcMapperTest, LUT) { static constexpr auto TC = R"TC( def fun(float(B, R) LUT, int32(B, N) I) -> (O) { - O(b, n) +=! LUT(I(b, n), r) + O(b, n) +=! LUT(I(b, n), r_r) } )TC"; @@ -533,22 +526,26 @@ TEST_F(TcMapperTest, DISABLED_SpatialBatchNormalization) { std::vector outputs; static constexpr auto TC = R"TC( - def spatial_batch_norm( +def spatial_batch_norm( float momentum, float eps, float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) - -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) - { - mean(c) +=! I(nn, c, hh, ww) - mean(c) = mean(c) / (N * H * W) - rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c) - centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) - variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) - expectedVariance(c) +=! (variance(n, c, h, w) + eps) / (N * H * W) - rVarOut(c) = rsqrt( - (1 - momentum) * rVarIn(c) + momentum * expectedVariance(c)) - O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) - normalizedOut(n, c, h, w) = O(n, c, h, w) - })TC"; +-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) +{ + mean(c) +=! I(r_n, c, r_h, r_w) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c) + + centered(n, c, h, w) = I( n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered( n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(r_n, c, r_h, r_w) + eps) / (N * H * W) + + rVarOut(c) = rsqrt( + (1 - momentum) * rVarIn(c) + + momentum * expectedVariance(c)) + + O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) +})TC"; auto checkFun = [=](const std::vector& inputs, std::vector& outputs) { diff --git a/test/test_tc_mapper_bugs.cc b/test/test_tc_mapper_bugs.cc index 44fdd4e2f..348247805 100644 --- a/test/test_tc_mapper_bugs.cc +++ b/test/test_tc_mapper_bugs.cc @@ -65,12 +65,10 @@ struct TensorDot_32_512_8_2_28_28 : public ::testing::Test { // Build naive options baseline to check correctness // Make naive compile first to better see debug spew auto TC = std::string(R"TC( - def tensordot_naive(float(N, C1, C2, H, W) I0, - float(N, C2, C3, H, W) I1) - -> (O) - { - O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) - } +def tensordot_naive(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) +{ + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} )TC"); // If running cuda-gdb only run on test code, not reference: things are @@ -92,10 +90,10 @@ struct TensorDot_32_512_8_2_28_28 : public ::testing::Test { auto TC = std::string("def ") + name + std::string(R"TC((float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) - -> (O) - { - O(n, c1, c3, h, w) +=! I0(n, c1, c2_red, h, w) * I1(n, c2_red, c3, h, w) - } +-> (O) +{ + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} )TC"); std::vector outputs; @@ -342,13 +340,13 @@ struct GroupConvolution_32_32_4_4_56_56_3_3 : public ::testing::Test { // Build naive options baseline to check correctness // Make naive compile first to better see debug spew auto TC = std::string(R"TC( - def group_convolution_naive(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W_, float(G,F) B) - -> (O) - { +def group_convolution_naive(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W_, float(G,F) B) +-> (O) +{ O(n, g, f, h, w) +=! - I(n, g, c, h + kh, w + kw) * W_(g, f, c, kh, kw) - O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) - } + I(n, g, r_c, h + r_kh, w + r_kw) * W_(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) +} )TC"); // If running cuda-gdb only run on test code, not reference: things are @@ -371,12 +369,12 @@ struct GroupConvolution_32_32_4_4_56_56_3_3 : public ::testing::Test { auto TC = std::string("def ") + name + std::string( R"TC((float(N,G,C,H,W) I, float(G,F,C,KH,KW) W_, float(G,F) B) - -> (O) - { +-> (O) +{ O(n, g, f, h, w) +=! - I(n, g, c, h + kh, w + kw) * W_(g, f, c, kh, kw) + I(n, g, r_c, h + r_kh, w + r_kw) * W_(g, f, r_c, r_kh, r_kw) O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) - } +} )TC"); std::vector outputs; @@ -461,9 +459,9 @@ struct C3_128_1000_1024 : public ::testing::Test { // Build naive options baseline to check correctness // Make naive compile first to better see debug spew auto TC = std::string(R"TC( - def _C3_naive(float(B,WX) I, float(WY, WX) W) -> (C3) { - C3(b, wy) +=! I(b, wxx) * W(wy, wxx) - } +def _C3_naive(float(B,WX) I, float(WY, WX) W) -> (C3) { + C3(b, wy) +=! I(b, r_wx) * W(wy, r_wx) +} )TC"); // If running cuda-gdb only run on test code, not reference: things are @@ -485,8 +483,8 @@ struct C3_128_1000_1024 : public ::testing::Test { auto TC = std::string("def ") + name + std::string( R"TC((float(B,WX) I, float(WY, WX) W) -> (C3) { - C3(b, wy) +=! I(b, wxx) * W(wy, wxx) - } + C3(b, wy) +=! I(b, r_wx) * W(wy, r_wx) +} )TC"); std::vector outputs; @@ -577,9 +575,9 @@ struct TMM_128_1024_1024 : public ::testing::Test { // Build naive options baseline to check correctness // Make naive compile first to better see debug spew auto TC = std::string(R"TC( - def tmm_naive(float(B, X) I, float(Y, X) W) -> (O) { - O(b, y) +=! I(b, rx) * W(y, rx) - } +def tmm_naive(float(B, X) I, float(Y, X) W) -> (O) { + O(b, y) +=! I(b, r_x) * W(y, r_x) +} )TC"); // If running cuda-gdb only run on test code, not reference: things are @@ -601,8 +599,8 @@ struct TMM_128_1024_1024 : public ::testing::Test { auto TC = std::string("def ") + name + std::string( R"TC((float(B, X) I, float(Y, X) W) -> (O) { - O(b, y) +=! I(b, rx) * W(y, rx) - } + O(b, y) +=! I(b, r_x) * W(y, r_x) +} )TC"); std::vector outputs; @@ -669,13 +667,14 @@ TEST(LayerNorm, ReferenceBelongsToTwoGroups) { std::vector outputs; static constexpr auto TC = R"TC( - def layernorm(float(T, B, C) I) -> (O, mean, centered, var) { - mean(t, b) +=! I(t, b, c) / C - centered(t, b, c) = I(t, b, c) - mean(t, b) - var(t, b) +=! centered(t, b, c) * centered(t, b, c) - var(t, b) = (var(t, b)) / C - O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) - } +def layernorm(float(T, B, C) I) -> (O, mean, centered, var) { + mean(t, b) +=! I(t, b, r_c) / C + centered(t, b, c) = I(t, b, c) - mean(t, b) + + var(t, b) +=! centered(t, b, r_c) * centered(t, b, r_c) + var(t, b) = var(t, b) / C + O(t, b, c) = centered(t, b, c) / rsqrt(var(t, b)) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions() .outerScheduleFusionStrategy(tc::FusionStrategy::Max) @@ -715,7 +714,7 @@ TEST(TMM_128_1024_1000, DisjunctiveFilter) { auto TC = std::string(R"TC( def tmm_naive(float(B, X) I, float(Y, X) W) -> (O) { - O(b, y) +=! I(b, rx) * W(y, rx) + O(b, y) +=! I(b, r_x) * W(y, r_x) } )TC"); auto options = @@ -749,10 +748,10 @@ TEST(Halide2Isl, MinInUpperBound) { std::vector inputs = {mat1, mat1_pad, mat2}; static constexpr auto TC = R"TC( - def graph2(float(N, C, H, W) I, float(N, C, R, T) J, float(KH, KW) W1) -> (O, Out) { - O(n, c, h, w) +=! J(n, c, h + kh, w + kw) * W1(kh, kw) - Out(i, j) +=! I(n, i, h, w) * O(n, j, h, w) - } +def graph2(float(N, C, H, W) I, float(N, C, R, T) J, float(KH, KW) W1) -> (O, Out) { + O(n, c, h, w) +=! J(n, c, h + r_kh, w + r_kw) * W1(r_kh, r_kw) + Out(c0, c1) +=! I(n, c0, h, w) * O( n, c1, h, w) +} )TC"; auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions();