Skip to content

Commit

Permalink
Merge pull request #70 from Shaikh-Ubaid/wasm_multi_dim_arrays
Browse files Browse the repository at this point in the history
WASM: Supporting Multi-Dimensional Arrays
  • Loading branch information
Shaikh-Ubaid committed Aug 2, 2022
2 parents 1e4c01a + c2b280c commit df50564
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 27 deletions.
32 changes: 30 additions & 2 deletions grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,33 @@ def visitField(self, field):
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class StatementsFirstWalkVisitorVisitor(ASTWalkVisitorVisitor, ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Statements First Visitor base class")
self.emit("")
self.emit("template <class Derived>")
self.emit("class StatementsFirstBaseWalkVisitor : public BaseVisitor<Derived>")
self.emit("{")
self.emit("private:")
self.emit(" Derived& self() { return static_cast<Derived&>(*this); }")
self.emit("public:")
super(ASTWalkVisitorVisitor, self).visitModule(mod)
self.emit("};")

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
self.used = False
have_body = False
for field in fields[::-1]:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
self.emit("}", 1)

# This class generates a visitor that prints the tree structure of AST/ASR
class TreeVisitorVisitor(ASDLVisitor):

Expand Down Expand Up @@ -2006,8 +2033,9 @@ def add_masks(fields, node):

visitors = [ASTNodeVisitor0, ASTNodeVisitor1, ASTNodeVisitor,
ASTVisitorVisitor1, ASTVisitorVisitor1b, ASTVisitorVisitor2,
ASTWalkVisitorVisitor, PickleVisitorVisitor,
SerializationVisitorVisitor, DeserializationVisitorVisitor]
ASTWalkVisitorVisitor, TreeVisitorVisitor, PickleVisitorVisitor,
StatementsFirstWalkVisitorVisitor, SerializationVisitorVisitor,
DeserializationVisitorVisitor]


def main(argv):
Expand Down
14 changes: 7 additions & 7 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@ RUN(NAME types_06 LABELS gfortran llvm cpp)
RUN(NAME doconcurrentloop_01 LABELS gfortran cpp)

RUN(NAME arrays_01 LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_01_size LABELS gfortran llvm)
RUN(NAME arrays_02_size LABELS gfortran llvm)
RUN(NAME arrays_01_size LABELS gfortran llvm wasm)
RUN(NAME arrays_02_size LABELS gfortran llvm wasm)
RUN(NAME matrix_01_transpose LABELS gfortran)
RUN(NAME matrix_02_matmul LABELS gfortran)
RUN(NAME array_01_pack LABELS gfortran)
RUN(NAME array_01_transfer LABELS gfortran)
RUN(NAME array_02_pack LABELS gfortran)
RUN(NAME array_02_transfer LABELS gfortran)
RUN(NAME array_03_transfer LABELS gfortran)
RUN(NAME arrays_01_real LABELS gfortran llvm)
RUN(NAME arrays_01_real LABELS gfortran llvm wasm)
RUN(NAME arrays_01_complex LABELS gfortran llvm)
RUN(NAME arrays_01_logical LABELS gfortran llvm)
RUN(NAME arrays_01_logical LABELS gfortran llvm wasm)
RUN(NAME array_bound_1 LABELS gfortran llvm)
RUN(NAME arrays_op_1 LABELS gfortran llvm)
RUN(NAME arrays_op_2 LABELS gfortran llvm)
Expand All @@ -198,15 +198,15 @@ RUN(NAME arrays_op_6 LABELS gfortran llvm)
RUN(NAME arrays_op_7 LABELS gfortran llvm)
RUN(NAME arrays_reshape_14 LABELS gfortran llvm)
RUN(NAME arrays_elemental_15 LABELS gfortran llvm)
RUN(NAME arrays_03_func LABELS gfortran cpp llvm)
RUN(NAME arrays_04_func LABELS gfortran cpp llvm)
RUN(NAME arrays_03_func LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_04_func LABELS gfortran cpp llvm wasm)
RUN(NAME arrays_05 LABELS gfortran llvm cpp wasm)

# GFortran
RUN(NAME arrays_02 LABELS gfortran)
RUN(NAME arrays_06 LABELS gfortran llvm)
RUN(NAME arrays_07 LABELS gfortran llvm)
RUN(NAME arrays_08_func LABELS gfortran llvm)
RUN(NAME arrays_08_func LABELS gfortran llvm wasm)
RUN(NAME arrays_09 LABELS gfortran)
RUN(NAME arrays_10 LABELS gfortran)
RUN(NAME arrays_11 LABELS gfortran llvm cpp)
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ expr
| ListLen(expr arg, ttype type, expr? value)
| ListConcat(expr left, expr right, ttype type, expr? value)

| ArrayConstant(expr* args, ttype type)

| SetConstant(expr* elements, ttype type)
| SetLen(expr arg, ttype type, expr? value)

Expand All @@ -264,7 +262,10 @@ expr

| DictConstant(expr* keys, expr* values, ttype type)
| DictLen(expr arg, ttype type, expr? value)

