Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WASM: Supporting Multi-Dimensional Arrays #70

Merged
merged 9 commits into from
Aug 2, 2022
32 changes: 30 additions & 2 deletions grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,33 @@ def visitField(self, field):
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class StatementsFirstWalkVisitorVisitor(ASTWalkVisitorVisitor, ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Statements First Visitor base class")
self.emit("")
self.emit("template <class Derived>")
self.emit("class StatementsFirstBaseWalkVisitor : public BaseVisitor<Derived>")
self.emit("{")
self.emit("private:")
self.emit(" Derived& self() { return static_cast<Derived&>(*this); }")
self.emit("public:")
super(ASTWalkVisitorVisitor, self).visitModule(mod)
self.emit("};")

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
self.used = False
have_body = False
for field in fields[::-1]:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
self.emit("}", 1)

# This class generates a visitor that prints the tree structure of AST/ASR
class TreeVisitorVisitor(ASDLVisitor):

Expand Down Expand Up @@ -2006,8 +2033,9 @@ def add_masks(fields, node):

visitors = [ASTNodeVisitor0, ASTNodeVisitor1, ASTNodeVisitor,
ASTVisitorVisitor1, ASTVisitorVisitor1b, ASTVisitorVisitor2,
ASTWalkVisitorVisitor, PickleVisitorVisitor,
SerializationVisitorVisitor, DeserializationVisitorVisitor]
ASTWalkVisitorVisitor, TreeVisitorVisitor, PickleVisitorVisitor,
StatementsFirstWalkVisitorVisitor, SerializationVisitorVisitor,
DeserializationVisitorVisitor]


def main(argv):
Expand Down
14 changes: 7 additions & 7 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@ RUN(NAME types_06 LABELS gfortran llvm cpp)
RUN(NAME doconcurrentloop_01 LABELS gfortran cpp)

RUN(NAME arrays_01 LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_01_size LABELS gfortran llvm)
RUN(NAME arrays_02_size LABELS gfortran llvm)
RUN(NAME arrays_01_size LABELS gfortran llvm wasm)
RUN(NAME arrays_02_size LABELS gfortran llvm wasm)
RUN(NAME matrix_01_transpose LABELS gfortran)
RUN(NAME matrix_02_matmul LABELS gfortran)
RUN(NAME array_01_pack LABELS gfortran)
RUN(NAME array_01_transfer LABELS gfortran)
RUN(NAME array_02_pack LABELS gfortran)
RUN(NAME array_02_transfer LABELS gfortran)
RUN(NAME array_03_transfer LABELS gfortran)
RUN(NAME arrays_01_real LABELS gfortran llvm)
RUN(NAME arrays_01_real LABELS gfortran llvm wasm)
RUN(NAME arrays_01_complex LABELS gfortran llvm)
RUN(NAME arrays_01_logical LABELS gfortran llvm)
RUN(NAME arrays_01_logical LABELS gfortran llvm wasm)
RUN(NAME array_bound_1 LABELS gfortran llvm)
RUN(NAME arrays_op_1 LABELS gfortran llvm)
RUN(NAME arrays_op_2 LABELS gfortran llvm)
Expand All @@ -198,15 +198,15 @@ RUN(NAME arrays_op_6 LABELS gfortran llvm)
RUN(NAME arrays_op_7 LABELS gfortran llvm)
RUN(NAME arrays_reshape_14 LABELS gfortran llvm)
RUN(NAME arrays_elemental_15 LABELS gfortran llvm)
RUN(NAME arrays_03_func LABELS gfortran cpp llvm)
RUN(NAME arrays_04_func LABELS gfortran cpp llvm)
RUN(NAME arrays_03_func LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_04_func LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_05 LABELS gfortran llvm cpp wasm)

