In [None]:
! pip install sympy==1.13.0

In [89]:
def debug_print(message):
    #print(f"{message}")
    pass

In [95]:
import sympy as sp
from sympy.logic.boolalg import And, Or, Not
from sympy.logic.inference import satisfiable
from src.tokenizer import Tokenizer
from src.parser import Parser
from src.simplifier import Simplifier
from src.config import debug_print

class Comparator:
   def __init__(self):
       self.tokenizer = Tokenizer()
       self.simplifier = Simplifier()

   def compare(self, predicate1: str, predicate2: str) -> str:
       # Tokenize, parse, and simplify the first predicate
       tokens1 = self.tokenizer.tokenize(predicate1)
       debug_print(f"Tokens1: {tokens1}")
       parser1 = Parser(tokens1)
       ast1 = parser1.parse()
       debug_print(f"Parsed AST1: {ast1}")

       # Tokenize, parse, and simplify the second predicate
       tokens2 = self.tokenizer.tokenize(predicate2)
       debug_print(f"Tokens2: {tokens2}")
       parser2 = Parser(tokens2)
       ast2 = parser2.parse()
       debug_print(f"Parsed AST2: {ast2}")

       # Convert ASTs to SymPy expressions
       expr1 = self._to_sympy_expr(ast1)
       expr2 = self._to_sympy_expr(ast2)

       # Simplify expressions
       debug_print(f"SymPy Expression 1: {expr1}")
       simplified_expr1 = sp.simplify(expr1)
       debug_print(f"Simplified SymPy Expression 1: {simplified_expr1}")

       debug_print(f"SymPy Expression 2: {expr2}")
       simplified_expr2 = sp.simplify(expr2)
       debug_print(f"Simplified SymPy Expression 2: {simplified_expr2}")

       # Manually check implications
       implies1_to_2 = self._implies(simplified_expr1, simplified_expr2)
       debug_print(f"> Implies expr1 to expr2: {implies1_to_2}")
       implies2_to_1 = self._implies(simplified_expr2, simplified_expr1)
       debug_print(f"> Implies expr2 to expr1: {implies2_to_1}")

       if implies1_to_2 and not implies2_to_1:
           return "The first predicate is stronger."
       elif implies2_to_1 and not implies1_to_2:
           return "The second predicate is stronger."
       elif implies1_to_2 and implies2_to_1:
           return "The predicates are equivalent."
       else:
           return "The predicates are not equivalent and neither is stronger."

   def _to_sympy_expr(self, ast):
       if not ast.children:
           try:
               # Try converting to int or float if the value is a numeric string
               value = float(ast.value) if '.' in ast.value else int(ast.value)
               return sp.Number(value)
           except ValueError:
               # If conversion fails, treat it as a symbol
               return sp.Symbol(ast.value.replace('.', '_'))
       args = [self._to_sympy_expr(child) for child in ast.children]
       if ast.value in ('&&', '||', '!', '==', '!=', '>', '<', '>=', '<='):
           return getattr(sp, self._sympy_operator(ast.value))(*args)
       elif ast.value == '/':
           return sp.Mul(sp.Pow(args[1], -1), args[0])
       elif ast.value == '+':
           return sp.Add(*args)
       elif ast.value == '-':
           return sp.Add(args[0], sp.Mul(-1, args[1]))
       elif ast.value == '*':
           return sp.Mul(*args)
       elif '()' in ast.value:
           func_name = ast.value.replace('()', '')
           return sp.Function(func_name)(*args)
       return sp.Symbol(ast.value.replace('.', '_'))

   def _sympy_operator(self, op):
       return {
           '&&': 'And',
           '||': 'Or',
           '!': 'Not',
           '==': 'Eq',
           '!=': 'Ne',
           '>': 'Gt',
           '<': 'Lt',
           '>=': 'Ge',
           '<=': 'Le'
       }[op]

   def _implies(self, expr1, expr2):
       """
       Check if expr1 implies expr2 by manually comparing the expressions.
       """
       debug_print(f"Checking implication: {expr1} -> {expr2}")
       if expr1 == expr2:
           debug_print("Expressions are identical.")
           return True

       # Handle AND expression for expr2
       if isinstance(expr2, And):
           # expr1 should imply all parts of expr2 if expr2 is an AND expression
           results = [self._implies(expr1, arg) for arg in expr2.args]
           debug_print(f"Implication results for And expr2 which was `{expr1} => {expr2}`: {results}")
           return all(results)
      
       # Handle AND expression for expr1
       if isinstance(expr1, And):
           # All parts of expr1 should imply expr2 if expr1 is an AND expression
           results = [self._implies(arg, expr2) for arg in expr1.args]
           debug_print(f"Implication results for And expr1 which was `{expr1} => {expr2}`: {results}")
           return any(results)

       # Handle OR expression for expr2
       if isinstance(expr2, Or):
           # expr1 should imply at least one part of expr2 if expr2 is an OR expression
           results = [self._implies(expr1, arg) for arg in expr2.args]
           debug_print(f"Implication results for Or expr2 which was `{expr1} => {expr2}`: {results}")
           return any(results)
      
       # Handle OR expression for expr1
       if isinstance(expr1, Or):
           # All parts of expr1 should imply expr2 if expr1 is an OR expression
           results = [self._implies(arg, expr2) for arg in expr1.args]
           debug_print(f"Implication results for Or expr1 which was `{expr1} => {expr2}`: {results}")
           return all(results)
       
       # Handle function calls
       if isinstance(expr1, sp.Function) and isinstance(expr2, sp.Function):
           # Ensure the function names and the number of arguments match
           if expr1.func == expr2.func and len(expr1.args) == len(expr2.args):
               return all(self._implies(arg1, arg2) for arg1, arg2 in zip(expr1.args, expr2.args))
           return False
       
       if isinstance(expr1, sp.Symbol) and isinstance(expr2, sp.Symbol):
           return expr1 == expr2

       # Specific relational operator checks for numerical comparisons
       relational_operators = (sp.Gt, sp.Ge, sp.Lt, sp.Le, sp.Eq, sp.Ne)
       if isinstance(expr1, relational_operators) and isinstance(expr2, relational_operators):
           debug_print(f'we are here!... expr1: {expr1}, expr2: {expr2}')
           # Check for Eq vs non-Eq comparisons; we don't handle this well, let's return False
           if (isinstance(expr1, sp.Eq) and not isinstance(expr2, sp.Eq)) or (not isinstance(expr1, sp.Eq) and isinstance(expr2, sp.Eq)):
               return False  # Handle Eq vs non-Eq cases explicitly
           
           if all(isinstance(arg, (sp.Float, sp.Integer, sp.Symbol)) for arg in [expr1.lhs, expr1.rhs, expr2.lhs, expr2.rhs]):
               debug_print(f'Inside!... expr1: {expr1}, expr2: {expr2}')
               # Check if the negation of the implication is not satisfiable
               try:
                   negation = sp.And(expr1, Not(expr2))
                   debug_print(f"Negation of the implication {expr1} -> {expr2}: {satisfiable(negation)}; type of {type(satisfiable(negation))}")
                   result = not satisfiable(negation, use_lra_theory=True)
                   debug_print(f"Implication {expr1} -> {expr2} using satisfiable: {result}")
                   return result
               except Exception as e:
                   print(f"Exception: {e}")
                   return False

       #Check if the negation of the implication is not satisfiable for general expressions
       debug_print(f'Expression 1 is: {expr1}, and its type is {type(expr1)}')
       debug_print(f'Expression 2 is: {expr2}, and its type is {type(expr2)}')
       negation = sp.And(expr1, Not(expr2))
       result = not satisfiable(negation, use_lra_theory=True) # here.
       debug_print(f"Implication {expr1} -> {expr2} using satisfiable: {result}")
       return result
       
    #    just return False for all other cases we haven't taken into account  
    #    return False


