Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
Expand All @@ -23,7 +24,7 @@ public class VariablePropagation {
*/
public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) {
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
Map<String, Expression> directSubstitutions = new HashMap<>(); // var == literal or var == var
Map<String, Expression> directSubstitutions = new HashMap<>(); // var == literal or var == var
Map<String, Expression> expressionSubstitutions = new HashMap<>(); // var == expression
for (Map.Entry<String, Expression> entry : substitutions.entrySet()) {
Expression value = entry.getValue();
Expand Down Expand Up @@ -69,6 +70,12 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
return new ValDerivationNode(var, null);
}

if (exp instanceof FunctionInvocation) {
Expression value = subs.get(exp.toString());
if (value != null)
return new ValDerivationNode(value.clone(), new VarDerivationNode(exp.toString()));
}

// lift unary origin
if (exp instanceof UnaryExpression unary) {
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, varOrigins);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.Var;

public class VariableResolver {
Expand All @@ -25,7 +26,7 @@ public static Map<String, Expression> resolve(Expression exp) {
resolveRecursive(exp, map);

// remove variables that were not used in the expression
map.entrySet().removeIf(entry -> !hasUsage(exp, entry.getKey()));
map.entrySet().removeIf(entry -> !hasUsage(exp, entry.getKey(), entry.getValue()));

// transitively resolve variables
return resolveTransitive(map);
Expand All @@ -45,33 +46,49 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
if ("&&".equals(op)) {
resolveRecursive(be.getFirstOperand(), map);
resolveRecursive(be.getSecondOperand(), map);
} else if ("==".equals(op)) {
Expression left = be.getFirstOperand();
Expression right = be.getSecondOperand();
if (left instanceof Var var && right.isLiteral()) {
map.put(var.getName(), right.clone());
} else if (right instanceof Var var && left.isLiteral()) {
map.put(var.getName(), left.clone());
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
// to substitute internal variable with user-facing variable
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
map.put(leftVar.getName(), right.clone());
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
map.put(rightVar.getName(), left.clone());
} else if (isInternal(leftVar) && isInternal(rightVar)) {
// to substitute the lower-counter variable with the higher-counter one
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
}
} else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) {
map.put(var.getName(), right.clone());
return;
}
if (!"==".equals(op))
return;

Expression left = be.getFirstOperand();
Expression right = be.getSecondOperand();
String leftKey = substitutionKey(left);
String rightKey = substitutionKey(right);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using strings for substitution really the best approach?


if (leftKey != null && right.isLiteral()) {
map.put(leftKey, right.clone());
} else if (rightKey != null && left.isLiteral()) {
map.put(rightKey, left.clone());
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
// to substitute internal variable with user-facing variable
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
map.put(leftVar.getName(), right.clone());
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
map.put(rightVar.getName(), left.clone());
} else if (isInternal(leftVar) && isInternal(rightVar)) {
// to substitute the lower-counter variable with the higher-counter one
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
}
} else if (left instanceof Var var && canSubstitute(var, right)) {
map.put(var.getName(), right.clone());
} else if (left instanceof FunctionInvocation && !containsExpression(right, left)) {
map.put(leftKey, right.clone());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means we would not substitute something like:
f(a) == ff(a) + b

because the string on the right contains f(a) even though they are not related, right?
Ig better be conservative but we should be aware of this limitation

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I'll replace that string check with an AST check.

}
}

private static String substitutionKey(Expression exp) {
if (exp instanceof Var var)
return var.getName();
if (exp instanceof FunctionInvocation)
return exp.toString();
return null;
}

/**
* Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1)
*
Expand All @@ -98,10 +115,10 @@ private static Map<String, Expression> resolveTransitive(Map<String, Expression>
* @return resolved expression
*/
private static Expression lookup(Expression exp, Map<String, Expression> map, Set<String> seen) {
if (!(exp instanceof Var))
String name = substitutionKey(exp);
if (name == null)
return exp;

String name = exp.toString();
if (seen.contains(name))
return exp; // circular reference

Expand All @@ -121,27 +138,36 @@ private static Expression lookup(Expression exp, Map<String, Expression> map, Se
*
* @return true if used, false otherwise
*/
private static boolean hasUsage(Expression exp, String name) {
private static boolean hasUsage(Expression exp, String name, Expression value) {
// exclude own definitions
if (exp instanceof BinaryExpression binary && "==".equals(binary.getOperator())) {
Expression left = binary.getFirstOperand();
Expression right = binary.getSecondOperand();
if (left instanceof Var v && v.getName().equals(name)
if (left instanceof Var v && v.getName().equals(name) && right.equals(value)
&& (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right))))
return false;
if (right instanceof Var v && v.getName().equals(name) && left.isLiteral())
if (left instanceof FunctionInvocation && left.toString().equals(name) && right.equals(value)
&& (right.isLiteral() || (!(right instanceof Var) && !containsExpression(right, left))))
return false;
if (right instanceof Var v && v.getName().equals(name) && left.equals(value) && left.isLiteral())
return false;
if (right instanceof FunctionInvocation && right.toString().equals(name) && left.equals(value)
&& left.isLiteral())
return false;
}

// usage found
if (exp instanceof Var var && var.getName().equals(name)) {
return true;
}
if (exp instanceof FunctionInvocation && exp.toString().equals(name)) {
return true;
}

// recurse children
if (exp.hasChildren()) {
for (Expression child : exp.getChildren())
if (hasUsage(child, name))
if (hasUsage(child, name, value))
return true;
}

Expand Down Expand Up @@ -185,4 +211,18 @@ private static boolean containsVariable(Expression exp, String name) {
}
return false;
}

private static boolean containsExpression(Expression exp, Expression target) {
if (exp.equals(target))
return true;

if (!exp.hasChildren())
return false;

for (Expression child : exp.getChildren()) {
if (containsExpression(child, target))
return true;
}
return false;
}
}
Loading
Loading