# GFortran
RUN(NAME arrays_02 LABELS gfortran)
RUN(NAME arrays_06 LABELS gfortran llvm)
RUN(NAME arrays_07 LABELS gfortran llvm)
RUN(NAME arrays_08_func LABELS gfortran llvm)
RUN(NAME arrays_08_func LABELS gfortran llvm wasm)
RUN(NAME arrays_09 LABELS gfortran)
RUN(NAME arrays_10 LABELS gfortran)
RUN(NAME arrays_11 LABELS gfortran llvm cpp)
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ expr
| ListLen(expr arg, ttype type, expr? value)
| ListConcat(expr left, expr right, ttype type, expr? value)

| ArrayConstant(expr* args, ttype type)

| SetConstant(expr* elements, ttype type)
| SetLen(expr arg, ttype type, expr? value)

Expand All @@ -264,7 +262,10 @@ expr

| DictConstant(expr* keys, expr* values, ttype type)
| DictLen(expr arg, ttype type, expr? value)

| Var(symbol v)

| ArrayConstant(expr* args, ttype type)
| ArrayItem(expr v, array_index* args, ttype type, expr? value)
| ArraySection(expr v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
Expand All @@ -274,6 +275,7 @@ expr
| ArrayMatMul(expr matrix_a, expr matrix_b, ttype type, expr? value)
| ArrayPack(expr array, expr mask, expr? vector, ttype type, expr? value)
| ArrayReshape(expr array, expr shape, ttype type, expr? value)

| BitCast(expr source, expr mold, expr? size, ttype type, expr? value)
| DerivedRef(expr v, symbol m, ttype type, expr? value)
| OverloadedCompare(expr left, cmpop op, expr right, ttype type, expr? value, expr overloaded)
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ set(SRC
codegen/asr_to_wasm.cpp
codegen/wasm_to_wat.cpp
codegen/wasm_utils.cpp

pass/param_to_const.cpp
pass/do_loops.cpp
pass/for_all.cpp
Expand All @@ -33,6 +33,7 @@ set(SRC
pass/implied_do_loops.cpp
pass/array_op.cpp
pass/class_constructor.cpp
pass/arr_dims_propagate.cpp
pass/arr_slice.cpp
pass/print_arr.cpp
pass/pass_utils.cpp
Expand Down
75 changes: 60 additions & 15 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <libasr/codegen/wasm_assembler.h>
#include <libasr/pass/do_loops.h>
#include <libasr/pass/unused_functions.h>
#include <libasr/pass/arr_dims_propagate.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>

Expand Down Expand Up @@ -404,17 +405,16 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)v)) != m_var_name_idx_map.end())
wasm::emit_set_local(m_code_section, m_al, m_var_name_idx_map[get_hash((ASR::asr_t *)v)]);
} else if (ASRUtils::is_array(v->m_type)) {
ASR::dimension_t* m_dims;
uint32_t kind = ASRUtils::extract_kind_from_ttype_t(v->m_type);
uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(v->m_type, m_dims);

uint64_t total_array_size = 1;
for (uint32_t i = 0; i < n_dims; i++) {
ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[i].m_length);
uint64_t len_in_this_dim = -1;
ASRUtils::extract_value(length_value, len_in_this_dim);
total_array_size *= len_in_this_dim;

Vec<uint32_t> array_dims;
get_array_dims(*v, array_dims);

uint32_t total_array_size = 1;
for (auto &dim:array_dims) {
total_array_size *= dim;
}

LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)v)) != m_var_name_idx_map.end());
wasm::emit_i32_const(m_code_section, m_al, avail_mem_loc);
wasm::emit_set_local(m_code_section, m_al, m_var_name_idx_map[get_hash((ASR::asr_t *)v)]);
Expand Down Expand Up @@ -999,26 +999,41 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
}
}

void get_array_dims(const ASR::Variable_t &x, Vec<uint32_t> &dims) {
ASR::dimension_t* m_dims;
uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(x.m_type, m_dims);
dims.reserve(m_al, n_dims);
for (uint32_t i = 0; i < n_dims; i++) {
ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[i].m_length);
uint64_t len_in_this_dim = -1;
ASRUtils::extract_value(length_value, len_in_this_dim);
dims.push_back(m_al, (uint32_t)len_in_this_dim);
}
}