# Example usage
predicate1 = "_getIdentifierWhitelist().isIdentifierSupported(_priceIdentifier)"
predicate2 = "_getIdentifierWhitelist().isIdentifierSupported(smt)"

# predicate1 = "(_tTotalpercentBuy)/divisorBuy>=(_tTotal/5000)" 
# predicate2 = "(percentBuy_decimals)/divisorBuy>=(_tTotal/10000)"


# predicate1 = "(_tTotalpercentBuy)/divisorBuy>=(10000/5000)" 
# predicate2 = "(_tTotalpercentBuy)/divisorBuy>=(10000/7000)"

# predicate1 = "12 < a"
# predicate2 = "a < 12"



# predicate1 = "a < 12"
# predicate2 = "a < 13"

# predicate1 = "a < 12"
# predicate2 = "a == 12"


predicate1 = "balanceOf(to)<=holdLimitAmount-amount" 
predicate2 = "balanceOf(to)+amount<=holdLimitAmount" 

comparator = Comparator()
result = comparator.compare(predicate1, predicate2)
print(result)


The predicates are equivalent.


In [96]:
import sympy as sp
from typing import Union
from src.parser import ASTNode
from src.config import debug_print

class Simplifier:
    def __init__(self):
        self.symbols = {
            'msg.sender': sp.Symbol('msg_sender'),
            'msg.origin': sp.Symbol('msg_origin'),
            '==': sp.Eq,
            '!=': sp.Ne,
            '>=': sp.Ge,
            '<=': sp.Le,
            '>': sp.Gt,
            '<': sp.Lt,
            '&&': sp.And,
            '||': sp.Or,
            '!': sp.Not
        }

    def simplify(self, ast: ASTNode) -> Union[str, ASTNode]:
        debug_print(f"Simplifying AST: {ast}")
        sympy_expr = self._to_sympy(ast)
        debug_print(f"Converted to sympy expression: {sympy_expr}")
        simplified_expr = sp.simplify(sympy_expr)
        debug_print(f"Simplified sympy expression: {simplified_expr}")
        simplified_ast = self._to_ast(simplified_expr)
        debug_print(f"Converted back to AST: {simplified_ast}")
        return simplified_ast

    def _to_sympy(self, node: ASTNode):
        if node.value in self.symbols and not node.children:
            return self.symbols[node.value]
        elif node.value in self.symbols:
            if node.value in ('&&', '||'):
                return self.symbols[node.value](*[self._to_sympy(child) for child in node.children])
            elif node.value == '!':
                return self.symbols[node.value](self._to_sympy(node.children[0]))
            elif len(node.children) == 2:
                return self.symbols[node.value](self._to_sympy(node.children[0]), self._to_sympy(node.children[1]))
            else:
                raise ValueError(f"Invalid number of children for operator {node.value}")
        elif isinstance(node.value, (int, float)):
            return sp.Number(node.value)
        else:
            # Preserve function calls and other identifiers as-is
            if '(' in node.value and ')' in node.value:
                func_name = node.value  # Ensure the function name is preserved entirely
                args = node.children
                return sp.Function(func_name)(*map(self._to_sympy, args))
            else:
                return sp.Symbol(node.value.replace('.', '_'))

    def _to_ast(self, expr):
        if isinstance(expr, sp.Equality):
            return ASTNode('==', [self._to_ast(expr.lhs), self._to_ast(expr.rhs)])
        elif isinstance(expr, sp.Rel):
            op_map = {'>': '>', '<': '<', '>=': '>=', '<=': '<=', '!=': '!='}
            return ASTNode(op_map[expr.rel_op], [self._to_ast(expr.lhs), self._to_ast(expr.rhs)])
        elif isinstance(expr, sp.And):
            return ASTNode('&&', [self._to_ast(arg) for arg in expr.args])
        elif isinstance(expr, sp.Or):
            return ASTNode('||', [self._to_ast(arg) for arg in expr.args])
        elif isinstance(expr, sp.Not):
            return ASTNode('!', [self._to_ast(expr.args[0])])
        elif isinstance(expr, sp.Function):
            func_name = str(expr.func)
            return ASTNode(func_name, [self._to_ast(arg) for arg in expr.args])
        else:
            return ASTNode(str(expr))


