 # Refining Fandango Constraints

As shown previously, FandangoLearn is able to learn constraints that distinguish between failing inputs and passing inputs.
However, the constraints learned by FandangoLearn are not always perfect. In this notebook, we show how to refine the constraints learned by FandangoLearn. 

Let's start by learning some constraints using FandangoLearn. We use some initial input for our calculator subject.
In the first step, lets load the grammar and the initial input.

In [6]:
from fandango.language.parse import parse_file, Grammar
from fandangoLearner.data.input import FandangoInput

grammar_file = "calculator.fan"
grammar, _ = parse_file(grammar_file)

assert isinstance(grammar, Grammar), "Grammar is not loaded correctly"

initial_inputs = {
    ("sqrt(-900)", True),
    ("sqrt(-10)", True),
    ("sqrt(0)", False),
    ("sin(-900)", False),
    ("sqrt(2)", False),
    ("cos(10)", False),
}

for inp, _ in initial_inputs:
    tree = grammar.parse(inp)
    assert tree is not None, f"Failed to parse {inp}"
    
initial_inputs = {FandangoInput.from_str(grammar, inp, oracle) for inp, oracle in initial_inputs}

Like before, we use FandangoLearn to learn constraints that distinguish between passing and failing inputs.

In [7]:
from fandangoLearner.learner import FandangoLearner

learner = FandangoLearner(grammar)
learned_constraints = learner.learn_constraints(
    initial_inputs,
)

fandango-learner:INFO: Instantiated patterns: 21
fandango-learner:INFO: Found 243 valid conjunctions
fandango-learner:INFO: Found 0 valid disjunctions


In [8]:
for candidate in learner.get_best_candidates():
    print(candidate.constraint)

(int(<number>) <= -10.0 and int(<onenine>) != len(str(<arithexp>)))
(int(<number>) <= -10.0 and int(<onenine>) != len(str(<start>)))
(int(<number>) <= -10.0 and str(<function>) == 'sqrt')
(int(<number>) <= -10.0 and int(<onenine>) < len(str(<arithexp>)))
(int(<number>) <= -10.0 and int(<onenine>) < len(str(<start>)))
((exists <container> in <digit>: int(<container>) <= 0.0) and str(<function>) == 'sqrt')
((exists <container> in <digit>: int(<container>) <= 0.0) and str(<function>) == 'sqrt')
((exists <container> in <number>: int(<container>) <= -10.0) and int(<onenine>) != len(str(<arithexp>)))
((exists <container> in <number>: int(<container>) <= -10.0) and int(<onenine>) != len(str(<start>)))
((exists <container> in <number>: int(<container>) <= -10.0) and str(<function>) == 'sqrt')
((exists <container> in <number>: int(<container>) <= -10.0) and int(<onenine>) < len(str(<arithexp>)))
((exists <container> in <number>: int(<container>) <= -10.0) and int(<onenine>) < len(str(<start>)))

We can see that the learned constraints are not perfect. For example, although `str(<function>) == 'sqrt'` is correct, the constraint `int(<number>) <= -10 ` is not yet general enough for distinguishing between passing and failing inputs. Furthermore, FandangoLearn learned additional constraints that, based on the initial inputs, are correct but not necessary for distinguishing between **all** passing and failing inputs.

Thus, we need to refine the learned constraints. We can do this by iteratively adding new inputs that are not yet distinguished by the learned constraints.

For each learned constraint, we generate new inputs that are not yet distinguished by the learned constraints. We then add these new inputs to the set of inputs and learn new constraints.

In [9]:
import math
from debugging_framework.input.oracle import OracleResult

def calculator_oracle(inp: str) -> OracleResult:
    try:
        eval(
            str(inp), {"sqrt": math.sqrt, "sin": math.sin, "cos": math.cos, "tan": math.tan}
        )
    except ValueError:
        return OracleResult.FAILING
    return OracleResult.PASSING

In [10]:
from fandango.evolution.algorithm import Fandango
from fandango.language.tree import DerivationTree

more_inputs = set(initial_inputs)
for candidate in learner.get_best_candidates():
    fandango = Fandango(grammar, [candidate.constraint], desired_solutions=5)
    solutions = fandango.evolve()
    
    for tree in solutions:
        more_inputs.add(FandangoInput(tree, calculator_oracle(str(tree))))
    
assert all(inp.oracle is not None for inp in more_inputs)
assert all(isinstance(inp.tree, DerivationTree) for inp in more_inputs)

In [11]:
for inp in more_inputs:
    print(inp.tree)

sqrt(-8800)
tan(-4003)
cos(-3900)
sqrt(-80)
cos(-88151)
sqrt(-6700)
sqrt(-30091)
sqrt(-3900)
sin(-400)
sqrt(-560030)
tan(-69)
sqrt(-10598)
sqrt(496900)
sqrt(-109)
sin(-108700)
sin(-7401)
sqrt(-605000)
sin(-360)
tan(-600000)
sqrt(-4000)
tan(-629200)
sqrt(-16)
tan(-837)
sqrt(-4005)
sqrt(-40050)
sqrt(-4093)
cos(-4000)
sin(-330440)
cos(-50707)
sqrt(-8670)
sqrt(2700)
sqrt(290700)
sin(-904000)
tan(-8705)
cos(-15)
cos(-62080)
tan(-700)
cos(-35360)
sqrt(-10)
tan(-9000)
tan(-45000)
sqrt(-580000)
sqrt(0)
sqrt(-40)
cos(-400)
sqrt(-45420)
cos(-30010)
sqrt(-300)
sin(-900)
sqrt(-220073)
sqrt(-8070)
cos(-8093)
sqrt(-674)
cos(10)
sqrt(-92430)
sqrt(-600804)
sin(-53280)
sqrt(832450)
cos(-305)
sqrt(10990)
sqrt(-6062)
sqrt(-990304)
sqrt(600000)
cos(-220774)
sqrt(-90380)
tan(-206168)
sqrt(-762002)
sin(-621108)
cos(-7005)
cos(-8000)
sqrt(2)
tan(-54)
sqrt(60)
sqrt(-12)
sqrt(-603)
sqrt(-900)


