Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/cancall_new' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
inumanag committed Apr 3, 2024
2 parents 4c0caeb + d3f3486 commit e7bb5c1
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 27 deletions.
26 changes: 8 additions & 18 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,25 +1033,15 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
callArgs.back().value->setType(a.second);
}

auto fn = expr->args[0].value->type->getFunc();
if (!fn) {
bool canCompile = true;
// Special case: not a function, just try compiling it!
auto ocache = *(ctx->cache);
auto octx = *ctx;
try {
transform(N<CallExpr>(clone(expr->args[0].value),
N<StarExpr>(clone(expr->args[1].value)),
N<KeywordStarExpr>(clone(expr->args[2].value))));
} catch (const exc::ParserException &e) {
// LOG("{}", e.what());
canCompile = false;
*ctx = octx;
*(ctx->cache) = ocache;
}
return {true, transform(N<BoolExpr>(canCompile))};
if (auto fn = expr->args[0].value->type->getFunc()) {
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (auto pt = expr->args[0].value->type->getPartial()) {
return {true, transform(N<BoolExpr>(canCall(pt->func, callArgs, pt) >= 0))};
} else {
compilationWarning("cannot use fn_can_call on non-functions", getSrcInfo().file,
getSrcInfo().line, getSrcInfo().col);
return {true, transform(N<BoolExpr>(false))};
}
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) {
expr->staticValue.type = StaticValue::INT;
auto fn = ctx->extractFunction(expr->args[0].value->type);
Expand Down
27 changes: 22 additions & 5 deletions codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, const StmtPtr &stmts) {
auto so = clone(stmts);
auto s = v.inferTypes(so, true);
if (!s) {
// LOG("{}", so->toString(2));
v.error("cannot typecheck the program");
}
if (s->getSuite())
Expand Down Expand Up @@ -251,13 +252,21 @@ class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
/// Check if a function can be called with the given arguments.
/// See @c reorderNamedArgs for details.
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
const std::vector<CallExpr::Arg> &args) {
const std::vector<CallExpr::Arg> &args,
std::shared_ptr<types::PartialType> part) {
auto getPartialArg = [&](size_t pi) -> types::TypePtr {
if (pi < part->args.size())
return part->args[pi];
else
return nullptr;
};

std::vector<std::pair<types::TypePtr, size_t>> reordered;
auto niGenerics = fn->ast->getNonInferrableGenerics();
auto score = ctx->reorderNamedArgs(
fn.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0, gi = 0; si < slots.size(); si++) {
for (int si = 0, gi = 0, pi = 0; si < slots.size(); si++) {
if (fn->ast->args[si].status == Param::Generic) {
if (slots[si].empty()) {
// is this "real" type?
Expand All @@ -275,15 +284,21 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
}
gi++;
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(nullptr, 0);
// Partials
if (slots[si].empty() && part && part->known[si]) {
reordered.emplace_back(getPartialArg(pi++), 0);
} else {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(nullptr, 0);
}
} else {
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
}
}
return 0;
},
[](error::Error, const SrcInfo &, const std::string &) { return -1; });
[](error::Error, const SrcInfo &, const std::string &) { return -1; },
part ? part->known : std::vector<char>{});
int ai = 0, mai = 0, gi = 0, real_gi = 0;
for (; score != -1 && ai < reordered.size(); ai++) {
auto expectTyp = fn->ast->args[ai].status == Param::Normal
Expand Down Expand Up @@ -341,6 +356,8 @@ TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ,
continue; // avoid overloads that have not been seen yet
auto method = ctx->instantiate(mi, typ)->getFunc();
int score = canCall(method, args);
// LOG("{}: {} {} :: {} :: {}", getSrcInfo(), method->debugString(2), args, score,
// method->ast->getSrcInfo());
if (score != -1) {
results.push_back(mi);
}
Expand Down
3 changes: 2 additions & 1 deletion codon/parser/visitors/typecheck/typecheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class TypecheckVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
types::FuncTypePtr
findBestMethod(const types::ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
int canCall(const types::FuncTypePtr &, const std::vector<CallExpr::Arg> &);
int canCall(const types::FuncTypePtr &, const std::vector<CallExpr::Arg> &,
std::shared_ptr<types::PartialType> = nullptr);
std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods,
Expand Down
9 changes: 6 additions & 3 deletions stdlib/internal/internal.codon
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,14 @@ class __internal__:

def _union_call_helper(union, args, kwargs) -> Union:
for tag, T in vars_types(union, with_index=1):
if hasattr(T, '__call__'):
if fn_can_call(__internal__.union_get_data(union, T), *args, **kwargs):
if fn_can_call(T, *args, **kwargs):
if __internal__.union_get_tag(union) == tag:
return __internal__.union_get_data(union, T)(*args, **kwargs)
elif hasattr(T, '__call__'):
if fn_can_call(T.__call__, *args, **kwargs):
if __internal__.union_get_tag(union) == tag:
return __internal__.union_get_data(union, T).__call__(*args, **kwargs)
raise TypeError("cannot call union")
raise TypeError("cannot call union " + union.__class__.__name__)

def union_call(union, args, kwargs):
t = __internal__._union_call_helper(union, args, kwargs)
Expand Down
3 changes: 3 additions & 0 deletions stdlib/internal/static.codon
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def fn_arg_has_type(F, i: Static[int]):
def fn_arg_get_type(F, i: Static[int]):
pass

@no_type_wrap
@no_argument_wrap
def fn_can_call(F, *args, **kwargs):
pass

Expand All @@ -28,6 +30,7 @@ def fn_get_default(F, i: Static[int]):
pass

@no_type_wrap
@no_argument_wrap
def static_print(*args):
pass

Expand Down

0 comments on commit e7bb5c1

Please sign in to comment.