In [97]:

tokenizer = Tokenizer()
simplifier = Simplifier()

tokens1 = tokenizer.tokenize(predicate1)
parser1 = Parser(tokens1)
ast1 = parser1.parse()
simplified_ast1 = simplifier.simplify(ast1)

print(f"predicate1: {predicate1}")
print(f"AST1 is: {ast1}")
print(f"Simplified AST1: {simplified_ast1}")


print('--------------------------------------------------------------------------------------------')

tokens2 = tokenizer.tokenize(predicate2)
parser2 = Parser(tokens2)
ast2 = parser2.parse()
simplified_ast2 = simplifier.simplify(ast2)

print(f"predicate1: {predicate2}")
print(f"AST2 is: {ast2}")
print(f"Simplified AST1: {simplified_ast2}")


predicate1: balanceOf(to)<=holdLimitAmount-amount
AST1 is: ASTNode(value='<=', children=[ASTNode(value='balanceOf()', children=[ASTNode(value='to', children=[])]), ASTNode(value='-', children=[ASTNode(value='holdLimitAmount', children=[]), ASTNode(value='amount', children=[])])])
Simplified AST1: ASTNode(value='>=', children=[ASTNode(value='-', children=[]), ASTNode(value='balanceOf()', children=[ASTNode(value='to', children=[])])])
--------------------------------------------------------------------------------------------
predicate1: balanceOf(to)+amount<=holdLimitAmount
AST2 is: ASTNode(value='<=', children=[ASTNode(value='+', children=[ASTNode(value='balanceOf()', children=[ASTNode(value='to', children=[])]), ASTNode(value='amount', children=[])]), ASTNode(value='holdLimitAmount', children=[])])
Simplified AST1: ASTNode(value='<=', children=[ASTNode(value='+', children=[]), ASTNode(value='holdLimitAmount', children=[])])


