Skip to content

Support multiple assignments #978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions integration_tests/expr_09.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,41 @@ def main0():
print(-i1 ^ -i2)
assert -i1 ^ -i2 == 6


def test_multiple_assign_1():
a: i32; b: i32; c: i32
d: f64; e: f32; g: i32
g = 5
d = e = g + 1.0
a = b = c = 10
assert a == b
assert b == c
assert a == 10
x: f32; y: f64
x = y = 23.0
assert abs(x - 23.0) < 1e-6
assert abs(y - 23.0) < 1e-12
assert abs(e - 6.0) < 1e-6
assert abs(d - 6.0) < 1e-12
i: list[f64]; j: list[f64]; k: list[f64] = []
g = 0
for g in range(10):
k.append(g*2.0 + 5.0)
i = j = k
for g in range(10):
assert abs(i[g] - j[g]) < 1e-12
assert abs(i[g] - k[g]) < 1e-12
assert abs(g*2.0 + 5.0 - k[g]) < 1e-12


def test_issue_928():
a: i32; b: i32; c: tuple[i32, i32]
a, b = c = 2, 1
assert a == 2
assert b == 1
assert c[0] == a and c[1] == b


test_multiple_assign_1()
test_issue_928()
main0()
110 changes: 67 additions & 43 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {


ASR::asr_t *tmp;

/*
If `tmp` is not null, then `tmp_vec` is ignored and `tmp` is used as the only result (statement or
expression). If `tmp` is null, then `tmp_vec` is used to return any number of statements:
0 (no statement returned), 1 (redundant, one should use `tmp` for that), 2, 3, ... etc.
*/
std::vector<ASR::asr_t *> tmp_vec;

Allocator &al;
SymbolTable *current_scope;
// The current_module contains the current module that is being visited;
Expand Down Expand Up @@ -732,7 +740,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
if (ASR::is_a<ASR::Function_t>(*t)) {
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
}
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
}
if (ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
ASR::ttype_t *a_type = nullptr;
Expand Down Expand Up @@ -2421,7 +2429,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
/* a_name */ s2c(al, sym_name),
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_type_params */ tps.p,
/* a_type_params */ tps.p,
/* n_type_params */ tps.size(),
/* a_body */ nullptr,
/* n_body */ 0,
Expand Down Expand Up @@ -2648,6 +2656,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
// The `body` Vec must already be reserved
void transform_stmts(Vec<ASR::stmt_t*> &body, size_t n_body, AST::stmt_t **m_body) {
tmp = nullptr;
tmp_vec.clear();
Vec<ASR::stmt_t*>* current_body_copy = current_body;
current_body = &body;
for (size_t i=0; i<n_body; i++) {
Expand All @@ -2656,9 +2665,18 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if (tmp != nullptr) {
ASR::stmt_t* tmp_stmt = ASRUtils::STMT(tmp);
body.push_back(al, tmp_stmt);
} else if (!tmp_vec.empty()) {
for (auto t: tmp_vec) {
if (t != nullptr) {
ASR::stmt_t* tmp_stmt = ASRUtils::STMT(t);
body.push_back(al, tmp_stmt);
}
}
tmp_vec.clear();
}
// To avoid last statement to be entered twice once we exit this node
tmp = nullptr;
tmp_vec.clear();
}
current_body = current_body_copy;
}
Expand All @@ -2677,9 +2695,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
items.reserve(al, 4);
for (size_t i=0; i<x.n_body; i++) {
tmp = nullptr;
tmp_vec.clear();
visit_stmt(*x.m_body[i]);
if (tmp) {
items.push_back(al, tmp);
} else if (!tmp_vec.empty()) {
for (auto t: tmp_vec) {
if (t) items.push_back(al, t);
}
// Ensure that statements in tmp_vec are used only once.
tmp_vec.clear();
}
}
// These global statements are added to the translation unit for now,
Expand Down Expand Up @@ -2782,10 +2807,21 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}

void visit_Assign(const AST::Assign_t &x) {
ASR::expr_t *target;
if (x.n_targets == 1) {
if (AST::is_a<AST::Subscript_t>(*x.m_targets[0])) {
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_targets[0]);
ASR::expr_t *target, *assign_value = nullptr, *tmp_value;
this->visit_expr(*x.m_value);
if (tmp) {
// This happens if `m.m_value` is `empty`, such as in:
// a = empty(16)
// We skip this statement for now, the array is declared
// by the annotation.
// TODO: enforce that empty(), ones(), zeros() is called
// for every declaration.
assign_value = ASRUtils::EXPR(tmp);
}
for (size_t i=0; i<x.n_targets; i++) {
tmp_value = assign_value;
if (AST::is_a<AST::Subscript_t>(*x.m_targets[i])) {
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_targets[i]);
if (AST::is_a<AST::Name_t>(*sb->m_value)) {
std::string name = AST::down_cast<AST::Name_t>(sb->m_value)->m_id;
ASR::symbol_t *s = current_scope->get_symbol(name);
Expand All @@ -2799,8 +2835,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
// dict insert case;
this->visit_expr(*sb->m_slice);
ASR::expr_t *key = ASRUtils::EXPR(tmp);
this->visit_expr(*x.m_value);
ASR::expr_t *val = ASRUtils::EXPR(tmp);
ASR::ttype_t *key_type = ASR::down_cast<ASR::Dict_t>(type)->m_key_type;
ASR::ttype_t *value_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) {
Expand All @@ -2815,29 +2849,30 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
);
throw SemanticAbort();
}
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(val), value_type)) {
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(val));
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) {
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value));
std::string totype = ASRUtils::type_to_str_python(value_type);
diag.add(diag::Diagnostic(
"Type mismatch in dictionary value, the types must be compatible",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')",
{val->base.loc})
{tmp_value->base.loc})
})
);
throw SemanticAbort();
}
ASR::expr_t* se = ASR::down_cast<ASR::expr_t>(
ASR::make_Var_t(al, x.base.base.loc, s));
tmp = make_DictInsert_t(al, x.base.base.loc, se, key, val);
return;
tmp = nullptr;
tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value));
continue;
} else if (ASRUtils::is_immutable(type)) {
throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support"
" item assignment", x.base.base.loc);
}
}
} else if (AST::is_a<AST::Attribute_t>(*x.m_targets[0])) {
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_targets[0]);
} else if (AST::is_a<AST::Attribute_t>(*x.m_targets[i])) {
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_targets[i]);
if (AST::is_a<AST::Name_t>(*attr->m_value)) {
std::string name = AST::down_cast<AST::Name_t>(attr->m_value)->m_id;
ASR::symbol_t *s = current_scope->get_symbol(name);
Expand All @@ -2852,62 +2887,51 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}
}
}
this->visit_expr(*x.m_targets[0]);
if (!tmp_value) continue;
this->visit_expr(*x.m_targets[i]);
target = ASRUtils::EXPR(tmp);
} else {
throw SemanticError("Assignment to multiple targets not supported",
x.base.base.loc);
}

