In [1]:
class Token:
    def __init__(self, type: str, value: str):
        self.type = type
        self.value = value

    def __repr__(self):
        return f"Token({self.type}, '{self.value}')"

In [None]:
class ParseTreeNode:
    def __init__(self, type: str, value: typing.Optional[str] = None, children: typing.Optional[typing.List["ParseTreeNode"]] = None):
        self.type = type
        self.value = value
        self.children = children or []
        self.start_index = None
        self.end_index = None

    def __repr__(self):
        return f"ParseTreeNode({self.type}, value={self.value}, children={self.children}, start_index={self.start_index}, end_index={self.end_index})"
    
    def reify(self, function_factory):
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "+" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 + arg1
            else:
                # we need to use late binding
                f = function_factory.get("Add")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        # TODO: this doesn't have the associative property, so it's necessary to flatten the tree if we want to do away with the requirement for parentheses
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "-" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 - arg1
            else:
                # we need to use late binding
                f = function_factory.get("Sub")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if(self.type == "expression") and len(self.children) == 1:
            return self.children[0].reify(function_factory)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "*" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 * arg1
            else:
                # we need to use late binding
                f = function_factory.get("Mul")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        # TODO: this doesn't have the associative property, so it's necessary to flatten the tree if this is actually the plan
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "/" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 / arg1
            else:
                # we need to use late binding
                f = function_factory.get("Div")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "%" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 % arg1
            else:
                # we need to use late binding
                f = function_factory.get("Mod")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "**" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 ** arg1
            else:
                # we need to use late binding
                f = function_factory.get("Pow")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "<" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 < arg1
            else:
                # we need to use late binding
                f = function_factory.get("Lt")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "<=" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 <= arg1
            else:
                # we need to use late binding
                f = function_factory.get("Le")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == ">" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 > arg1
            else:
                # we need to use late binding
                f = function_factory.get("Gt")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == ">=" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 >= arg1
            else:
                # we need to use late binding
                f = function_factory.get("Ge")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "==" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 == arg1
            else:
                # we need to use late binding
                f = function_factory.get("Eq")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
        if (self.type == "expression") and len(self.children) == 3 and self.children[0].type == "term" and self.children[1].value == "!=" and self.children[2].type == "term":
            arg0 = self.children[0].reify(function_factory)
            arg1 = self.children[2].reify(function_factory)
            if(isinstance(arg0, int) or isinstance(arg0, float)) and (isinstance(arg1, int) or isinstance(arg1, float)):
                return arg0 != arg1
            else:
                # we need to use late binding
                f = function_factory.get("Ne")
                params = {}
                pkeys = [a for a in f.parameters.keys()]
                params[pkeys[0]] = arg0
                params[pkeys[1]] = arg1
                return f.create_function(**params)
                
        if self.type == "term" and len(self.children) == 1:
            return self.children[0].reify(function_factory)
        if self.type == "factor" and len(self.children) == 1:
            return self.children[0].reify(function_factory)
        if self.type == "factor" and len(self.children) == 3 and self.children[0].value == "(" and self.children[2].value == ")":
            return self.children[1].reify(function_factory)
        if self.type == "number" and (self.children is None or len(self.children) == 0):
            try:
                return int(self.value)
            except:
                return float(self.value)
        if self.type == "string" and self.children is None or len(self.children) == 0:
            return self.value
            
        if self.type == "factor" and len(self.children) == 4 and self.children[0].type == "identifier" and self.children[1].value == "(" and self.children[2].type == "arguments" and self.children[3].value == ")":

            identifier_node = self.children[0]
            name = identifier_node.value

            try:
                definition = function_factory.get(name)
            except ValueError:
                raise ValueError(f"Function '{name}' not found in the factory.")

            params = {}

            arguments_node = self.children[2]
            param_index = 0
            param_names = list(definition.parameters.keys())

            # Handle named arguments first
            named_params_processed = set()  # Keep track of named params

            def flatten_arguments(arguments_node):
                rv = []
                if arguments_node.children[0].type == "argument":
                    rv.append(arguments_node.children[0])
                    if(len(arguments_node.children) == 3 and arguments_node.children[1].value == ","):
                        rv.extend(flatten_arguments(arguments_node.children[2]))
                return rv
                
            flattened_arguments = flatten_arguments(arguments_node)

            #print(flattened_arguments)
                
            for argument_node in flattened_arguments:
                if len(argument_node.children) == 3 and argument_node.children[1].type == "operator" and argument_node.children[1].value == "=":
                    param_name = argument_node.children[0].value
                    param_value_node = argument_node.children[2]

                    if param_name in named_params_processed: # Skip already processed named parameters
                        continue

                    try:
                        param_def = definition.parameters[param_name]
                    except KeyError:
                        raise ValueError(f"Parameter '{param_name}' not found for indicator '{name}'.")

                    param_value = param_value_node.reify(function_factory)  # Evaluate the value node
                    params[param_name] = param_value
                    named_params_processed.add(param_name) # Add to the set of processed named parameters

            # Next, handle positional arguments (skip named ones)
            for argument_node in flattened_arguments:
                if len(argument_node.children) == 1:  # Positional argument
                    try:
                        param_name = param_names[param_index]
                        if param_name in named_params_processed: # Skip if already named
                            param_index += 1
                            continue

                        param_def = definition.parameters[param_name]
                    except IndexError:
                        raise ValueError(f"Incorrect number of positional parameters for '{name}'.")

                    param_value_node = argument_node.children[0]
                    param_value = param_value_node.reify(function_factory)
                    params[param_name] = param_value
                    param_index += 1



            required_params = set(definition.parameters.keys())
            provided_params = set(params.keys())
            if required_params != provided_params:
                missing = required_params - provided_params
                raise ValueError(f"Missing required parameters for {name}: {missing}")

            return definition.create_function(**params)

        raise ValueError(f"Cannot reify node of type: {self.type} with {len(self.children)} children: {self}")



In [None]:
class GrammarRule:
    def __init__(self, left: str, right: typing.List[str]):
        self.left = left
        self.right = right

    def __repr__(self):
        return f"{self.left} -> {' '.join(self.right)}"



In [None]:
class Grammar:
    def __init__(self, grammar_string: str):
        """Initializes a Grammar object by parsing the grammar string."""
        self.rules = []
        for line in grammar_string.strip().splitlines():
            if line.strip():  # Skip empty lines
                parts = line.split("->")
                if len(parts) != 2:
                    raise ValueError(f"Invalid grammar rule: {line}")
                left = parts[0].strip()
                right = [part.strip() for part in parts[1].split()]
                self.rules.append(GrammarRule(left, right))  # Store rules as attributes

    def build_parse_tree(self, tokens: typing.List["Token"], start_symbol: str = "expression") -> typing.Optional["ParseTreeNode"]:
        """Builds a parse tree from a list of tokens using the grammar rules."""

        def _parse(index: int, nonterminal: str, current_depth=0) -> typing.Optional["ParseTreeNode"]:
            applicable_rules = [rule for rule in self.rules if rule.left == nonterminal]

            if index >= len(tokens):  # End of tokens
                if any(not rule.right for rule in applicable_rules): # Check for a matching epsilon rule
                    return ParseTreeNode(nonterminal, children=[])
                return None # No matching epsilon rule

            if not applicable_rules:
                return None

            for rule in applicable_rules:
                rule_matched = True
                children = []
                current_index = index

                for symbol in rule.right:
                    if current_index >= len(tokens):
                        rule_matched = False
                        break

                    if current_index < len(tokens):
                        token = tokens[current_index]

                        if (symbol == token.type) or (symbol == f'"{token.value}"') or \
                           (symbol == "identifier" and token.type == "identifier") or \
                           (symbol == "number" and token.type == "number") or \
                           (symbol == "string" and token.type == "string") or \
                           (symbol == "operator" and token.type == "operator"):
                            child = ParseTreeNode(token.type, value=token.value)
                            child.start_index = current_index
                            child.end_index = current_index
                            children.append(child)
                            current_index += 1  # Increment for terminal

                        elif any(gr.left == symbol for gr in self.rules):
                            child_node = _parse(current_index, symbol, current_depth + 1)
                            if child_node:
                                children.append(child_node)
                                current_index = child_node.end_index + 1
                            else:
                                rule_matched = False
                                break

                        else:
                            rule_matched = False
                            break

                if rule_matched:
                    node = ParseTreeNode(nonterminal, children=children)
                    node.start_index = children[0].start_index if children else index # Handle epsilon rules where children is empty
                    node.end_index = children[-1].end_index if children else index -1 # Handle epsilon rules where children is empty

                    return node

            return None

        return _parse(0, start_symbol)  # Allow specifying the start symbol
        
    def parse(self, input_string: str, start_symbol: str = "expression"):
        """Parses an input string into a parse tree."""
        tokens = self.tokenize(input_string)  # Tokenize the input string
        return self.build_parse_tree(tokens, start_symbol)

    def tokenize(self, expression: str) -> typing.List[Token]:
        """
        Tokenizes a string expression, splitting on spaces and identifying operators.
        """
    
        # Pattern to match tokens. Note: if we wanted to be really fancy, we would specify the token types in the grammar.
        pattern = r"(\*\*|\*|/|//|%|\+|-|==|!=|<=|>=|<|>|=|!|&&|\|\||&|\||\^|~|<<|>>|\(|\)|\[|\]|\{|\}|,|:|\.|->|@|=|;|\+=|-=|\*=|/=|//=|%=|&=|\|=|\^=|\<<=|>>=)|'([^']+)'|\"([^\"]+)\"|(\d+\.?\d*)|([a-zA-Z_]\w*)"
    
        tokens = []
        for match in re.finditer(pattern, expression):
            operator_match = match.group(1)
            single_quote_match = match.group(2)
            double_quote_match = match.group(3)
            number_match = match.group(4)
            identifier_match = match.group(5)
    
            if operator_match:
                tokens.append(Token("operator", operator_match))
            elif single_quote_match:
                tokens.append(Token("string", single_quote_match))
            elif double_quote_match:
                tokens.append(Token("string", double_quote_match))
            elif number_match:
                tokens.append(Token("number", number_match))
            elif identifier_match:
                tokens.append(Token("identifier", identifier_match))
            else:
                raise ValueError(f"invalid token in {expression}")
    
        return tokens