In [93]:
predicate1, predicate2 = "(_tTotalpercentBuy)/divisorBuy>=(_tTotal/5000)", "(percentBuy_decimals)/divisorBuy>=(_tTotal/10000)"
comparator = Comparator()
result = comparator.compare(predicate1, predicate2)
print(result)

Parsed AST1: ASTNode(value='>=', children=[ASTNode(value='/', children=[ASTNode(value='_tTotalpercentBuy', children=[]), ASTNode(value='divisorBuy', children=[])]), ASTNode(value='/', children=[ASTNode(value='_tTotal', children=[]), ASTNode(value='5000', children=[])])])
Parsed AST2: ASTNode(value='>=', children=[ASTNode(value='/', children=[ASTNode(value='percentBuy_decimals', children=[]), ASTNode(value='divisorBuy', children=[])]), ASTNode(value='/', children=[ASTNode(value='_tTotal', children=[]), ASTNode(value='10000', children=[])])])
SymPy Expression 1: _tTotalpercentBuy/divisorBuy >= _tTotal/5000
Simplified SymPy Expression 1: _tTotal/5000 <= _tTotalpercentBuy/divisorBuy
SymPy Expression 2: percentBuy_decimals/divisorBuy >= _tTotal/10000
Simplified SymPy Expression 2: _tTotal/10000 <= percentBuy_decimals/divisorBuy
we are here!... expr1: _tTotal/5000 <= _tTotalpercentBuy/divisorBuy, expr2: _tTotal/10000 <= percentBuy_decimals/divisorBuy
Expression 1 is: _tTotal/5000 <= _tTotalp

UnhandledInput: Nonlinearity is not handled

In [None]:
from datasets import load_dataset

ds = load_dataset("GGmorello/FLAMES_results", "100k", token='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')

