In [1]:
import ast, inspect

ALLOWED_FUNCTIONS = {'itertools', 'numpy', 'np', 'math'}
DISALLOWED_BUILTINS = {'print','__import__','breakpoint','compile','open','dir','eval','exec','globals','input','repr'}

class FunctionChecker(ast.NodeVisitor):
    def __init__(self):
        self.is_safe = True
        self.vars = []

    def visit_Import(self, node):
        for alias in node.names:
            if alias.name not in ALLOWED_FUNCTIONS:
                self.is_safe = False
            self.generic_visit(node)

    def visit_ImportFrom(self, node):
        if node.module not in ALLOWED_FUNCTIONS:
            self.is_safe = False
        self.generic_visit(node)
    
    def visit_Assign(self, node):
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.vars.append(target.id)
        self.generic_visit(node)

    def visit_Call(self, node):
        # Check for disallowed built-in function calls
        if isinstance(node.func, ast.Name):
            if node.func.id in DISALLOWED_BUILTINS:
                self.is_safe = False
        # Check if function calls are from allowed modules
        elif isinstance(node.func, ast.Attribute):
            # if node.func.value.id not in ALLOWED_FUNCTIONS:
            #     self.is_safe = False
            func_value = ast.unparse(node.func.value)
            if func_value.split('.')[0] not in ALLOWED_FUNCTIONS and func_value.split('.')[0] not in self.vars:
                self.is_safe = False
        self.generic_visit(node)

def is_function_safe(func):
    function_code = inspect.getsource(func)
    tree = ast.parse(function_code)
    checker = FunctionChecker()
    checker.visit(tree)
    return checker.is_safe

def is_function_safe_string(function_code:str)->bool:
    #function_code = inspect.getsource(func)
    tree = ast.parse(function_code)
    checker = FunctionChecker()
    checker.visit(tree)
    return checker.is_safe

In [2]:
import numpy as np
import jax.numpy as jnp
import itertools

def my_function():
    #import matplotlib.pyplot as plt
    #print(5)
    eval('print(2)')

    return np.math.comb(10,2)

print(is_function_safe(my_function))

#my_function()

False


In [3]:
def priority1(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  pair_count = len(set(itertools.combinations(v, 2)))
  return pair_count / (n * (n - 1) / 2)


def priority2(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  v = sorted(v)
  pair_count = sum(1 for i in range(1, n) if v[i] != v[i-1])
  return pair_count / n


def priority3(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  unique_values = np.unique(v)
  pair_count = sum(1 for i in range(n) for j in range(i+1, n) if v[i] != v[j])
  diversity = len(unique_values) / n
  return pair_count / n + diversity


def priority4(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  unique_elements = len(set(v))
  pair_count = sum(1 for i in range(n) for j in range(i+1, n) if v[i] != v[j])
  return (unique_elements * pair_count) / (n * (n-1))

def priority5(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  # This function is identical to `priority_v2`; we're only adding it to maintain the sequence of function names
  return priority_v2(v, n)


def priority6(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  count = 0
  for i in range(n):
    for j in range(i+1, n):
      if v[i] != v[j]:
        count += 1
  return count / n


def priority7(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  # Convert vector to a bitmask by inverting each 3-valued entry and joining them as a single binary number
  bitmask = sum(2**i for i in range(n) if v[i] != 0)
  # Calculate the number of 1s in the bitmask to get the pair count
  pair_count = np.popcount(bitmask)
  return pair_count / n


def priority8(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  unique_pairs = len(set(list(itertools.combinations(v, 2))))
  return unique_pairs / np.math.comb(n, 2)

def priority9(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
    The cap set will be constructed by adding vectors that do not create a line in order by priority.
  """
  unique_pairs = open('test')
  return unique_pairs / np.math.comb(n, 2)

def priority10(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
      The cap set will be constructed by adding vectors that do not create a line in order by priority.
      This version checks for collinearity by calculating the determinant of the 3x3 matrix made from the vectors.
  """
  # Calculate the dot product of each pair of vectors
  dot_products = np.outer(v, v)[::-1].sum(axis=0)

  # Add a small epsilon to avoid singularity
  epsilon = 1e-8
  determinant = np.linalg.det(np.vstack((v, dot_products, np.ones(n))).T + epsilon * np.eye(n + 1))

  # Normalize the determinant to a range of [0, 1]
  return 1.0 - np.abs(determinant) / (np.linalg.det(np.eye(n + 1)) ** 2)

def priority11(v: tuple[int, ...], n: int) -> float:
  """Returns the priority, as a floating point number, of the vector `v` of length `n`. The vector 'v' is a tuple of values in {0,1,2}.
      The cap set will be constructed by adding vectors that do not create a line in order by priority.
      This version uses a simple heuristic to prioritize vectors with fewer lines.
  """
  lines = set()
  for i in range(n):
    for j in range(i+1, n):
      if np.sum(np.array(v[i:j+1]) == np.array(v[i:j+1]).sum(axis=0)) == 3:
        lines.add((i, j))
  if lines:
    return len(lines)
  else:
    return 0.0

In [7]:
%timeit is_function_safe_string(inspect.getsource(priority1))
%timeit is_function_safe_string(inspect.getsource(priority2))
%timeit is_function_safe_string(inspect.getsource(priority3))
%timeit is_function_safe_string(inspect.getsource(priority4))
%timeit is_function_safe_string(inspect.getsource(priority5))
%timeit is_function_safe_string(inspect.getsource(priority6))
%timeit is_function_safe_string(inspect.getsource(priority7))
%timeit is_function_safe_string(inspect.getsource(priority8))
%timeit is_function_safe_string(inspect.getsource(priority9))
%timeit is_function_safe_string(inspect.getsource(priority10))
%timeit is_function_safe_string(inspect.getsource(priority11))

134 µs ± 2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
153 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
197 µs ± 4.11 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
193 µs ± 1.77 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
81.6 µs ± 312 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
164 µs ± 687 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
153 µs ± 6.93 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
141 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
110 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
299 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
296 µs ± 3.03 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