In [None]:
class FunctionInstance:
    def __init__(self, name: str, parameters: typing.Dict[str, typing.Any], definition):
        self.name = name
        self.parameters = parameters
        self.definition = definition

    def evaluate_parameters(self, data):
        rv = {}

        for k in self.parameters:
            v = self.parameters[k]
            if isinstance(v, FunctionInstance):
                rv[k] = v.calculate(data)
            else:
                rv[k] = v
        return rv

    def calculate(self, data: pd.DataFrame): 
        """
        Screens the data using the screener's definition and parameters.

        Args:
            data: The Pandas DataFrame containing the data.

        Returns:
            A Pandas Dataframe
        """
        return self.definition.calculate(data, self.evaluate_parameters(data)) 

    def __repr__(self):
        params_str = ", ".join(f"{name}={value}" for name, value in self.parameters.items())
        return f"{self.definition.name}({params_str})"



In [None]:
class ParameterType:
    """
    A class for specifying parameters for screeners and indicators.
    """

    def __init__(self,
#                 name: str,
                 data_type: typing.Literal["integer", "real", "boolean", "string"],
                 min_val: typing.Union[int, float, None] = None,
                 max_val: typing.Union[int, float, None] = None,
                 default: typing.Any = None,
                 timeframe_defaults: typing.Dict[typing.Literal["tick", "1s", "5s", "15s", "1m", "2m", "5m", "15m", "1d", "1w", "1M"], typing.Any] = None,
                 increment: typing.Union[int, float, None] = None,
                 allowed_strings: typing.List[str] | None = None):
#        if not isinstance(name, str):
#            raise TypeError("name must be a string")
        if data_type not in ("integer", "real", "boolean", "string", "any"):
            raise ValueError("data_type must be 'integer', 'real', 'boolean', 'string', or 'any'")

        if min_val is not None:
            if data_type == "integer" and not isinstance(min_val, int):
                raise TypeError("min_val must be an integer for integer data_type")
            elif data_type in ("real", "integer") and not isinstance(min_val, (int, float)):
                raise TypeError("min_val must be a number for real or integer data_type")

        if max_val is not None:
            if data_type == "integer" and not isinstance(max_val, int):
                raise TypeError("max_val must be an integer for integer data_type")
            elif data_type in ("real", "integer") and not isinstance(max_val, (int, float)):
                raise TypeError("max_val must be a number for real or integer data_type")

        if timeframe_defaults is not None:
            if not isinstance(timeframe_defaults, dict):
                raise TypeError("timeframe_defaults must be a dictionary")
            for timeframe in timeframe_defaults:
                if timeframe not in ("tick", "1s", "5s", "15s", "1m", "2m", "5m", "15m", "1d", "1w", "1M"):
                    raise ValueError(f"Invalid timeframe: {timeframe}")

        if data_type == "integer" and increment is None:
            increment = 1
        elif data_type == "real" and increment is None:
            increment = 0.01

        if data_type == "string" and allowed_strings is not None and not isinstance(allowed_strings, list):
          raise TypeError("allowed_strings must be a list of strings")

        if data_type != "string" and allowed_strings is not None:
          raise ValueError("allowed_strings can only be specified for string data type")