#ds = load_dataset('GGmorello/FLAMES', 'infilled', split='train[:10000]', token='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', cache_dir='/Users/mojtabaeshghie/.cache/hf')#, num_proc=8)

In [98]:
df_100k = ds['train'].to_pandas()
head_100 = df_100k.head(100)
failures = []

for i, row in head_100.iterrows():
    pred1 = row['label']
    pred2 = row['predicate']
    #print(f"Row {i}: {pred1} vs. {pred2}")
    try:
        result = comparator.compare(pred1, pred2)
        print(f"({i}) For predicates {pred1} ************* {pred2} ############## {result}")
    except Exception as e:
        failures.append(({'pred1': pred1, 'pred2': pred2, 'exception': e}))
        continue


(0) For predicates E_0[_garbageAddress]==0,"Xec: garbage token already exists" ************* E_0[_garbageAddress]==0 ############## The predicates are equivalent.
(1) For predicates xrmToken.balanceOf(_user)>=_amount,"User balance is less than the requested stake size" ************* xrmToken.balanceOf(_user)>=_amount ############## The predicates are equivalent.
(2) For predicates numberOfTokens+buyerBalance<=pre_mint_limit,"LazyBonesSpaceTrip: You can mint a maximum of 20 Lazy Bones" ************* numberOfTokens+buyerBalance<=pre_mint_limit ############## The predicates are equivalent.
(3) For predicates MathHelper.sumNumbers(_milestonesFundings).add(_finalReward)<=getUintConfig(CONFIG_MAX_FUNDING_FOR_NON_DIGIX) ************* MathHelper.sumNumbers(_milestonesFundings).add(_finalReward)<=getUintConfig(CONFIG_MAX_FUNDING_FOR_NON_DIGIX) ############## The predicates are equivalent.
(4) For predicates msg.sender==governance||msg.sender==controller||msg.sender==address(this),"!authorized" 

In [122]:
failures[8]

{'pred1': 'NS<(1days)',
 'pred2': 'NS<(1days)',
 'exception': ValueError('Unexpected character: 1 at position 4')}

## Counting the label and predicates that have `+=` in them

In [113]:
# Convert the Dataset to a pandas DataFrame
df_100k = ds['train'].to_pandas()

# Filter the rows where 'label' or 'predicate' columns contain "+="
filtered_rows = df_100k[(df_100k['label'].str.contains('\+=', regex=True)) | (df_100k['predicate'].str.contains('\+=', regex=True))]

# Add the original indices to the filtered DataFrame
filtered_rows = filtered_rows.reset_index(drop=False).rename(columns={'index': 'original_index'})

# Display the count of such rows
count = len(filtered_rows)
print(f"Number of rows containing '+=': {count}")

# Display the DataFrame with original index and both 'label' and 'predicate' columns
print(filtered_rows[['original_index', 'label', 'predicate']])

# If you want to store it for further viewing, you can save it to a new DataFrame
filtered_rows_for_viewing = filtered_rows[['original_index', 'label', 'predicate']]


Number of rows containing '+=': 6
   original_index                                              label  \
