In [28]:
import ast, inspect

ALLOWED_FUNCTIONS = {'itertools', 'numpy', 'np'}
DISALLOWED_BUILTINS = {'print','__import__','breakpoint','compile','open','dir','eval','exec','globals','input','repr'}

class FunctionChecker(ast.NodeVisitor):
    def __init__(self):
        self.is_safe = True

    def visit_Import(self, node):
        for alias in node.names:
            if alias.name not in ALLOWED_FUNCTIONS:
                self.is_safe = False
            self.generic_visit(node)

    def visit_ImportFrom(self, node):
        if node.module not in ALLOWED_FUNCTIONS:
            self.is_safe = False
        self.generic_visit(node)

    def visit_Call(self, node):
        # Check for disallowed built-in function calls
        if isinstance(node.func, ast.Name):
            if node.func.id in DISALLOWED_BUILTINS:
                self.is_safe = False
        # Check if function calls are from allowed modules
        elif isinstance(node.func, ast.Attribute):
            if node.func.value.id not in ALLOWED_FUNCTIONS:
                self.is_safe = False
        self.generic_visit(node)

def is_function_safe(func):
    function_code = inspect.getsource(func)
    tree = ast.parse(function_code)
    checker = FunctionChecker()
    checker.visit(tree)
    return checker.is_safe

In [40]:
import numpy as np
import jax.numpy as jnp

def my_function():
    #import matplotlib.pyplot as plt
    #print(5)
    #eval('print(2)')
    return np.mean([1, 2, 3])

print(is_function_safe(my_function))

#my_function()

False