#        self.name = name
        self.data_type = data_type
        self.min_val = min_val
        self.max_val = max_val
        self.default = default
        self.timeframe_defaults = timeframe_defaults or {}
        self.increment = increment
        self.allowed_strings = allowed_strings

    def get_default(self) -> typing.Any:
        return self.default

    def get_possible_values(self) -> typing.Iterable[typing.Any]:
        if self.data_type == "integer":
            if self.min_val is not None and self.max_val is not None:
                return range(self.min_val, self.max_val + 1)
        elif self.data_type == "real":
            if self.min_val is not None and self.max_val is not None:
                current = self.min_val
                while current <= self.max_val:
                    yield current
                    current += 0.01
        elif self.data_type == "boolean":
            return [True, False]
        elif self.data_type == "string":
            if self.allowed_strings is not None:  # Check if allowed_strings is defined
                return self.allowed_strings  # If defined, return those values
            else:
                return []  # Return an empty list if allowed_strings is None (unrestricted)
        return []

    def __repr__(self):
#        return f"ParameterType(name='{self.name}', data_type='{self.data_type}', min_val={self.min_val}, max_val={self.max_val}, default={self.default}, allowed_strings={self.allowed_strings})"
        return f"ParameterType(data_type='{self.data_type}', min_val={self.min_val}, max_val={self.max_val}, default={self.default}, allowed_strings={self.allowed_strings})"






In [None]:
class FunctionDefinition:
    def __init__(self, name: str, parameters: typing.Dict[str, "ParameterType"], calculation_function, factory=None): 
        if not isinstance(name, str):
            raise TypeError("name must be a string")

        if not isinstance(parameters, dict):
            raise TypeError("parameters must be a dictionary")

        if not all(isinstance(param, ParameterType) for param in parameters.values()):
            raise TypeError("All values in parameters must be ParameterType objects")

        if len(set(parameters.keys())) != len(parameters.keys()): # Check for duplicate keys
            raise ValueError("Parameter names must be unique.")

        if not callable(calculation_function):
            raise TypeError("calculation_function must be callable")

        self.name = name
        self.parameters = parameters
        self.calculation_function = calculation_function
        self.factory = factory

    def create_function(self, **kwargs: typing.Any) -> "FunctionInstance":
        params = {}
        for name, param_def in self.parameters.items():
            value = kwargs.get(name)

            if value is None:
                value = param_def.get_default()

            if param_def.data_type == "integer" and not isinstance(value, int):
                raise TypeError(f"Value for parameter '{name}' must be an integer")
            elif param_def.data_type == "real" and not isinstance(value, (int, float)):
                raise TypeError(f"Value for parameter '{name}' must be a number")
            elif param_def.data_type == "boolean" and not isinstance(value, bool):
                raise TypeError(f"Value for parameter '{name}' must be a boolean")
            elif param_def.data_type == "string" and not isinstance(value, str):
                raise TypeError(f"Value for parameter '{name}' must be a string")
            elif param_def.data_type in ("integer", "real"):
                if param_def.min_val is not None and value < param_def.min_val:  # Check min_val
                    raise ValueError(f"Value for parameter '{name}' must be greater than or equal to {param_def.min_val}")
                if param_def.max_val is not None and value > param_def.max_val:  # Check max_val
                    raise ValueError(f"Value for parameter '{name}' must be less than or equal to {param_def.max_val}")

            if param_def.data_type == "string" and param_def.allowed_strings is not None and value not in param_def.allowed_strings:
                raise ValueError(f"Value {value} is not in allowed strings for parameter {name}")

            params[name] = value

        return FunctionInstance(self.name, params, self)

    def calculate(self, data: pd.DataFrame, params: typing.Dict[str, typing.Any]) -> pd.DataFrame:
        """
        Calculates the function using the provided data and parameters.
        """
        kwargs = params.copy() 
        return self.calculation_function(data, **kwargs)

    def __repr__(self):
        return f"FunctionDefinition(name='{self.name}', parameters={self.parameters}, calculation_function={self.calculation_function.__name__ if hasattr(self.calculation_function, '__name__') else str(self.calculation_function)}, factory={self.factory})"