void emit_array_item_address_onto_stack(const ASR::ArrayItem_t &x) {
this->visit_expr(*x.m_v);
ASR::ttype_t* ttype = ASRUtils::expr_type(x.m_v);
uint32_t kind = ASRUtils::extract_kind_from_ttype_t(ttype);
// ASR::dimension_t* m_dims;
// uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(ttype, m_dims);
// ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[0].m_length);
// uint64_t array_size = -1;
// ASRUtils::extract_value(length_value, array_size);
Vec<uint32_t> array_dims;
get_array_dims(*ASRUtils::EXPR2VAR(x.m_v), array_dims);
uint32_t multiplier = 1;
wasm::emit_i32_const(m_code_section, m_al, 0);
for(uint32_t i = 0; i < x.n_args; i++) {
if (x.m_args[i].m_right) {
this->visit_expr(*x.m_args[i].m_right);
wasm::emit_i32_const(m_code_section, m_al, 1);
wasm::emit_i32_sub(m_code_section, m_al);
wasm::emit_i32_const(m_code_section, m_al, kind);
wasm::emit_i32_const(m_code_section, m_al, multiplier);
wasm::emit_i32_mul(m_code_section, m_al);
wasm::emit_i32_add(m_code_section, m_al);
multiplier *= array_dims[i];
} else {
diag.codegen_warning_label("/* FIXME right index */", {x.base.base.loc}, "");
}
}
wasm::emit_i32_const(m_code_section, m_al, kind);
wasm::emit_i32_mul(m_code_section, m_al);
wasm::emit_i32_add(m_code_section, m_al);
}

Expand All @@ -1027,6 +1042,35 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
emit_memory_load(x.m_v);
}

void visit_ArraySize(const ASR::ArraySize_t &x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return;
}
Vec<uint32_t> array_dims;
get_array_dims(*ASRUtils::EXPR2VAR(x.m_v), array_dims);
int kind = ASRUtils::extract_kind_from_ttype_t(x.m_type);
if (x.m_dim) {
uint32_t dim_idx = -1;
ASRUtils::extract_value(ASRUtils::expr_value(x.m_dim), dim_idx);
if (kind == 4) {
wasm::emit_i32_const(m_code_section, m_al, array_dims[dim_idx - 1]);
} else if (kind == 8) {
wasm::emit_i64_const(m_code_section, m_al, array_dims[dim_idx - 1]);
}
return;
}
uint32_t total_array_size = 1U;
for (auto &dim:array_dims) {
total_array_size *= dim;
}
if (kind == 4) {
wasm::emit_i32_const(m_code_section, m_al, total_array_size);
} else if (kind == 8) {
wasm::emit_i64_const(m_code_section, m_al, total_array_size);
}
}

