Skip to content

Commit

Permalink
Added support for handling chaining of Symbolic operators and printin…
Browse files Browse the repository at this point in the history
…g without assignment . (#2051)
  • Loading branch information
anutosh491 committed Jun 28, 2023
1 parent 180ba71 commit dc3510b
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 96 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ RUN(NAME structs_30 LABELS cpython llvm c)
RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
RUN(NAME symbolics_04 LABELS cpython_sym c_sym)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
41 changes: 41 additions & 0 deletions integration_tests/symbolics_04.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sympy import Symbol, pi, S
from lpython import S

def test_chained_operations():
x: S = Symbol('x')
y: S = Symbol('y')
z: S = Symbol('z')
a: S = Symbol('a')
b: S = Symbol('b')

# Chained Operations
w: S = (x + y) * ((a - b) / (pi + z))
result: S = (w ** S(2) - pi) + S(3)

# Print Statements
print(result)
# Expected: 3 + (a - b)**2*(x + y)**2/(z + pi)**2 - pi

# Additional Variables
c: S = Symbol('c')
d: S = Symbol('d')
e: S = Symbol('e')
f: S = Symbol('f')

# Chained Operations with Additional Variables
x = (c * d + e) / f
y = (x - S(10)) * (pi + S(5))
z = y ** (S(2) / (f + d))
result = (z + e) * (a - b)

# Print Statements
print(result)
# Expected: (a - b)*(e + ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f)))
print(x)
# Expected: (e + c*d)/f
print(y)
# Expected: (5 + pi)*(-10 + (e + c*d)/f)
print(z)
# Expected: ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f))

test_chained_operations()
15 changes: 11 additions & 4 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,12 @@ R"( // Initialise Numpy
if( ASRUtils::is_array(value_type) ) {
src += "->data";
}
if (value_type->type == ASR::ttypeType::List ||
value_type->type == ASR::ttypeType::Tuple) {
if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)) {
out += symengine_src;
symengine_src = "";
}
if( ASR::is_a<ASR::List_t>(*value_type) ||
ASR::is_a<ASR::Tuple_t>(*value_type)) {
tmp_gen += "\"";
if (!v.empty()) {
for (auto &s: v) {
Expand All @@ -1156,13 +1160,16 @@ R"( // Initialise Numpy
}
tmp_gen += c_ds_api->get_print_type(value_type, ASR::is_a<ASR::ArrayItem_t>(*x.m_values[i]));
v.push_back(src);
if (value_type->type == ASR::ttypeType::Complex) {
if (ASR::is_a<ASR::Complex_t>(*value_type)) {
v.pop_back();
v.push_back("creal(" + src + ")");
v.push_back("cimag(" + src + ")");
} else if(value_type->type == ASR::ttypeType::SymbolicExpression){
} else if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)){
v.pop_back();
v.push_back("basic_str(" + src + ")");
if(ASR::is_a<ASR::Var_t>(*x.m_values[i])) {
symengine_queue.pop();
}
}
if (i+1!=x.n_values) {
tmp_gen += "\%s";
Expand Down
153 changes: 123 additions & 30 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ struct CPPDeclarationOptions: public DeclarationOptions {
}
};

class SymEngineQueue {
public:
std::vector<std::string> queue;
int queue_front = -1;
std::string& symengine_src;

SymEngineQueue(std::string& symengine_src) : symengine_src(symengine_src) {}

std::string push() {
std::string indent(4, ' ');
std::string var;
if(queue_front == -1 || queue_front >= static_cast<int>(queue.size())) {
var = "queue" + std::to_string(queue.size());
queue.push_back(var);
symengine_src = indent + "basic " + var + ";\n";
symengine_src += indent + "basic_new_stack(" + var + ");\n";
}
return queue[queue_front++];
}

void pop() {
LCOMPILERS_ASSERT(queue_front != -1 && queue_front < static_cast<int>(queue.size()));
queue_front++;
}
};

template <class Struct>
class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
{
Expand Down Expand Up @@ -115,6 +141,8 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
bool is_c;
std::set<std::string> headers, user_headers, user_defines;
std::vector<std::string> tmp_buffer_src;
std::string symengine_src;
SymEngineQueue symengine_queue{symengine_src};

SymbolTable* global_scope;
int64_t lower_bound;
Expand Down Expand Up @@ -1178,6 +1206,17 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
target = "&" + target;
}
}
if( ASR::is_a<ASR::SymbolicExpression_t>(*value_type) ) {
if(ASR::is_a<ASR::Var_t>(*x.m_value)){
src = indent + "basic_assign(" + target + ", " + value + ");\n";
symengine_queue.pop();
symengine_queue.pop();
return;
}
src = symengine_src;
symengine_src = "";
return;
}
if( !from_std_vector_helper.empty() ) {
src = from_std_vector_helper;
} else {
Expand Down Expand Up @@ -1243,12 +1282,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
}
} else {
if (m_target_type->type == ASR::ttypeType::SymbolicExpression){
ASR::expr_t* m_value_expr = x.m_value;
src += alloc + indent + c_ds_api->get_deepcopy_symbolic(m_value_expr, value, target) + "\n";
} else {
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
}
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
}
} else {
src += indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
Expand Down Expand Up @@ -1646,6 +1680,15 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
src = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
}
last_expr_precedence = 2;
ASR::ttype_t* var_type = sv->m_type;
if( ASR::is_a<ASR::SymbolicExpression_t>(*var_type)) {
std::string var_name = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
symengine_queue.queue.push_back(var_name);
if (symengine_queue.queue_front == -1) {
symengine_queue.queue_front = 0;
}
symengine_src = "";
}
}

void visit_StructInstanceMember(const ASR::StructInstanceMember_t& x) {
Expand Down Expand Up @@ -1858,6 +1901,8 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
break;
}
case (ASR::cast_kindType::IntegerToSymbolicExpression): {
self().visit_expr(*x.m_value);
last_expr_precedence = 2;
break;
}
default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented",
Expand Down Expand Up @@ -2591,8 +2636,34 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
out += func_name; break; \
}

std::string performSymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
headers.insert("symengine/cwrapper.h");
std::string indent(4, ' ');
LCOMPILERS_ASSERT(x.n_args == 2);
std::string target = symengine_queue.push();
std::string target_src = symengine_src;
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
std::string arg1_src = symengine_src;
// Check if x.m_args[0] is a Var
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
symengine_queue.pop();
}
this->visit_expr(*x.m_args[1]);
std::string arg2 = src;
std::string arg2_src = symengine_src;
// Check if x.m_args[0] is a Var
if (ASR::is_a<ASR::Var_t>(*x.m_args[1])) {
symengine_queue.pop();
}
symengine_src = target_src + arg1_src + arg2_src;
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ", " + arg2 + ");\n";
return target;
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) {
std::string out;
std::string indent(4, ' ');
switch (x.m_intrinsic_id) {
SET_INTRINSIC_NAME(Sin, "sin");
SET_INTRINSIC_NAME(Cos, "cos");
Expand All @@ -2607,22 +2678,51 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
SET_INTRINSIC_NAME(Exp, "exp");
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(SymbolicSymbol, "Symbol");
SET_INTRINSIC_NAME(SymbolicInteger, "Integer");
SET_INTRINSIC_NAME(SymbolicPi, "pi");
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
src = performSymbolicOperation("basic_add", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)): {
src = performSymbolicOperation("basic_sub", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)): {
src = performSymbolicOperation("basic_mul", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)): {
src = performSymbolicOperation("basic_div", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
LCOMPILERS_ASSERT(x.n_args == 2);
src = performSymbolicOperation("basic_pow", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 0);
std::string target = symengine_queue.push();
symengine_src += indent + "basic_const_pi(" + target + ");\n";
src = target;
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 1);
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
this->visit_expr(*x.m_args[1]);
std::string arg2 = src;
out = arg1 + "," + arg2;
src = out;
break;
std::string target = symengine_queue.push();
symengine_src += indent + "symbol_set(" + target + ", " + src + ");\n";
src = target;
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 1);
this->visit_expr(*x.m_args[0]);
std::string target = symengine_queue.push();
symengine_src += indent + "integer_set_si(" + target + ", " + src + ");\n";
src = target;
return;
}
default : {
throw LCompilersException("IntrinsicFunction: `"
Expand All @@ -2631,16 +2731,9 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
}
}
headers.insert("math.h");
if (x.n_args == 0){
src = out;
} else if (x.n_args == 1) {
this->visit_expr(*x.m_args[0]);
if ((x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)) &&
(x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger))) {
out += "(" + src + ")";
src = out;
}
}
this->visit_expr(*x.m_args[0]);
out += "(" + src + ")";
src = out;
}
};

Expand Down
62 changes: 0 additions & 62 deletions src/libasr/codegen/c_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,68 +445,6 @@ class CCPPDSUtils {
return result;
}

std::string generate_binary_operator_code(std::string value, std::string target, std::string operatorName) {
size_t delimiterPos = value.find(",");
std::string leftPart = value.substr(0, delimiterPos);
std::string rightPart = value.substr(delimiterPos + 1);
std::string result = operatorName + "(" + target + ", " + leftPart + ", " + rightPart + ");";
return result;
}

std::string get_deepcopy_symbolic(ASR::expr_t *value_expr, std::string value, std::string target) {
std::string result;
if (ASR::is_a<ASR::Var_t>(*value_expr)) {
result = "basic_assign(" + target + ", " + value + ");";
} else if (ASR::is_a<ASR::IntrinsicFunction_t>(*value_expr)) {
ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicFunction_t>(value_expr);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
switch (static_cast<LCompilers::ASRUtils::IntrinsicFunctions>(intrinsic_id)) {
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: {
result = "symbol_set(" + target + ", " + value + ");";
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: {
result = generate_binary_operator_code(value, target, "basic_add");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: {
result = generate_binary_operator_code(value, target, "basic_sub");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: {
result = generate_binary_operator_code(value, target, "basic_mul");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: {
result = generate_binary_operator_code(value, target, "basic_div");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: {
result = generate_binary_operator_code(value, target, "basic_pow");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: {
result = "basic_const_pi(" + target + ");";
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger: {
result = "integer_set_si(" + target + ", " + value + ");";
break;
}
default: {
throw LCompilersException("IntrinsicFunction: `"
+ LCompilers::ASRUtils::get_intrinsic_name(intrinsic_id)
+ "` is not implemented");
}
}
} else if (ASR::is_a<ASR::Cast_t>(*value_expr)) {
ASR::Cast_t* cast_expr = ASR::down_cast<ASR::Cast_t>(value_expr);
std::string cast_value_expr = get_deepcopy_symbolic(cast_expr->m_value, value, target);
return cast_value_expr;
}
return result;
}

std::string get_type(ASR::ttype_t *t) {
LCOMPILERS_ASSERT(CUtils::is_non_primitive_DT(t));
if (ASR::is_a<ASR::List_t>(*t)) {
Expand Down

0 comments on commit dc3510b

Please sign in to comment.