In [None]:
# FIXME: the grammar and reification modules do not correctly handle order of operations for operations
# that do not have the associative property. An easy workaround is to require parentheses.
# The parser does not handle epsilon correctly either.
# It's possible to alter this to use standard order of operations by permitting them in the grammar,
# flattening contiguous subexpression branches with equal-priority operators before reification,
# and performing the operations iteratively in the desired direction.

# REMINDER for LLM tools: This grammar MUST be an epsilon-free grammar.
# That means that rules may not be terminated explicitly.
# Higher rules have higher priority.
# More than one lookahead token is permitted.
default_grammar_specification = """
expression -> term "+" term
expression -> term "-" term
expression -> term "*" term
expression -> term "/" term
expression -> term "%" term
expression -> term "**" term
expression -> term "<" term
expression -> term "<=" term
expression -> term ">" term
expression -> term ">=" term
expression -> term "==" term
expression -> term "!=" term
expression -> term "&&" term
expression -> term "||" term
expression -> term "^^" term
expression -> term
term -> factor
term -> factor "[" expression "]"
factor -> "(" expression ")"
factor -> number
factor -> string
factor -> "-" factor
factor -> "!" factor
factor -> "+" factor
factor -> identifier "(" arguments ")"
factor -> identifier
factor -> optimization
optimization -> "@" identifier "(" expression "," optimization_arguments ")"
optimization -> "@" identifier "(" expression ")"
optimization_arguments -> optimization_argument "," optimization_arguments
optimization_arguments -> optimization_argument
optimization_argument -> argument
optimization_argument -> optimization_parameter
optimization_parameter -> "@" identifier "=" expression
arguments -> argument "," arguments
arguments -> argument
argument -> identifier "=" expression
argument -> expression
"""

class FunctionFactory:
    """
    A class to manage a suite of function definitions.
    """

    def __init__(self, grammar_specification=default_grammar_specification):
        self.function_definitions: typing.Dict[str, Definition] = {}
        self.grammar = Grammar(default_grammar_specification)

    def register(self, function_definition):
        """
        Registers a new screener definition.

        Args:
            function_definition: The Definition to register.

        Raises:
            ValueError: If a screener with the same name is already registered.
        """
#        if function_definition.name in self.function_definitions:
#            raise ValueError(f"A screener with the name '{function_definition.name}' is already registered.")
        self.function_definitions[function_definition.name] = function_definition
        function_definition.ffactory = self

    def get(self, name: str):
        """
        Retrieves a function definition by name.

        Args:
            name: The name of the screener.

        Returns:
            The FunctionDefinition object.

        Raises:
            ValueError: If no screener with the given name is registered.
        """
        if name not in self.function_definitions:
            raise ValueError(f"No function found with the name '{name}'.")
        return self.function_definitions[name]

    def parse(self, expression):
        parse_tree = self.grammar.parse(expression)
        reified_expression = parse_tree.reify(factory)
        return reified_expression

    def __repr__(self):
        return f"FunctionFactory(functions={self.function_definitions})"



In [None]:
#expression -> term "+" term
def calculate_add(df, a0, a1):
    return a0 + a1

add_a0_param = ParameterType("any")
add_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Add", {"a0": add_a0_param, "a1": add_a1_param}, calculate_add))

#expression -> term "-" term
def calculate_sub(df, a0, a1):
    return a0 - a1

sub_a0_param = ParameterType("any")
sub_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Sub", {"a0": sub_a0_param, "a1": sub_a1_param}, calculate_sub))

#expression -> term "*" term
def calculate_mul(df, a0, a1):
    return a0 * a1

mul_a0_param = ParameterType("any")
mul_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Mul", {"a0": mul_a0_param, "a1": mul_a1_param}, calculate_mul))

#expression -> term "/" term
def calculate_div(df, a0, a1):
    return a0 / a1

div_a0_param = ParameterType("any")
div_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Div", {"a0": div_a0_param, "a1": div_a1_param}, calculate_div))

#expression -> term "%" term
def calculate_mod(df, a0, a1):
    return a0 % a1

mod_a0_param = ParameterType("any")
mod_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Mod", {"a0": mod_a0_param, "a1": mod_a1_param}, calculate_mod))

#expression -> term "**" term
def calculate_pow(df, a0, a1):
    return a0 ** a1