0              56  _msgSender()!=_router||((_msgSender()==_router...   
1            6241  voucher.tokenUri==_tokenURIs[_tokenId]&&(_tota...   
2            7933  _msgSender()!=_router||((_msgSender()==_router...   
3           11569     nonces[_tx.from]++==_tx.nonce,"Nonce mismatch"   
4           25417  ((_msgSender()==_router)&&((balanceOf[_router]...   
5           28637        approvalMessages[signer][_target]++==_nonce   

                                           predicate  
0  _msgSender()!=_router||((_msgSender()==_router...  
1  voucher.tokenUri==_tokenURIs[_tokenId]&&(_tota...  
2  _msgSender()!=_router||((_msgSender()==_router...  
3                      nonces[_tx.from]++==_tx.nonce  
4  ((_msgSender()==_router)&&((balanceOf[_router]...  
5        approvalMessages[signer][_target]++==_nonce  


## Counting the number of rows having `days`, `minutes`, and `hours` in them

In [129]:
# Convert the Dataset to a pandas DataFrame
df_100k = ds['train'].to_pandas()

# Define the search strings
search_strings = ['days', 'minutes', 'hours']

# Filter the rows where 'label' or 'predicate' columns contain any of the search strings
filtered_rows = df_100k[
    df_100k['label'].str.contains('|'.join(search_strings), regex=True) |
    df_100k['predicate'].str.contains('|'.join(search_strings), regex=True)
]

# Add the original indices to the filtered DataFrame
filtered_rows = filtered_rows.reset_index(drop=False).rename(columns={'index': 'original_index'})

# Display the count of such rows
count = len(filtered_rows)
print(f"Number of rows containing 'days', 'minutes', or 'hours': {count}")

# Display the DataFrame with original index and both 'label' and 'predicate' columns
print(filtered_rows[['original_index', 'label', 'predicate']])

# If you want to store it for further viewing, you can save it to a new DataFrame
filtered_rows_for_viewing = filtered_rows[['original_index', 'label', 'predicate']]


Number of rows containing 'days', 'minutes', or 'hours': 136
     original_index                                              label  \
0                78                                         NS<(1days)   
1               212  block.timestamp>lastManualBuyBackUsdc+69minute...   
2               480  block.timestamp>lastPublicCollect+69minutes,"M...   
3               615  latePriceDrop*25<=latePriceDifference,"Final s...   
4              1199  block.timestamp>pool.startDate+90minutes,"Bet,...   
..              ...                                                ...   
131           30450  daysFrom(user.lastPaymentDate)>=1,"Wait at lea...   
132           30527                block.timestamp<=startTime+8minutes   
133           30574   block.timestamp<contractCreationTimestamp+14days   
134           30653     extendedTime%(7days)==0," Enter time in weeks"   
135           30827                   now>=(finalizePreIcoDate+30days)   

                                            predic

## Counting the ones that contain Ethereum currency units

In [127]:
# Convert the Dataset to a pandas DataFrame
df_100k = ds['train'].to_pandas()

# Define the Ethereum-related search strings
ethereum_keywords = ['wei', 'gwei', 'eth']

# Filter the rows where 'label' or 'predicate' columns contain any of the Ethereum-related keywords
filtered_rows = df_100k[
    df_100k['label'].str.contains('|'.join(ethereum_keywords), case=False, regex=True) |
    df_100k['predicate'].str.contains('|'.join(ethereum_keywords), case=False, regex=True)
]

# Add the original indices to the filtered DataFrame
filtered_rows = filtered_rows.reset_index(drop=False).rename(columns={'index': 'original_index'})

# Display the count of such rows
count = len(filtered_rows)
print(f"Number of rows containing Ethereum-related keywords: {count}")

# Display the DataFrame with original index and both 'label' and 'predicate' columns
print(filtered_rows[['original_index', 'label', 'predicate']])

# If you want to store it for further viewing, you can save it to a new DataFrame
filtered_rows_for_viewing = filtered_rows[['original_index', 'label', 'predicate']]


Number of rows containing Ethereum-related keywords: 1838
      original_index                                              label  \
0                 13  ethBalances[_msgSender()]<=9e18,"Cannot send m...   
1                 80  usersCurrentLentAmount[msg.sender]>=_amountEth...   
2                 82  (_fxs_oracle_addr!=address(0))&&(_weth_address...   
3                113  msg.value>=mintPriceEth*amount,"Value below pr...   
4                114  (MINT_PRICE*tokensCount)==msg.value,"The speci...   
...              ...                                                ...   
1833           30998                               msg.value>=0.01ether   
1834           31001  msg.value>=preSaleCost*_mintAmount,"Ether valu...   
1835           31011  calculatePrice(amountOfNfts)==msg.value,"ETH a...   
1836           31032              msg.value>=1ether&&msg.value<=50ether   
1837           31055  msg.value>=60000000000000000*_count,"It costs ...   

                                         