void handle_return() {
if (cur_sym_info->return_var) {
LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)cur_sym_info->return_var)) != m_var_name_idx_map.end());
Expand Down Expand Up @@ -1545,6 +1589,7 @@ Result<Vec<uint8_t>> asr_to_wasm_bytes_stream(ASR::TranslationUnit_t &asr, Alloc

pass_unused_functions(al, asr, true);
pass_replace_do_loops(al, asr);
pass_propagate_arr_dims(al, asr);

// std::cout << pickle(asr, true /* use colors */, true /* indent */,
// true /* with_intrinsic_modules */) << std::endl;
Expand Down
66 changes: 66 additions & 0 deletions src/libasr/pass/arr_dims_propagate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <libasr/asr.h>
#include <libasr/containers.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/asr_verify.h>
#include <libasr/pass/for_all.h>
#include <libasr/pass/stmt_walk_visitor.h>

namespace LFortran {

/*
* This ASR pass replaces ttype for all arrays passed.
*
* Converts:
* integer :: a(:, :)
*
* to:
* integer :: a(2, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you changing this based on the function call? I don't think you can do that in general, consider, e.g.:

subroutine f(a)
integer :: a(:,:)
end subroutine

integer :: x(2,3)
call f(x)

Then you can indeed "specialize" f for this specific case. But what if you have:

integer :: x(2,3), y(10,11), z(20,21)
call f(x)
call f(y)
call f(z)

?

I think the way to do it is to just inline the function with our inline pass, then that should take care of it.

Alternatively, modify this pass to do the following:

Change this:

subroutine f(a, b)
integer :: a(:,:)
real :: b(:,:,:)
end subroutine

to this

subroutine f(na1, na2, a, nb1, nb2, nb3, b)
integer, intent(in) :: na1, na2, nb1, nb2, nb3
integer :: a(na1, na2)
real :: b(nb1, nb2, nb3)
end subroutine

And you have to change all places that call this function to change from:

call f(x, y)

to

call f(size(x,1), size(x, 2), x, size(y, 1), size(y, 2), size(y, 3), y)

This would be beautiful and extremely useful, and for cases when the array is contiguous, this would be equivalent (I think).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Effectively this would pass the array descriptor "by value" already at the ASR level. It would be an optional pass (for now you can use it in the WASM backend), but very useful down the road. These are the kinds of code transformations that we need to have, and later on we will combine them in various ways to get very good optimization pipeline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I will update the pass soon as suggested above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, yes, I think, the current approach would/will fail on the shared example.

Then you can indeed "specialize" f for this specific case. But what if you have:

integer :: x(2,3), y(10,11), z(20,21)
call f(x)
call f(y)
call f(z)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the suggested approach looks beautiful to me as well. Thank you for sharing it.

*/

class ArrDimsPropagate : public ASR::StatementsFirstBaseWalkVisitor<ArrDimsPropagate>
{
private:
Allocator &m_al;
public:
ArrDimsPropagate(Allocator &al) : m_al(al) { }

void visit_FunctionCall(const ASR::FunctionCall_t &x) {
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(ASRUtils::symbol_get_past_external(x.m_name));

for (size_t i = 0; i < x.n_args; i++) {
if (ASR::is_a<ASR::Var_t>(*x.m_args[i].m_value) && ASRUtils::is_array(ASRUtils::expr_type(x.m_args[i].m_value))) {
ASR::Variable_t* v = ASRUtils::EXPR2VAR(x.m_args[i].m_value);
ASR::Variable_t *fn_param = ASRUtils::EXPR2VAR(fn->m_args[i]);
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(fn_param->m_type, m_dims);
if (n_dims > 0 && !m_dims[0].m_length && ASRUtils::check_equal_type(v->m_type, fn_param->m_type)) {
fn_param->m_type = v->m_type;
}
}
}
}

void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
ASR::Subroutine_t *sb = ASR::down_cast<ASR::Subroutine_t>(ASRUtils::symbol_get_past_external(x.m_name));
for (size_t i = 0; i < x.n_args; i++) {
if (ASR::is_a<ASR::Var_t>(*x.m_args[i].m_value) && ASRUtils::is_array(ASRUtils::expr_type(x.m_args[i].m_value))) {
ASR::Variable_t* v = ASRUtils::EXPR2VAR(x.m_args[i].m_value);
ASR::Variable_t *sb_param = ASRUtils::EXPR2VAR(sb->m_args[i]);
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(sb_param->m_type, m_dims);
if (n_dims > 0 && !m_dims[0].m_length && ASRUtils::check_equal_type(v->m_type, sb_param->m_type)) {
sb_param->m_type = v->m_type;
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, every time I see identical code like this, I get more and more convinced we should just merge Subroutine/Function: lcompilers/lpython#866.

};

void pass_propagate_arr_dims(Allocator &al, ASR::TranslationUnit_t &unit) {
ArrDimsPropagate v(al);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}

} // namespace LFortran
12 changes: 12 additions & 0 deletions src/libasr/pass/arr_dims_propagate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef LFORTRAN_PASS_ARR_DIMS_PROPAGATE
#define LFORTRAN_PASS_ARR_DIMS_PROPAGATE

#include <libasr/asr.h>

namespace LFortran {

void pass_propagate_arr_dims(Allocator &al, ASR::TranslationUnit_t &unit);

} // namespace LFortran

#endif // LFORTRAN_PASS_ARR_DIMS_PROPAGATE