Skip to content

Commit

Permalink
Return function schema for atom constructors (#8743)
Browse files Browse the repository at this point in the history
close #8663

Changelog:
- update: use `MethodRootNode` for the atom constructor function to preserve the call info in runtime
- fix: return function schema for atom constructors
  • Loading branch information
4e6 committed Jan 12, 2024
1 parent 3c29a58 commit 972b359
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,14 @@ public boolean isFunctionCallChanged() {
public record FunctionPointer(
QualifiedName moduleName, QualifiedName typeName, String functionName) {

public static FunctionPointer fromAtomConstructor(AtomConstructor atomConstructor) {
QualifiedName moduleName = atomConstructor.getDefinitionScope().getModule().getName();
QualifiedName typeName = atomConstructor.getType().getQualifiedName();
String functionName = atomConstructor.getName();

return new FunctionPointer(moduleName, typeName, functionName);
}

public static FunctionPointer fromFunction(Function function) {
RootNode rootNode = function.getCallTarget().getRootNode();

Expand Down Expand Up @@ -767,6 +775,9 @@ public static FunctionPointer fromFunction(Function function) {
public static int[] collectNotAppliedArguments(Function function) {
FunctionSchema functionSchema = function.getSchema();
Object[] preAppliedArguments = function.getPreAppliedArguments();
if (preAppliedArguments == null) {
preAppliedArguments = new Object[functionSchema.getArgumentsCount()];
}
boolean isStatic = preAppliedArguments[0] instanceof Type;
int selfArgumentPosition = isStatic ? -1 : 0;
int[] notAppliedArguments = new int[functionSchema.getArgumentsCount()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.enso.interpreter.instrument.profiling.ExecutionTime
import org.enso.interpreter.node.callable.FunctionCallInstrumentationNode.FunctionCall
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode
import org.enso.interpreter.runtime.`type`.{Types, TypesGen}
import org.enso.interpreter.runtime.callable.atom.AtomConstructor
import org.enso.interpreter.runtime.callable.function.Function
import org.enso.interpreter.runtime.control.ThreadInterruptedException
import org.enso.interpreter.runtime.error.{
Expand Down Expand Up @@ -344,12 +345,12 @@ object ProgramExecutionSupport {
syncState: UpdatesSynchronizationState,
value: ExpressionValue
)(implicit ctx: RuntimeContext): Unit = {
val expressionId = value.getExpressionId
val methodPointer = toMethodCall(value)
val expressionId = value.getExpressionId
val methodCall = toMethodCall(value)
if (
!syncState.isExpressionSync(expressionId) ||
(
methodPointer.isDefined && !syncState.isMethodPointerSync(
methodCall.isDefined && !syncState.isMethodPointerSync(
expressionId
)
) ||
Expand Down Expand Up @@ -409,13 +410,23 @@ object ProgramExecutionSupport {
val schema = value.getValue match {
case function: Function =>
val functionInfo = FunctionPointer.fromFunction(function)
toMethodPointer(functionInfo).map { methodPointer =>
Api.FunctionSchema(
methodPointer,
FunctionPointer.collectNotAppliedArguments(function).toVector
val notAppliedArguments = FunctionPointer
.collectNotAppliedArguments(function)
.toVector
toMethodPointer(functionInfo).map(methodPointer =>
Api.FunctionSchema(methodPointer, notAppliedArguments)
)
case atomConstructor: AtomConstructor =>
val functionInfo =
FunctionPointer.fromAtomConstructor(atomConstructor)
val notAppliedArguments = FunctionPointer
.collectNotAppliedArguments(
atomConstructor.getConstructorFunction
)
}

.toVector
toMethodPointer(functionInfo).map(methodPointer =>
Api.FunctionSchema(methodPointer, notAppliedArguments)
)
case _ =>
None
}
Expand All @@ -430,7 +441,7 @@ object ProgramExecutionSupport {
Api.ExpressionUpdate(
value.getExpressionId,
Option(value.getType),
methodPointer,
methodCall,
value.getProfilingInfo.map { case e: ExecutionTime =>
Api.ProfilingInfo.ExecutionTime(e.getNanoTimeElapsed)
}.toVector,
Expand All @@ -444,7 +455,7 @@ object ProgramExecutionSupport {
)

syncState.setExpressionSync(expressionId)
if (methodPointer.isDefined) {
if (methodCall.isDefined) {
syncState.setMethodPointerSync(expressionId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import java.util.concurrent.locks.ReentrantLock;
import org.enso.compiler.context.LocalScope;
import org.enso.interpreter.EnsoLanguage;
import org.enso.interpreter.node.ClosureRootNode;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.MethodRootNode;
import org.enso.interpreter.node.callable.argument.ReadArgumentNode;
import org.enso.interpreter.node.callable.function.BlockNode;
import org.enso.interpreter.node.expression.atom.InstantiateNode;
Expand Down Expand Up @@ -161,15 +161,8 @@ private Function buildConstructorFunction(
}
BlockNode instantiateBlock = BlockNode.buildSilent(assignments, instantiateNode);
RootNode rootNode =
ClosureRootNode.build(
language,
localScope,
definitionScope,
instantiateBlock,
section,
type.getName() + "." + name,
null,
false);
MethodRootNode.build(
language, localScope, definitionScope, instantiateBlock, section, type, name);
RootCallTarget callTarget = rootNode.getCallTarget();
return new Function(callTarget, null, new FunctionSchema(annotations, args));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package org.enso.interpreter.test.instrument;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import java.nio.file.Paths;
import java.util.Map;
import java.util.logging.Level;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.service.ExecutionService.FunctionPointer;
import org.enso.interpreter.test.TestBase;
import org.enso.polyglot.RuntimeOptions;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Language;
import org.graalvm.polyglot.Source;
import org.graalvm.polyglot.io.IOAccess;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class FunctionPointerTest extends TestBase {

private Context context;

@Before
public void initContext() {
context =
Context.newBuilder()
.allowExperimentalOptions(true)
.option(
RuntimeOptions.LANGUAGE_HOME_OVERRIDE,
Paths.get("../../distribution/component").toFile().getAbsolutePath())
.option(RuntimeOptions.LOG_LEVEL, Level.WARNING.getName())
.logHandler(System.err)
.allowExperimentalOptions(true)
.allowIO(IOAccess.ALL)
.allowAllAccess(true)
.build();

var engine = context.getEngine();
Map<String, Language> langs = engine.getLanguages();
Assert.assertNotNull("Enso found: " + langs, langs.get("enso"));
}

@After
public void disposeContext() {
context.close();
}

@Test
public void moduleFunctionPointer() throws Exception {
var rawCode = """
from Standard.Base import all
run a b = a + b
""";
var src = Source.newBuilder("enso", rawCode, "TestFunctionPointer.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "run");

assertTrue("fn: " + res, res.canExecute());
var rawRes = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawRes, rawRes instanceof Function);
var c = FunctionPointer.fromFunction((Function) rawRes);
assertNotNull(c);
assertEquals("TestFunctionPointer", c.moduleName().toString());
assertEquals("TestFunctionPointer", c.typeName().toString());
assertEquals("run", c.functionName().toString());
}

@Test
public void typeStaticMethodPointer() throws Exception {
var rawCode =
"""
from Standard.Base import all
type X
run a b = a + b
""";
var src = Source.newBuilder("enso", rawCode, "StaticMethodPointer.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "X.run");

assertTrue("fn: " + res, res.canExecute());
var rawRes = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawRes, rawRes instanceof Function);
var c = FunctionPointer.fromFunction((Function) rawRes);
assertNotNull(c);
assertEquals("StaticMethodPointer", c.moduleName().toString());
assertEquals("StaticMethodPointer.X", c.typeName().toString());
assertEquals("run", c.functionName().toString());

var apply = res.execute(1);
assertTrue("fn: " + apply, apply.canExecute());
var rawApply = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawApply, rawApply instanceof Function);
var a = FunctionPointer.fromFunction((Function) rawApply);
assertNotNull(a);
assertEquals("StaticMethodPointer", a.moduleName().toString());
assertEquals("StaticMethodPointer.X", a.typeName().toString());
assertEquals("run", a.functionName().toString());
}

@Test
public void typeInstanceMethodPointer() throws Exception {
var rawCode =
"""
from Standard.Base import all
type X
run self b c = [self, b, c]
""";
var src = Source.newBuilder("enso", rawCode, "InstanceMethodPointer.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "X.run");

assertTrue("fn: " + res, res.canExecute());
var rawRes = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawRes, rawRes instanceof Function);
var c = FunctionPointer.fromFunction((Function) rawRes);
assertNotNull(c);
assertEquals("InstanceMethodPointer", c.moduleName().toString());
assertEquals("InstanceMethodPointer.X", c.typeName().toString());
assertEquals("run", c.functionName().toString());

var apply = res.execute(1);
assertTrue("fn: " + apply, apply.canExecute());
var rawApply = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawApply, rawApply instanceof Function);
var a = FunctionPointer.fromFunction((Function) rawApply);
assertNotNull(a);
assertEquals("InstanceMethodPointer", a.moduleName().toString());
assertEquals("InstanceMethodPointer.X", a.typeName().toString());
assertEquals("run", a.functionName().toString());
}

@Test
public void typeConstructorPointer() throws Exception {
var rawCode =
"""
from Standard.Base import all
type X
Run a b
""";
var src = Source.newBuilder("enso", rawCode, "ConstructorPointer.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "X.Run");

assertTrue("fn: " + res, res.canInstantiate());
var rawRes = TestBase.unwrapValue(context, res);
assertTrue("function: " + rawRes.getClass(), rawRes instanceof AtomConstructor);
var rawFn = ((AtomConstructor) rawRes).getConstructorFunction();
var c = FunctionPointer.fromFunction(rawFn);
assertNotNull("We should get a pointer for " + rawFn, c);

assertEquals("ConstructorPointer", c.moduleName().toString());
assertEquals("ConstructorPointer.X", c.typeName().toString());
assertEquals("Run", c.functionName());

var d = FunctionPointer.fromAtomConstructor((AtomConstructor) rawRes);
assertNotNull("We should get a pointer from " + rawRes, d);

assertEquals("ConstructorPointer", d.moduleName().toString());
assertEquals("ConstructorPointer.X", d.typeName().toString());
assertEquals("Run", d.functionName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,23 @@ class BuiltinTypesTest
3
) should contain theSameElementsAs Seq(
Api.Response(requestId, Api.PushContextResponse(contextId)),
TestMessages.update(contextId, idMain, ConstantsGen.FUNCTION),
TestMessages.update(
contextId,
idMain,
ConstantsGen.FUNCTION,
payload = Api.ExpressionUpdate.Payload.Value(
functionSchema = Some(
Api.FunctionSchema(
Api.MethodPointer(
"Enso_Test.Test.Main",
"Enso_Test.Test.Main.Foo",
"Bar"
),
Vector(0)
)
)
)
),
context.executionComplete(contextId)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,20 +1004,42 @@ class RuntimeServerTest
contextId,
id_x_0,
ConstantsGen.FUNCTION_BUILTIN,
Api.MethodCall(
Api
.MethodPointer("Enso_Test.Test.Main", "Enso_Test.Test.Main.T", "A")
methodCall = Some(
Api.MethodCall(Api.MethodPointer(moduleName, s"$moduleName.T", "A"))
),
payload = Api.ExpressionUpdate.Payload.Value(
functionSchema = Some(
Api.FunctionSchema(
Api.MethodPointer(moduleName, s"$moduleName.T", "A"),
Vector(0, 1)
)
)
)
),
TestMessages.update(
contextId,
id_x_1,
ConstantsGen.FUNCTION_BUILTIN
ConstantsGen.FUNCTION_BUILTIN,
methodCall = Some(
Api.MethodCall(
Api.MethodPointer(moduleName, s"$moduleName.T", "A"),
Vector(1)
)
),
payload = Api.ExpressionUpdate.Payload.Value(
functionSchema = Some(
Api.FunctionSchema(
Api.MethodPointer(moduleName, s"$moduleName.T", "A"),
Vector(1)
)
)
)
),
TestMessages.update(
contextId,
id_x_2,
"Enso_Test.Test.Main.T"
s"$moduleName.T",
Api.MethodCall(Api.MethodPointer(moduleName, s"$moduleName.T", "A"))
),
context.executionComplete(contextId)
)
Expand Down

0 comments on commit 972b359

Please sign in to comment.