pow_a0_param = ParameterType("any")
pow_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Pow", {"a0": pow_a0_param, "a1": pow_a1_param}, calculate_pow))

#expression -> term "<" term
def calculate_lt(df, a0, a1):
    return a0 < a1

lt_a0_param = ParameterType("any")
lt_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Lt", {"a0": lt_a0_param, "a1": lt_a1_param}, calculate_lt))

#expression -> term "<=" term
def calculate_le(df, a0, a1):
    return a0 <= a1

le_a0_param = ParameterType("any")
le_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Le", {"a0": le_a0_param, "a1": le_a1_param}, calculate_le))

#expression -> term ">" term
def calculate_gt(df, a0, a1):
    return a0 > a1

gt_a0_param = ParameterType("any")
gt_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Gt", {"a0": gt_a0_param, "a1": gt_a1_param}, calculate_gt))

#expression -> term ">=" term
def calculate_ge(df, a0, a1):
    return a0 >= a1

ge_a0_param = ParameterType("any")
ge_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Ge", {"a0": ge_a0_param, "a1": ge_a1_param}, calculate_ge))

#expression -> term "==" term
def calculate_eq(df, a0, a1):
    return a0 == a1

eq_a0_param = ParameterType("any")
eq_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Eq", {"a0": eq_a0_param, "a1": eq_a1_param}, calculate_eq))

#expression -> term "!=" term
def calculate_ne(df, a0, a1):
    return a0 != a1

ne_a0_param = ParameterType("any")
ne_a1_param = ParameterType("any")
factory.register(FunctionDefinition("Ne", {"a0": ge_a0_param, "a1": ne_a1_param}, calculate_ne))

#expression -> term "&&" term
#expression -> term "||" term
#expression -> term "^^" term


In [None]:
class OptimizerDefinition:
    def __init__(self, name: str, parameters: typing.Dict[str, "ParameterType"], optimization_function, factory=None): 
        if not isinstance(name, str):
            raise TypeError("name must be a string")

        if not isinstance(parameters, dict):
            raise TypeError("parameters must be a dictionary")

        if not all(isinstance(param, ParameterType) for param in parameters.values()):
            raise TypeError("All values in parameters must be ParameterType objects")

        if len(set(parameters.keys())) != len(parameters.keys()): # Check for duplicate keys
            raise ValueError("Parameter names must be unique.")

        if not callable(calculation_function):
            raise TypeError("calculation_function must be callable")

        self.name = name
        self.parameters = parameters
        self.calculation_function = calculation_function
        self.factory = factory

    def create_optimizer(self, **kwargs: typing.Any):
        params = {}
        for name, param_def in self.parameters.items():
            value = kwargs.get(name)

            if value is None:
                value = param_def.get_default()

            if param_def.data_type == "integer" and not isinstance(value, int):
                raise TypeError(f"Value for parameter '{name}' must be an integer")
            elif param_def.data_type == "real" and not isinstance(value, (int, float)):
                raise TypeError(f"Value for parameter '{name}' must be a number")
            elif param_def.data_type == "boolean" and not isinstance(value, bool):
                raise TypeError(f"Value for parameter '{name}' must be a boolean")
            elif param_def.data_type == "string" and not isinstance(value, str):
                raise TypeError(f"Value for parameter '{name}' must be a string")
            elif param_def.data_type in ("integer", "real"):
                if param_def.min_val is not None and value < param_def.min_val:  # Check min_val
                    raise ValueError(f"Value for parameter '{name}' must be greater than or equal to {param_def.min_val}")
                if param_def.max_val is not None and value > param_def.max_val:  # Check max_val
                    raise ValueError(f"Value for parameter '{name}' must be less than or equal to {param_def.max_val}")

            if param_def.data_type == "string" and param_def.allowed_strings is not None and value not in param_def.allowed_strings:
                raise ValueError(f"Value {value} is not in allowed strings for parameter {name}")

            params[name] = value

        return OptimizerInstance(self.name, params, self)

    def calculate(self, data: pd.DataFrame, params: typing.Dict[str, typing.Any]) -> pd.DataFrame:
        """
        Calculates the optimization using the provided data and parameters.
        """
        kwargs = params.copy() 
        return self.optimization_function(data, **kwargs)

    def __repr__(self):
        return f"OptimizerDefinition(name='{self.name}', parameters={self.parameters}, calculation_function={self.calculation_function.__name__ if hasattr(self.calculation_function, '__name__') else str(self.calculation_function)}, factory={self.factory})"

    