In [12]:
from fandangoLearner.learner import NonTerminal
relevant_non_terminals = {
    NonTerminal("<number>"),
    NonTerminal("<maybeminus>"),
    NonTerminal("<function>"),
}

In [13]:
learner = FandangoLearner(grammar)
learned_constraints = learner.learn_constraints(
    more_inputs, relevant_non_terminals=relevant_non_terminals
)

fandango-learner:INFO: Instantiated patterns: 21
fandango-learner:INFO: Found 10 valid conjunctions
fandango-learner:INFO: Found 0 valid disjunctions


In [15]:
for candidate in learner.get_best_candidates():
    print(candidate)

(int(<number>) <= -10.0 and str(<function>) == 'sqrt'), Precision: 1.0, Recall: 1.0 (based on 32 failing and 44 passing inputs)
((exists <container> in <number>: int(<container>) <= -10.0) and str(<function>) == 'sqrt'), Precision: 1.0, Recall: 1.0 (based on 32 failing and 44 passing inputs)


By adding new inputs, we were able to reduce the number of constraints significantly.
Thus, we were able to exclude constraints that were not precise enough. However, you see that the constraint `int(<number>) <= -10 and str(<function>) == 'sqrt'` is still not perfect. Further refining this constraint by genereating more inputs with this constraint will not help, because we will only generate more inputs that will fulfil this constraint. However, we need to generate inputs that are not covered by this constraint. Inputs such as `sqrt(-2)` or `sqrt(-1)`. Thus we have to negate this constraint.

In [93]:
from fandango.constraints.base import *
from fandango.language.search import RuleSearch

class NegationConstraint(Constraint):
    """A simple constraint to represent logical negations."""
    def __init__(self, inner_constraint: Constraint, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.inner_constraint = inner_constraint

    def transform(self, transformer: "ConstraintTransformer") -> "Constraint":
        # Apply the transformer to the inner constraint.
        normalized_inner = self.inner_constraint.transform(transformer)
        return NegationConstraint(normalized_inner)

    def accept(self, visitor: "ConstraintVisitor"):
        #visitor.visit_expression_constraint(self)
        pass

    def fitness(
            self, tree: DerivationTree, scope: Optional[Dict[NonTerminal, DerivationTree]] = None
    ) -> ConstraintFitness:
        """
        Computes the fitness for the negation of the inner constraint.

        Negation logic:
        - If the inner constraint is fully satisfied, this constraint is fully unsatisfied.
        - If the inner constraint is fully unsatisfied, this constraint is fully satisfied.
        - Otherwise, the fitness is calculated as the negation of the inner fitness.
        """
        # Evaluate the fitness of the inner constraint
        inner_fitness = self.inner_constraint.fitness(tree, scope)

        # Negate the fitness results
        solved = 1 - inner_fitness.solved
        total = inner_fitness.total
        success = not inner_fitness.success

        return ConstraintFitness(
            solved=solved,
            total=total,
            success=success,
            # failing_trees=failing_trees,
        )

    def __repr__(self):
        return f"~({repr(self.inner_constraint)})"

In [94]:
from fandango.constraints.base import ConjunctionConstraint
negated_constraint = NegationConstraint(learned_constraints[0].constraint.constraints[0])
test_constraint = ConjunctionConstraint([learned_constraints[0].constraint.constraints[1], negated_constraint])

In [95]:
test_constraint

(str(<function>) == 'sqrt' and ~(int(<number>) <= -10.0))

In [103]:
fandango = Fandango(grammar, [test_constraint], desired_solutions=100)
solutions = fandango.evolve()

In [104]:
negated_inputs = {FandangoInput(tree, calculator_oracle(str(tree))) for tree in solutions}

new_failing_inputs = {inp for inp in negated_inputs if inp.oracle == OracleResult.FAILING}
for inp in new_failing_inputs:
    print(inp.tree)

sqrt(-3)
sqrt(-5)
sqrt(-9)
sqrt(-6)


In [106]:
learner = FandangoLearner(grammar)
learned_constraints = learner.learn_constraints(
    more_inputs.union(new_failing_inputs), relevant_non_terminals=relevant_non_terminals
)

fandango-learner:INFO: Instantiated patterns: 21
fandango-learner:INFO: Found 10 valid conjunctions
fandango-learner:INFO: Found 0 valid disjunctions


In [107]:
for candidate in learner.get_best_candidates():
    print(candidate)

(int(<number>) <= -3.0 and str(<function>) == 'sqrt'), Precision: 1.0, Recall: 1.0 (based on 36 failing and 44 passing inputs)
((exists <container> in <number>: int(<container>) <= -3.0) and str(<function>) == 'sqrt'), Precision: 1.0, Recall: 1.0 (based on 36 failing and 44 passing inputs)


We can see that we got much closer to the perfect constraint. However, we still have to refine the constraint further. We can do this by adding more inputs that are not yet distinguished by the learned constraints.