Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,63 @@ ASR::asr_t* symbol_resolve_external_generic_procedure_without_eval(
}
}

} // namespace ASRUtils
ASR::asr_t* make_Cast_t_value(Allocator &al, const Location &a_loc,
ASR::expr_t* a_arg, ASR::cast_kindType a_kind, ASR::ttype_t* a_type) {

ASR::expr_t* value = nullptr;

if (ASRUtils::expr_value(a_arg)) {
// calculate value
if (a_kind == ASR::cast_kindType::RealToInteger) {
int64_t v = ASR::down_cast<ASR::RealConstant_t>(
ASRUtils::expr_value(a_arg))->m_r;
value = ASR::down_cast<ASR::expr_t>(
ASR::make_IntegerConstant_t(al, a_loc, v, a_type));
} else if (a_kind == ASR::cast_kindType::RealToReal) {
double v = ASR::down_cast<ASR::RealConstant_t>(
ASRUtils::expr_value(a_arg))->m_r;
value = ASR::down_cast<ASR::expr_t>(
ASR::make_RealConstant_t(al, a_loc, v, a_type));
} else if (a_kind == ASR::cast_kindType::RealToComplex) {
double double_value = ASR::down_cast<ASR::RealConstant_t>(
ASRUtils::expr_value(a_arg))->m_r;
value = ASR::down_cast<ASR::expr_t>(ASR::make_ComplexConstant_t(al, a_loc,
double_value, 0, a_type));
} else if (a_kind == ASR::cast_kindType::IntegerToReal) {
// TODO: Clashes with the pow functions
// int64_t value = ASR::down_cast<ASR::ConstantInteger_t>(ASRUtils::expr_value(a_arg))->m_n;
// value = ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, a_loc, (double)v, a_type));
} else if (a_kind == ASR::cast_kindType::IntegerToComplex) {
int64_t int_value = ASR::down_cast<ASR::IntegerConstant_t>(
ASRUtils::expr_value(a_arg))->m_n;
value = ASR::down_cast<ASR::expr_t>(ASR::make_ComplexConstant_t(al, a_loc,
(double)int_value, 0, a_type));
} else if (a_kind == ASR::cast_kindType::IntegerToInteger) {
// TODO: implement
// int64_t v = ASR::down_cast<ASR::ConstantInteger_t>(ASRUtils::expr_value(a_arg))->m_n;
// value = ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, a_loc, v, a_type));
} else if (a_kind == ASR::cast_kindType::IntegerToLogical) {
// TODO: implement
} else if (a_kind == ASR::cast_kindType::ComplexToComplex) {
ASR::ComplexConstant_t* value_complex = ASR::down_cast<ASR::ComplexConstant_t>(
ASRUtils::expr_value(a_arg));
double real = value_complex->m_re;
double imag = value_complex->m_im;
value = ASR::down_cast<ASR::expr_t>(
ASR::make_ComplexConstant_t(al, a_loc, real, imag, a_type));
} else if (a_kind == ASR::cast_kindType::ComplexToReal) {
ASR::ComplexConstant_t* value_complex = ASR::down_cast<ASR::ComplexConstant_t>(
ASRUtils::expr_value(a_arg));
double real = value_complex->m_re;
value = ASR::down_cast<ASR::expr_t>(
ASR::make_RealConstant_t(al, a_loc, real, a_type));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can fix the formatting of these just like I suggested above, that would be great. I think it's ready after that.

}

return ASR::make_Cast_t(al, a_loc, a_arg, a_kind, a_type, value);
}

} // namespace ASRUtils


} // namespace LFortran
4 changes: 4 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,10 @@ ASR::asr_t* symbol_resolve_external_generic_procedure_without_eval(
SymbolTable* current_scope, Allocator& al,
const std::function<void (const std::string &, const Location &)> err);

// Creates an Cast node and automatically computes the `value` if it can be computed at compile time
ASR::asr_t* make_Cast_t_value(Allocator &al, const Location &a_loc,
ASR::expr_t* a_arg, ASR::cast_kindType a_kind, ASR::ttype_t* a_type);

} // namespace ASRUtils

} // namespace LFortran
Expand Down
89 changes: 38 additions & 51 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,30 +766,30 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
int lkind = ASR::down_cast<ASR::Integer_t>(left_type)->m_kind;
int rkind = ASR::down_cast<ASR::Integer_t>(right_type)->m_kind;
if ((is_assign && (lkind != rkind)) || (lkind > rkind)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToInteger,
left_type, nullptr));
left_type));
}
} else if (ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type)) {
bool is_l64 = ASR::down_cast<ASR::Real_t>(left_type)->m_kind == 8;
bool is_r64 = ASR::down_cast<ASR::Real_t>(right_type)->m_kind == 8;
if ((is_assign && (is_l64 != is_r64)) || (is_l64 && !is_r64)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::RealToReal,
left_type, nullptr));
left_type));
}
} else if (ASRUtils::is_complex(*left_type) && ASRUtils::is_complex(*right_type)) {
bool is_l64 = ASR::down_cast<ASR::Complex_t>(left_type)->m_kind == 8;
bool is_r64 = ASR::down_cast<ASR::Complex_t>(right_type)->m_kind == 8;
if ((is_assign && (is_l64 != is_r64)) || (is_l64 && !is_r64)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::ComplexToComplex,
left_type, nullptr));
left_type));
}
} else if (!is_assign && ASRUtils::is_real(*left_type) && ASRUtils::is_integer(*right_type)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal,
left_type, nullptr));
left_type));
} else if (is_assign && ASRUtils::is_real(*left_type) && ASRUtils::is_integer(*right_type)) {
throw SemanticError("Assigning integer to float is not supported",
right->base.loc);
Expand All @@ -798,22 +798,20 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
right->base.loc);
} else if (!is_assign && ASRUtils::is_complex(*left_type) && !ASRUtils::is_complex(*right_type)) {
if (ASRUtils::is_real(*right_type)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::RealToComplex,
left_type, nullptr));
left_type));
} else if (ASRUtils::is_integer(*right_type)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToComplex,
left_type, nullptr));
left_type));
} else if (ASRUtils::is_logical(*right_type)) {
ASR::ttype_t* int_type = ASRUtils::TYPE(ASR::make_Integer_t(al,
right->base.loc, 4, nullptr, 0));
right = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type,
nullptr));
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::IntegerToComplex, left_type,
nullptr));
right = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type));
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToComplex, left_type));
} else {
std::string rtype = ASRUtils::type_to_str_python(right_type);
throw SemanticError("Casting " + rtype + " to complex is not Implemented",
Expand All @@ -824,30 +822,24 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
if (ASRUtils::is_logical(*left_type) && ASRUtils::is_logical(*right_type)) {
ASR::ttype_t* int_type = ASRUtils::TYPE(ASR::make_Integer_t(al,
right->base.loc, 4, nullptr, 0));
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type,
nullptr));
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type));
} else if (ASRUtils::is_logical(*right_type)) {
ASR::ttype_t* int_type = ASRUtils::TYPE(ASR::make_Integer_t(al,
right->base.loc, 4, nullptr, 0));
if (ASRUtils::is_integer(*left_type)) {
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type,
nullptr));
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type));
} else if (ASRUtils::is_real(*left_type)) {
right = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type,
nullptr));
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, left_type,
nullptr));
right = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type));
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, left_type));
} else if (ASRUtils::is_complex(*left_type)) {
right = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type,
nullptr));
return ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::IntegerToComplex, left_type,
nullptr));
right = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::LogicalToInteger, int_type));
return ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToComplex, left_type));
} else {
std::string ltype = ASRUtils::type_to_str_python(left_type);
throw SemanticError("Binary Operation not implemented for bool and " + ltype,
Expand Down Expand Up @@ -891,14 +883,12 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
dest_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc,
8, nullptr, 0));
if (ASRUtils::is_integer(*left_type)) {
left = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, left->base.loc, left, ASR::cast_kindType::IntegerToReal, dest_type,
value));
left = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, left->base.loc, left, ASR::cast_kindType::IntegerToReal, dest_type));
}
if (ASRUtils::is_integer(*right_type)) {
right = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, dest_type,
value));
right = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, dest_type));
}
left = cast_helper(ASRUtils::expr_type(right), left);
right = cast_helper(ASRUtils::expr_type(left), right);
Expand Down Expand Up @@ -948,9 +938,8 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
dest_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc,
8, nullptr, 0));
if (ASRUtils::is_integer(*left_type)) {
left = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, left->base.loc, left, ASR::cast_kindType::IntegerToReal, dest_type,
value));
left = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, left->base.loc, left, ASR::cast_kindType::IntegerToReal, dest_type));
}
if (ASRUtils::is_integer(*right_type)) {
if (ASRUtils::expr_value(right) != nullptr) {
Expand All @@ -966,9 +955,8 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
throw SemanticAbort();
}
}
right = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, dest_type,
value));
right = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, right->base.loc, right, ASR::cast_kindType::IntegerToReal, dest_type));
} else if (ASRUtils::is_real(*right_type)) {
if (ASRUtils::expr_value(right) != nullptr) {
double val = ASR::down_cast<ASR::RealConstant_t>(ASRUtils::expr_value(right))->m_r;
Expand Down Expand Up @@ -2340,9 +2328,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
int kind = ASRUtils::extract_kind_from_ttype_t(var->m_type);
ASR::ttype_t *dest_type = ASR::down_cast<ASR::ttype_t>(ASR::make_Real_t(al, x.base.base.loc,
kind, nullptr, 0));
ASR::expr_t *value = ASR::down_cast<ASR::expr_t>(ASR::make_Cast_t(
al, val->base.loc, val, ASR::cast_kindType::ComplexToReal, dest_type,
nullptr));
ASR::expr_t *value = ASR::down_cast<ASR::expr_t>(ASRUtils::make_Cast_t_value(
al, val->base.loc, val, ASR::cast_kindType::ComplexToReal, dest_type));
tmp = ASR::make_ComplexRe_t(al, x.base.base.loc, val, dest_type, ASRUtils::expr_value(value));
return;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-assign2-8d1a2ee.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-assign2-8d1a2ee.stdout",
"stdout_hash": "166d952ff942fbd5f47a07859736c10446d8a1fedb4d32d9724ac426",
"stdout_hash": "543ba4ebcdcd0f6f9b360f6bde8fde60c4375e1e8b0d528e3bffc3c0",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-assign2-8d1a2ee.stdout
Original file line number Diff line number Diff line change
@@ -1 +1 @@
(TranslationUnit (SymbolTable 1 {f: (Variable 1 f Local () (Cast (RealConstant 1.234568 (Real 8 [])) RealToReal (Real 4 []) ()) Default (Real 4 []) Source Public Required .false.), f2: (Variable 1 f2 Local () (RealConstant 1.234568 (Real 8 [])) Default (Real 8 []) Source Public Required .false.), i: (Variable 1 i Local () (IntegerConstant 5 (Integer 4 [])) Default (Integer 4 []) Source Public Required .false.), i2: (Variable 1 i2 Local () (Cast (IntegerConstant 53430903434 (Integer 4 [])) IntegerToInteger (Integer 8 []) ()) Default (Integer 8 []) Source Public Required .false.), main_program: (Program (SymbolTable 2 {}) main_program [] [])}) [])
(TranslationUnit (SymbolTable 1 {f: (Variable 1 f Local () (Cast (RealConstant 1.234568 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 1.234568 (Real 4 []))) Default (Real 4 []) Source Public Required .false.), f2: (Variable 1 f2 Local () (RealConstant 1.234568 (Real 8 [])) Default (Real 8 []) Source Public Required .false.), i: (Variable 1 i Local () (IntegerConstant 5 (Integer 4 [])) Default (Integer 4 []) Source Public Required .false.), i2: (Variable 1 i2 Local () (Cast (IntegerConstant 53430903434 (Integer 4 [])) IntegerToInteger (Integer 8 []) ()) Default (Integer 8 []) Source Public Required .false.), main_program: (Program (SymbolTable 2 {}) main_program [] [])}) [])
2 changes: 1 addition & 1 deletion tests/reference/asr-complex1-f26c460.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-complex1-f26c460.stdout",
"stdout_hash": "5869e314cb005e53ec05365a2887376794de5ecb2148e5f049e51096",
"stdout_hash": "614ff3da6f6eff7b6c3285296be10d7703f97966655c2eb70f8ef100",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading