In [2]:
import ast
import inspect
from pathlib import Path

In [3]:
kernel_path = Path('vector_add_kernel.py')


In [19]:
import ast
import importlib.util
import tempfile
import inspect

# # Assume we have an ast.FunctionDef object named func_def
# # This is just an example, in real use you would parse your function code into an AST
# func_def = ast.parse("def foo(): return 'Hello, World!'").body[0]

# # Generate the source code from the ast.FunctionDef object
# source_code = ast.unparse(func_def)

# # Write the source code to a temporary file
# with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as tmp:
#     tmp.write(source_code)
#     tmp_path = tmp.name

# Load the temporary file as a module
spec = importlib.util.spec_from_file_location('temp_module', kernel_path)
temp_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(temp_module)

with open(kernel_path, 'r') as f:
    source_code = f.read()  
    tree = ast.parse(source_code)


In [20]:
fns = [f for f in tree.body if isinstance(f, ast.FunctionDef)]

In [22]:
fn_name = fns[0].name

In [23]:
getattr(temp_module, fn_name)

<function temp_module.add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: triton.language.core.constexpr)>

In [24]:
fn = temp_module.add_kernel

In [25]:
fn

<function temp_module.add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: triton.language.core.constexpr)>

In [9]:

# Now you can use inspect.getsource to get the source code
print(inspect.getsource(fn))  # Outputs: def foo(): return 'Hello, World!'


def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE:  so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:

    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    

In [15]:
lines = inspect.getsourcelines(fn) 

In [18]:
fn.__globals__

{'__name__': 'temp_module',
 '__doc__': None,
 '__package__': '',
 '__loader__': <_frozen_importlib_external.SourceFileLoader at 0x7ff54c4096d0>,
 '__spec__': ModuleSpec(name='temp_module', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7ff54c4096d0>, origin='vector_add_kernel.py'),
 '__file__': 'vector_add_kernel.py',
 '__cached__': '__pycache__/vector_add_kernel.cpython-39.pyc',
 '__builtins__': {'__name__': 'builtins',
  '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.",
  '__package__': '',
  '__loader__': _frozen_importlib.BuiltinImporter,
  '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'),
  '__build_class__': <function __build_class__>,
  '__import__': <function __import__>,
  'abs': <function abs(x, /)>,
  'all': <function all(iterable, /)>,
  'any': <function any(iterable, /)>,
  'ascii': <function ascii(obj,

In [54]:
tree = ast.parse(kernel_path.read_text())

In [55]:
import types

# Parse the function definition into an ast.FunctionDef object
func_def = ast.parse('def foo(x):\n    return x + 1').body[0]

# Compile the function definition into a code object
code = compile(func_def, '<string>', 'exec')

# Create a Python function object from the code object
foo = types.FunctionType(code, globals(), 'foo')

# Call the function
print(foo(5)) # prints 6

TypeError: expected Module node, got FunctionDef

In [62]:
import ast

code = """
import triton.language as tl

def add_kernel(
    x_ptr, # *Pointer* to first input vector.
    y_ptr, # *Pointer* to second input vector.
    output_ptr, # *Pointer* to output vector.
    n_elements, # Size of the vector.
    BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
    # NOTE: so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:

    pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)
"""
tree = ast.parse(code)

functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
namespace = {}
exec(code, namespace)


In [66]:
fn_name = [f.name for f in functions][0]

In [67]:
fn = namespace[fn_name] 

In [73]:
inspect.signature(fn)

<Signature (x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: triton.language.core.constexpr)>

In [None]:

# Parse the code into an AST
tree = ast.parse(code)

# Extract the function definition from the AST
func_def = tree.body[0]

# Compile the function definition into a Python function object
# func = compile(func_def, "<string>", "exec")

# Extract the globals from the AST
# globals = {name: value for name, value in func_def.body[0].items()}

# Add the function to the globals
# globals["add_kernel"] = func

# # Execute the function
# result = eval(func, globals)

# print(result)

In [57]:
func_def

<ast.Import at 0x7f73f3eb1fa0>

In [39]:
f = [n for n in nodes if isinstance(n, ast.FunctionDef)][0]

In [48]:
f.args.args[0].arg


'x_ptr'

In [52]:
f_def = f.body[0].value

In [53]:
code = compile(f_def, '<string>', 'exec')

TypeError: expected Module node, got Call

In [10]:
fn = exec(kernel_path.read_text())

2023-10-31 20:20:58.819 | DEBUG    | triton.compiler.debugging:wrapper:33 - args
2023-10-31 20:20:58.820 | DEBUG    | triton.compiler.debugging:wrapper:34 - kwargs
2023-10-31 20:20:58.821 | DEBUG    | triton.compiler.debugging:wrapper:36 - function.return_value


In [22]:
import ast

func_def = ast.parse("def foo(): return 'Hello, World!'").body[0]


In [23]:
from triton.runtime.jit import JITFunction
code = compile(ast.Module(body=[func_def], type_ignores=[]), filename='<ast>', mode='exec')

# Execute the bytecode to define the function in the current scope
exec(code)


In [25]:
inspect.getsource(foo)

OSError: could not get source code

In [29]:
inspect.signature(add_kernel)

<Signature (x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: triton.language.core.constexpr)>

In [None]:
jitted = JITFunction(
                add_kernel,
            )