| Var(symbol v)

| ArrayConstant(expr* args, ttype type)
| ArrayItem(expr v, array_index* args, ttype type, expr? value)
| ArraySection(expr v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
Expand All @@ -274,6 +275,7 @@ expr
| ArrayMatMul(expr matrix_a, expr matrix_b, ttype type, expr? value)
| ArrayPack(expr array, expr mask, expr? vector, ttype type, expr? value)
| ArrayReshape(expr array, expr shape, ttype type, expr? value)

| BitCast(expr source, expr mold, expr? size, ttype type, expr? value)
| DerivedRef(expr v, symbol m, ttype type, expr? value)
| OverloadedCompare(expr left, cmpop op, expr right, ttype type, expr? value, expr overloaded)
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ set(SRC
codegen/asr_to_wasm.cpp
codegen/wasm_to_wat.cpp
codegen/wasm_utils.cpp

pass/param_to_const.cpp
pass/do_loops.cpp
pass/for_all.cpp
Expand All @@ -33,6 +33,7 @@ set(SRC
pass/implied_do_loops.cpp
pass/array_op.cpp
pass/class_constructor.cpp
pass/arr_dims_propagate.cpp
pass/arr_slice.cpp
pass/print_arr.cpp
pass/pass_utils.cpp
Expand Down
75 changes: 60 additions & 15 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <libasr/codegen/wasm_assembler.h>
#include <libasr/pass/do_loops.h>
#include <libasr/pass/unused_functions.h>
#include <libasr/pass/arr_dims_propagate.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>

Expand Down Expand Up @@ -404,17 +405,16 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)v)) != m_var_name_idx_map.end())
wasm::emit_set_local(m_code_section, m_al, m_var_name_idx_map[get_hash((ASR::asr_t *)v)]);
} else if (ASRUtils::is_array(v->m_type)) {
ASR::dimension_t* m_dims;
uint32_t kind = ASRUtils::extract_kind_from_ttype_t(v->m_type);
uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(v->m_type, m_dims);

uint64_t total_array_size = 1;
for (uint32_t i = 0; i < n_dims; i++) {
ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[i].m_length);
uint64_t len_in_this_dim = -1;
ASRUtils::extract_value(length_value, len_in_this_dim);
total_array_size *= len_in_this_dim;

Vec<uint32_t> array_dims;
get_array_dims(*v, array_dims);

uint32_t total_array_size = 1;
for (auto &dim:array_dims) {
total_array_size *= dim;
}

LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)v)) != m_var_name_idx_map.end());
wasm::emit_i32_const(m_code_section, m_al, avail_mem_loc);
wasm::emit_set_local(m_code_section, m_al, m_var_name_idx_map[get_hash((ASR::asr_t *)v)]);
Expand Down Expand Up @@ -999,26 +999,41 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
}
}

void get_array_dims(const ASR::Variable_t &x, Vec<uint32_t> &dims) {
ASR::dimension_t* m_dims;
uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(x.m_type, m_dims);
dims.reserve(m_al, n_dims);
for (uint32_t i = 0; i < n_dims; i++) {
ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[i].m_length);
uint64_t len_in_this_dim = -1;
ASRUtils::extract_value(length_value, len_in_this_dim);
dims.push_back(m_al, (uint32_t)len_in_this_dim);
}
}

void emit_array_item_address_onto_stack(const ASR::ArrayItem_t &x) {
this->visit_expr(*x.m_v);
ASR::ttype_t* ttype = ASRUtils::expr_type(x.m_v);
uint32_t kind = ASRUtils::extract_kind_from_ttype_t(ttype);
// ASR::dimension_t* m_dims;
// uint32_t n_dims = ASRUtils::extract_dimensions_from_ttype(ttype, m_dims);
// ASR::expr_t* length_value = ASRUtils::expr_value(m_dims[0].m_length);
// uint64_t array_size = -1;
// ASRUtils::extract_value(length_value, array_size);
Vec<uint32_t> array_dims;
get_array_dims(*ASRUtils::EXPR2VAR(x.m_v), array_dims);
uint32_t multiplier = 1;
wasm::emit_i32_const(m_code_section, m_al, 0);
for(uint32_t i = 0; i < x.n_args; i++) {
if (x.m_args[i].m_right) {
this->visit_expr(*x.m_args[i].m_right);
wasm::emit_i32_const(m_code_section, m_al, 1);
wasm::emit_i32_sub(m_code_section, m_al);
wasm::emit_i32_const(m_code_section, m_al, kind);
wasm::emit_i32_const(m_code_section, m_al, multiplier);
wasm::emit_i32_mul(m_code_section, m_al);
wasm::emit_i32_add(m_code_section, m_al);
multiplier *= array_dims[i];
} else {
diag.codegen_warning_label("/* FIXME right index */", {x.base.base.loc}, "");
}
}
wasm::emit_i32_const(m_code_section, m_al, kind);
wasm::emit_i32_mul(m_code_section, m_al);
wasm::emit_i32_add(m_code_section, m_al);
}

Expand All @@ -1027,6 +1042,35 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
emit_memory_load(x.m_v);
}

void visit_ArraySize(const ASR::ArraySize_t &x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return;
}
Vec<uint32_t> array_dims;
get_array_dims(*ASRUtils::EXPR2VAR(x.m_v), array_dims);
int kind = ASRUtils::extract_kind_from_ttype_t(x.m_type);
if (x.m_dim) {
uint32_t dim_idx = -1;
ASRUtils::extract_value(ASRUtils::expr_value(x.m_dim), dim_idx);
if (kind == 4) {
wasm::emit_i32_const(m_code_section, m_al, array_dims[dim_idx - 1]);
} else if (kind == 8) {
wasm::emit_i64_const(m_code_section, m_al, array_dims[dim_idx - 1]);
}
return;
}
uint32_t total_array_size = 1U;
for (auto &dim:array_dims) {
total_array_size *= dim;
}
if (kind == 4) {
wasm::emit_i32_const(m_code_section, m_al, total_array_size);
} else if (kind == 8) {
wasm::emit_i64_const(m_code_section, m_al, total_array_size);
}
}

void handle_return() {
if (cur_sym_info->return_var) {
LFORTRAN_ASSERT(m_var_name_idx_map.find(get_hash((ASR::asr_t *)cur_sym_info->return_var)) != m_var_name_idx_map.end());
Expand Down Expand Up @@ -1545,6 +1589,7 @@ Result<Vec<uint8_t>> asr_to_wasm_bytes_stream(ASR::TranslationUnit_t &asr, Alloc

pass_unused_functions(al, asr, true);
pass_replace_do_loops(al, asr);
pass_propagate_arr_dims(al, asr);

// std::cout << pickle(asr, true /* use colors */, true /* indent */,
// true /* with_intrinsic_modules */) << std::endl;
Expand Down
66 changes: 66 additions & 0 deletions src/libasr/pass/arr_dims_propagate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <libasr/asr.h>
#include <libasr/containers.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/asr_verify.h>
#include <libasr/pass/for_all.h>
#include <libasr/pass/stmt_walk_visitor.h>

namespace LFortran {

/*
* This ASR pass replaces ttype for all arrays passed.
*
* Converts:
* integer :: a(:, :)
*
* to:
* integer :: a(2, 3)
*/

class ArrDimsPropagate : public ASR::StatementsFirstBaseWalkVisitor<ArrDimsPropagate>
{
private:
Allocator &m_al;
public:
ArrDimsPropagate(Allocator &al) : m_al(al) { }

void visit_FunctionCall(const ASR::FunctionCall_t &x) {
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(ASRUtils::symbol_get_past_external(x.m_name));

for (size_t i = 0; i < x.n_args; i++) {
if (ASR::is_a<ASR::Var_t>(*x.m_args[i].m_value) && ASRUtils::is_array(ASRUtils::expr_type(x.m_args[i].m_value))) {
ASR::Variable_t* v = ASRUtils::EXPR2VAR(x.m_args[i].m_value);
ASR::Variable_t *fn_param = ASRUtils::EXPR2VAR(fn->m_args[i]);
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(fn_param->m_type, m_dims);
if (n_dims > 0 && !m_dims[0].m_length && ASRUtils::check_equal_type(v->m_type, fn_param->m_type)) {
fn_param->m_type = v->m_type;
}
}
}
}

void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
ASR::Subroutine_t *sb = ASR::down_cast<ASR::Subroutine_t>(ASRUtils::symbol_get_past_external(x.m_name));
for (size_t i = 0; i < x.n_args; i++) {
if (ASR::is_a<ASR::Var_t>(*x.m_args[i].m_value) && ASRUtils::is_array(ASRUtils::expr_type(x.m_args[i].m_value))) {
ASR::Variable_t* v = ASRUtils::EXPR2VAR(x.m_args[i].m_value);
ASR::Variable_t *sb_param = ASRUtils::EXPR2VAR(sb->m_args[i]);
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(sb_param->m_type, m_dims);
if (n_dims > 0 && !m_dims[0].m_length && ASRUtils::check_equal_type(v->m_type, sb_param->m_type)) {
sb_param->m_type = v->m_type;
}
}
}
}
};

void pass_propagate_arr_dims(Allocator &al, ASR::TranslationUnit_t &unit) {
ArrDimsPropagate v(al);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}

} // namespace LFortran
12 changes: 12 additions & 0 deletions src/libasr/pass/arr_dims_propagate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef LFORTRAN_PASS_ARR_DIMS_PROPAGATE
#define LFORTRAN_PASS_ARR_DIMS_PROPAGATE

#include <libasr/asr.h>

namespace LFortran {

void pass_propagate_arr_dims(Allocator &al, ASR::TranslationUnit_t &unit);

} // namespace LFortran

#endif // LFORTRAN_PASS_ARR_DIMS_PROPAGATE

0 comments on commit df50564

Please sign in to comment.