this->visit_expr(*x.m_value);
if (tmp == nullptr) {
// This happens if `m.m_value` is `empty`, such as in:
// a = empty(16)
// We skip this statement for now, the array is declared
// by the annotation.
// TODO: enforce that empty(), ones(), zeros() is called
// for every declaration.
tmp = nullptr;
} else {
ASR::expr_t *value = ASRUtils::EXPR(tmp);
ASR::ttype_t *target_type = ASRUtils::expr_type(target);
ASR::ttype_t *value_type = ASRUtils::expr_type(value);
ASR::ttype_t *value_type = ASRUtils::expr_type(tmp_value);
if( ASR::is_a<ASR::Pointer_t>(*target_type) &&
ASR::is_a<ASR::Var_t>(*target) ) {
if( !ASR::is_a<ASR::GetPointer_t>(*value) ) {
if( !ASR::is_a<ASR::GetPointer_t>(*tmp_value) ) {
throw SemanticError("A pointer variable can only "
"be associated with the output "
"of pointer() call.",
value->base.loc);
tmp_value->base.loc);
}
if( !ASRUtils::check_equal_type(target_type, value_type) ) {
throw SemanticError("Casting not supported for different pointer types. Received "
"target pointer type, " + ASRUtils::type_to_str_python(target_type) +
" and value pointer type, " + ASRUtils::type_to_str_python(value_type),
x.base.base.loc);
}
tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, value, nullptr);
return ;
tmp = nullptr;
tmp_vec.push_back(ASR::make_Assignment_t(al, x.base.base.loc, target,
tmp_value, nullptr));
continue;
}

cast_helper(target, value, true);
value_type = ASRUtils::expr_type(value);
cast_helper(target, tmp_value, true);
value_type = ASRUtils::expr_type(tmp_value);
if (!ASRUtils::check_equal_type(target_type, value_type)) {
std::string ltype = ASRUtils::type_to_str_python(target_type);
std::string rtype = ASRUtils::type_to_str_python(value_type);
diag.add(diag::Diagnostic(
"Type mismatch in assignment, the types must be compatible",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')",
{target->base.loc, value->base.loc})
{target->base.loc, tmp_value->base.loc})
})
);
throw SemanticAbort();
}
ASR::stmt_t *overloaded=nullptr;
tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, value,
overloaded);
tmp = nullptr;
tmp_vec.push_back(ASR::make_Assignment_t(al, x.base.base.loc, target, tmp_value,
overloaded));
}
// to make sure that we add only those statements in tmp_vec
tmp = nullptr;
}

void visit_Assert(const AST::Assert_t &x) {
Expand Down
2 changes: 0 additions & 2 deletions src/lpython/semantics/python_comptime_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ struct PythonIntrinsicProcedures {
throw SemanticError("str.capitalize() takes no arguments", loc);
}
ASR::expr_t *arg = args[0];
ASR::ttype_t *arg_type = ASRUtils::expr_type(arg);
std::string val = ASR::down_cast<ASR::StringConstant_t>(arg)->m_s;
if (val.size()) {
val[0] = std::toupper(val[0]);
Expand All @@ -676,7 +675,6 @@ struct PythonIntrinsicProcedures {
throw SemanticError("str.lower() takes no arguments", loc);
}
ASR::expr_t *arg = args[0];
ASR::ttype_t *arg_type = ASRUtils::expr_type(arg);
std::string val = ASR::down_cast<ASR::StringConstant_t>(arg)->m_s;
for (auto &i: val) {
if (i >= 'A' && i <= 'Z') {
Expand Down
4 changes: 2 additions & 2 deletions tests/reference/asr-expr_09-f3e89c8.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
"basename": "asr-expr_09-f3e89c8",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/../integration_tests/expr_09.py",
"infile_hash": "7a3cdb6538c8d2d8e4555683aeac4f9b074be2fcaa6fe4532c01bf1a",
"infile_hash": "51dfe55e01443840104d583e5e21ba3dd48fa33a95f1f943aac1d5d0",
"outfile": null,
"outfile_hash": null,
"stdout": "asr-expr_09-f3e89c8.stdout",
"stdout_hash": "167f5176a21663f13aff75078c19a6bd7e07d7cab2605ef6302d9b8a",
"stdout_hash": "66cf441a7ed60ad292ae9933ae51d8aac76f803201b809eac2689a33",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading