Skip to content

Commit

Permalink
Resolved issue #1823 (#1831)
Browse files Browse the repository at this point in the history
* Restored TemplateBinOp

* Updated tests

* Renamed tests

* Forgot to update reference tests

* Got the interface working in requirement

* Converted binop in template into the corresponding functions during template construction, add various checks for requirement's parameters

* Removed TemplateBinOp

* Modified the template example with + operator

---------

Co-authored-by: Ondřej Čertík <ondrej@certik.us>
  • Loading branch information
ansharlubis and certik committed Jul 21, 2023
1 parent 5c741a6 commit 67c3a20
Show file tree
Hide file tree
Showing 21 changed files with 1,504 additions and 5 deletions.
2 changes: 1 addition & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ RUN(NAME assign_to2 LABELS gfortran llvm)
RUN(NAME conv_complex2real LABELS gfortran llvm wasm)

RUN(NAME template_add_01 LABELS llvm)
RUN(NAME template_add_01b LABELS llvm)
RUN(NAME template_add_02 LABELS llvm)
RUN(NAME template_add_03 LABELS llvm)
RUN(NAME template_add_04 LABELS llvm NOFAST)
Expand All @@ -854,7 +855,6 @@ RUN(NAME template_array_04 LABELS llvm NOFAST)
RUN(NAME template_01 LABELS llvm NOFAST)
RUN(NAME template_struct_01 LABELS llvm)


RUN(NAME statement1 LABELS gfortran llvm)

RUN(NAME implied_do_loops1 LABELS gfortran)
Expand Down
59 changes: 59 additions & 0 deletions integration_tests/template_add_01b.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module template_add_01b_m
implicit none
private
public :: add_t

requirement R(T, F)
type :: T; end type
interface operator (+)
procedure F
end interface
function F(x, y) result(z)
type(T), intent(in) :: x, y
type(T) :: z
end function
end requirement

template add_t(T, F)
requires R(T, F)
private
public :: add_generic
contains
function add_generic(x, y) result(z)
type(T), intent(in) :: x, y
type(T) :: z
!print*, "x, y, z, x+y+z =", x, y, z, x+y+z
z = x + y
end function
function add_generic2(x, y, z) result(s)
type(T), intent(in) :: x, y, z
type(T) :: s
s = x + y + z
end function
end template

contains

integer function func_arg_int(x, y) result(z)
integer, intent(in) :: x, y
z = x + y
end function

subroutine test_template()
integer :: n1, n2

instantiate add_t(integer, func_arg_int), only: add_integer => add_generic, add_integer2 => add_generic2
n1 = add_integer(5, 9)
n2 = add_integer2(5, 9, 10)
print*, "The result is", n1
print*, "The result is", n2
end subroutine
end module

program template_add_01b
use template_add_01b_m
implicit none

call test_template()

end program
26 changes: 26 additions & 0 deletions src/lfortran/semantics/ast_common_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5248,6 +5248,32 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {

asr = ASR::make_ComplexBinOp_t(al, x.base.base.loc, left, op, right, dest_type, value);

} else if (ASRUtils::is_type_parameter(*left_type) || ASRUtils::is_type_parameter(*right_type)) {
// if overloaded is not found, then reject
if (overloaded == nullptr) {
std::string op_str = "+";
switch (op) {
case (ASR::Add):
break;
case (ASR::Sub):
op_str = "-";
break;
case (ASR::Mul):
op_str = "*";
break;
case (ASR::Div):
op_str = "/";
break;
case (ASR::Pow):
op_str = "**";
break;
default:
LCOMPILERS_ASSERT(false);
}
throw SemanticError("Operator undefined for " + ASRUtils::type_to_str(left_type)
+ " " + op_str + " " + ASRUtils::type_to_str(right_type), x.base.base.loc);
}
return;
}

if (overloaded != nullptr) {
Expand Down
67 changes: 64 additions & 3 deletions src/lfortran/semantics/ast_symboltable_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
Str s;
s.from_str_view(pname);
char *name = s.c_str(al);
x = resolve_symbol(loc, name);
x = resolve_symbol(loc, to_lower(name));
symbols.push_back(al, x);
}
LCOMPILERS_ASSERT(strlen(generic_name) > 0);
Expand Down Expand Up @@ -2117,6 +2117,12 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
current_procedure_args.push_back(arg);
}

std::map<AST::intrinsicopType, std::vector<std::string>> requirement_op_procs;
for (auto &proc: overloaded_op_procs) {
requirement_op_procs[proc.first] = proc.second;
}
overloaded_op_procs.clear();

Vec<ASR::require_instantiation_t*> reqs;
reqs.reserve(al, x.n_decl);
for (size_t i=0; i<x.n_decl; i++) {
Expand All @@ -2130,6 +2136,39 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
this->visit_program_unit(*x.m_funcs[i]);
}

for (size_t i=0; i<x.n_namelist; i++) {
std::string arg = to_lower(x.m_namelist[i]);
if (!current_scope->get_symbol(arg)) {
diag.add(Diagnostic(
"Parameter " + arg + " is unused in " + x.m_name,
Level::Warning, Stage::Semantic, {
Label("", {x.base.base.loc})
}
));
}
current_procedure_args.push_back(arg);
}

for (auto &item: current_scope->get_scope()) {
bool defined = false;
std::string sym = item.first;
for (size_t i=0; i<current_procedure_args.size(); i++) {
std::string arg = current_procedure_args[i];
if (sym.compare(arg) == 0) {
defined = true;
}
}
if (!defined) {
throw SemanticError("Symbol " + sym + " is not declared in " + to_lower(x.m_name) + "'s parameters",
x.base.base.loc);
}
}

add_overloaded_procedures();
for (auto &proc: requirement_op_procs) {
overloaded_op_procs[proc.first] = proc.second;
}

ASR::asr_t *req = ASR::make_Requirement_t(al, x.base.base.loc,
current_scope, s2c(al, to_lower(x.m_name)), args.p, args.size(),
reqs.p, reqs.size());
Expand All @@ -2142,6 +2181,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
}

void visit_Requires(const AST::Requires_t &x) {
std::map<std::string,std::string> parameter_map;
std::string require_name = to_lower(x.m_name);
ASR::symbol_t *req0 = current_scope->resolve_symbol(require_name);

Expand Down Expand Up @@ -2171,13 +2211,35 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
ASR::symbol_t *temp_arg_sym = current_scope->resolve_symbol(temp_arg);
if (!temp_arg_sym) {
std::string req_arg = req->m_args[i];
parameter_map[req_arg] = temp_arg;
ASR::symbol_t *req_arg_sym = (req->m_symtab)->get_symbol(req_arg);
// TODO: inline this? convert into a static method?
temp_arg_sym = replace_symbol(req_arg_sym, temp_arg);
current_scope->add_symbol(temp_arg, temp_arg_sym);
}
}

// adding custom operators
for (auto &item: req->m_symtab->get_scope()) {
if (ASR::is_a<ASR::CustomOperator_t>(*item.second)) {
ASR::CustomOperator_t *c_op = ASR::down_cast<ASR::CustomOperator_t>(item.second);

// may not need to add new custom operators if another requires already got an interface
Vec<ASR::symbol_t*> symbols;
symbols.reserve(al, c_op->n_procs);
for (size_t i=0; i<c_op->n_procs; i++) {
ASR::symbol_t *proc = c_op->m_procs[i];
std::string new_proc_name = parameter_map[ASRUtils::symbol_name(proc)];
proc = current_scope->resolve_symbol(new_proc_name);
symbols.push_back(al, proc);
}

ASR::symbol_t *new_c_op = ASR::down_cast<ASR::symbol_t>(ASR::make_CustomOperator_t(
al, c_op->base.base.loc, current_scope,
s2c(al, c_op->m_name), symbols.p, symbols.size(), c_op->m_access));
current_scope->add_symbol(c_op->m_name, new_c_op);
}
}

tmp = ASR::make_Require_t(al, x.base.base.loc, s2c(al, require_name),
args.p, args.size());
}
Expand Down Expand Up @@ -2223,7 +2285,6 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
current_procedure_args.clear();
context_map.clear();
is_template = false;

}

void visit_Instantiate(const AST::Instantiate_t &x) {
Expand Down
1 change: 0 additions & 1 deletion src/libasr/pass/instantiate_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,6 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
x->m_original_name, args.p, args.size(), dt, nullptr, false);
}


ASR::ttype_t* substitute_type(ASR::ttype_t *param_type) {
if (ASR::is_a<ASR::List_t>(*param_type)) {
ASR::List_t *tlist = ASR::down_cast<ASR::List_t>(param_type);
Expand Down
25 changes: 25 additions & 0 deletions tests/errors/template_error_07a.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module template_01_m
implicit none
private
public :: op_t

requirement semigroup(t)
type, deferred :: t
elemental function combine(x, y) result(combined)
type(t), intent(in) :: x, y
type(t) :: combined
end function
end requirement

contains

subroutine test_template()
end subroutine

end module

program template_01
use template_01_m
implicit none

end program template_01
21 changes: 21 additions & 0 deletions tests/errors/template_error_07b.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module template_01_m
implicit none
private
public :: op_t

requirement semigroup(t, combine)
type, deferred :: t
end requirement

contains

subroutine test_template()
end subroutine

end module

program template_01
use template_01_m
implicit none

end program template_01
38 changes: 38 additions & 0 deletions tests/errors/template_error_07c.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module template_01_m
implicit none
private
public :: op_t

requirement semigroup(t, combine)
type, deferred :: t
elemental function combine(x, y) result(combined)
type(t), intent(in) :: x, y
type(t) :: combined
end function
end requirement

requirement extended_semigroup(t, combine, sconcat, stimes)
requires semigroup(t, scombine)
pure function sconcat(list) result(combined)
type(t), intent(in) :: list(:)
type(t) :: combined
end function
elemental function stimes(n, a) result(repeated)
integer, intent(in) :: n
type(t), intent(in) :: a
type(t) :: repeated
end function
end requirement

contains

subroutine test_template()
end subroutine

end module

program template_01
use template_01_m
implicit none

end program template_01
50 changes: 50 additions & 0 deletions tests/errors/template_error_08.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module template_add_01b_m
implicit none
private
public :: add_t

requirement R(T, F)
type :: T; end type
function F(x, y) result(z)
type(T), intent(in) :: x, y
type(T) :: z
end function
end requirement

template add_t(T, F)
requires R(T, F)
private
public :: add_generic
contains
function add_generic(x, y) result(z)
type(T), intent(in) :: x, y
type(T) :: z
z = x + y
end function
end template

contains

integer function func_arg_int(x, y) result(z)
integer, intent(in) :: x, y
z = x + y
end function

subroutine test_template()
real :: a
integer :: n, s

instantiate add_t(integer, func_arg_int), only: add_integer => add_generic
n = add_integer(5, 9)
!s = add_integer2(5, 9, 10)
print*, "The result is", n
end subroutine
end module

program template_add_01b
use template_add_01b_m
implicit none

call test_template()

end program
13 changes: 13 additions & 0 deletions tests/reference/asr-template_add_01b-bd911f4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-template_add_01b-bd911f4",
"cmd": "lfortran --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/../integration_tests/template_add_01b.f90",
"infile_hash": "7aff73d0cb778ae2dcac14a95cf3bd54d165adc17b929d7f5197ed50",
"outfile": null,
"outfile_hash": null,
"stdout": "asr-template_add_01b-bd911f4.stdout",
"stdout_hash": "a4bc452d7016e9611a6eee8d2bfdc2668be37d7ad4010f0dbbd1f71b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}

0 comments on commit 67c3a20

Please sign in to comment.