# Tape-Based

In [26]:

import math

global_graph : list[tuple['BWD', callable]] = []
var_list : list['BWD'] = []

class BWD:
    def __init__(self, value):
        self.value : float = value
        self.grad : float = 0
        var_list.append(self)

    def __repr__(self):
        return f"BWD({self.value:.3f}, grad={self.grad:.3f})"

    def __add__(self, other):
        other = other if isinstance(other, BWD) else BWD(other)

        result = BWD(self.value + other.value)

        # z = h(y), y = a + b => dz/da = upperGradVal * 1, dz/db = upperGradVal * 1
        def backward_fn(upperGradVal):
            self.grad += upperGradVal
            other.grad += upperGradVal

        global_graph.append((result, backward_fn))
        return result

    def __radd__(self, other):
        # const + BWD
        return self.__add__(other)

    def __sub__(self, other):
        other = other if isinstance(other, BWD) else BWD(other)

        result = BWD(self.value - other.value)

        # z = h(y), y = a - b => dz/da = upperGradVal * 1, dz/db = upperGradVal * -1
        def backward_fn(upperGradVal):
            self.grad += upperGradVal
            other.grad -= upperGradVal

        global_graph.append((result, backward_fn))
        return result

    def __rsub__(self, other):
        # const - BWD
        other = other if isinstance(other, BWD) else BWD(other)
        return other.__sub__(self)

    def __mul__(self, other):
        other = other if isinstance(other, BWD) else BWD(other)

        result = BWD(self.value * other.value)

        # z = h(y), y = a * b => dz/da = upperGradVal * b, dz/db = upperGradVal * a
        def backward_fn(upperGradVal):
            self.grad += upperGradVal * other.value
            other.grad += upperGradVal * self.value

        global_graph.append((result, backward_fn))
        return result
    
    def __rmul__(self, other):
        # const * BWD
        return self.__mul__(other)

    def __truediv__(self, other):
        other = other if isinstance(other, BWD) else BWD(other)

        result = BWD(self.value / other.value)

        # z = h(y), y = a / b => dz/da = upperGradVal * 1/b, dz/db = upperGradVal * -a/b^2
        def backward_fn(upperGradVal):
            self.grad += upperGradVal / other.value
            other.grad -= upperGradVal * self.value / (other.value ** 2)

        global_graph.append((result, backward_fn))
        return result
    
    def __rtruediv__(self, other):
        # const / BWD
        other = other if isinstance(other, BWD) else BWD(other)
        return other.__truediv__(self)

    @staticmethod
    def sin(var):
        var = var if isinstance(var, BWD) else BWD(var)

        result = BWD(math.sin(var.value))

        # z = h(y), y = sin(a) => dz/da = upperGradVal * cos(a)
        def backward_fn(upperGradVal):
            var.grad += upperGradVal * math.cos(var.value)

        global_graph.append((result, backward_fn))
        return result
    
    @staticmethod
    def cos(var):
        var = var if isinstance(var, BWD) else BWD(var)

        result = BWD(math.cos(var.value))

        # z = h(y), y = cos(a) => dz/da = upperGradVal * -sin(a)
        def backward_fn(upperGradVal):
            var.grad -= upperGradVal * math.sin(var.value)

        global_graph.append((result, backward_fn))
        return result
    
    @staticmethod
    def sqrt(var):
        var = var if isinstance(var, BWD) else BWD(var)

        result = BWD(math.sqrt(var.value))

        # z = h(y), y = sqrt(a) => dz/da = upperGradVal * 1/(2*sqrt(a))
        def backward_fn(upperGradVal):
            var.grad += upperGradVal * (1 / (2 * math.sqrt(var.value)))

        global_graph.append((result, backward_fn))
        return result
    
def clear_gradients():
    for var in var_list:
        var.grad = 0

In [27]:
# customize with 1~n input and 1~m output
# compute [∂y_1 / ∂x_1, ..., ∂y_1 / ∂x_n ] for i = 1~n, j = 1~m
#         [∂y_2 / ∂x_1, ..., ∂y_2 / ∂x_n ]
#         [...,         ...,        ...  ]
#         [∂y_m / ∂x_1, ..., ∂y_m / ∂x_n ]
def vjp(customfunc):
    def wrapper_func(*args):
        # empty the previous computation graph & variables list
        global_graph.clear()
        var_list.clear()

        # transform inputs to Variables and call the function
        # after execution, the global_graph and variable_list is built
        var_inputs = [BWD(arg) for arg in args]
        outputs = customfunc(*var_inputs)


        # deal with output values
        if not isinstance(outputs, (list, tuple)):
            outputs = [outputs]
        out_values = tuple(output.value for output in outputs)
        if len(out_values) == 1:
            out_values = out_values[0]
        

        # compute the jacobian matrix, iterate over each output y_i,
        # and propagate gradients through the graph
        # [∂y_i / ∂x_1, ..., ∂y_i / ∂x_n ]
        grads = []
        for output in outputs:
            assert isinstance(output, BWD), "Output must be a BWD"
            
            # Reset all gradients to zero before computing gradients for this output
            clear_gradients()

            # backpropagation
            output.grad = 1.0
            for curNode, backward_fn in reversed(global_graph):
                if curNode.grad != 0:
                    backward_fn(curNode.grad)
            
            # after backpropagation, collect gradients w.r.t. inputs
            row = [inp.grad for inp in var_inputs]
            grads.append(row)
        
        if len(grads) == 1:
            return out_values, grads[0]
        else:
            return out_values, grads

    return wrapper_func

In [28]:
# Test 1: Simple polynomial function
def poly_function_fwd(x):
    return 3 * (x * x) + 2 * x + 1

grad_poly = vjp(poly_function_fwd)
val, grads = grad_poly(2)
print(f"Polynomial f(x) = 3x² + 2x + 1 at x=2:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(2) = 17, f'(2) = 14\n")

Polynomial f(x) = 3x² + 2x + 1 at x=2:
Value: 17, Gradient: [14.0]
Expected: f(2) = 17, f'(2) = 14



In [29]:
# Test 2: Trigonometric function
def trig_function(x):
    return BWD.sin(x) + BWD.cos(x)

grad_trig = vjp(trig_function)
val, grads = grad_trig(0)
print(f"Trigonometric f(x) = sin(x) + cos(x) at x=0:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(0) = 1, f'(0) = 1\n")

Trigonometric f(x) = sin(x) + cos(x) at x=0:
Value: 1.0, Gradient: [1.0]
Expected: f(0) = 1, f'(0) = 1



In [30]:
# Test 4: Chain rule validation
def chain_function(x):
    inner = x * x + 1
    return BWD.sin(inner)

grad_chain = vjp(chain_function)
val, grads = grad_chain(1)
print(f"Chain rule f(x) = sin(x² + 1) at x=1:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(1) = sin(2) ≈ 0.909, f'(1) = 2cos(2) ≈ -0.833\n")

Chain rule f(x) = sin(x² + 1) at x=1:
Value: 0.9092974268256817, Gradient: [-0.8322936730942848]
Expected: f(1) = sin(2) ≈ 0.909, f'(1) = 2cos(2) ≈ -0.833



In [31]:
# Test 5: Square root function
def sqrt_function(x):
    return BWD.sqrt(x)

grad_sqrt = vjp(sqrt_function)
val, grads = grad_sqrt(4)
print(f"Square root f(x) = √x at x=4:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(4) = 2, f'(4) = 0.25")

Square root f(x) = √x at x=4:
Value: 2.0, Gradient: [0.25]
Expected: f(4) = 2, f'(4) = 0.25


# Source-Transform

In [32]:
import ast
import inspect
import textwrap
import math

class WorkingSourceTransform:
    """
    可用的 Source Transformation 自动微分
    支持: +, -, *, /, sin, cos, sqrt
    """
    
    def __init__(self):
        self.temp_counter = 0
        self.variables = set()
        
    def grad(self, func):
        """
        将函数转换为计算梯度的版本
        返回装饰后的函数
        """
        # 获取函数源代码
        source = inspect.getsource(func)
        source = textwrap.dedent(source)
        
        # 解析 AST
        tree = ast.parse(source)
        func_def = tree.body[0]
        
        # 获取参数
        args = [arg.arg for arg in func_def.args.args]
        
        # 分析并生成代码
        generated_code = self._generate_grad_function(func_def, args)
        
        # 编译执行
        namespace = {'math': math}
        exec(generated_code, namespace)
        
        return namespace[f"{func.__name__}_grad"]
    
    def _generate_grad_function(self, func_def, args):
        """生成梯度函数代码"""
        func_name = func_def.name
        
        # 分析返回表达式
        return_stmt = func_def.body[0]  # 假设只有一个 return 语句
        expr = return_stmt.value
        
        # 生成前向和反向代码
        forward_code, grad_exprs = self._process_expression(expr, args)
        
        # 构建完整函数
        lines = []
        lines.append(f"def {func_name}_grad({', '.join(args)}):")
        lines.append('    """自动生成的梯度函数"""')
        
        # 前向计算
        for line in forward_code:
            lines.append(f"    {line}")
        
        # 梯度计算
        grad_list = []
        for arg in args:
            grad_expr = grad_exprs.get(arg, "0")
            grad_list.append(grad_expr)
        
        lines.append(f"    return result, [{', '.join(grad_list)}]")
        
        return '\n'.join(lines)
    
    def _process_expression(self, expr, args):
        """处理表达式，返回前向代码和梯度表达式"""
        if isinstance(expr, ast.BinOp):
            return self._process_binop(expr, args)
        elif isinstance(expr, ast.Call):
            return self._process_call(expr, args)
        elif isinstance(expr, ast.Name):
            # 变量
            var_name = expr.id
            forward_code = [f"result = {var_name}"]
            grad_exprs = {arg: "1" if arg == var_name else "0" for arg in args}
            return forward_code, grad_exprs
        elif isinstance(expr, ast.Constant):
            # 常数
            forward_code = [f"result = {expr.value}"]
            grad_exprs = {arg: "0" for arg in args}
            return forward_code, grad_exprs
        else:
            raise NotImplementedError(f"Expression type {type(expr)} not supported")
    
    def _process_binop(self, binop, args):
        """处理二元运算"""
        # 处理左右操作数
        left_forward, left_grads = self._process_expression(binop.left, args)
        right_forward, right_grads = self._process_expression(binop.right, args)
        
        # 获取操作数的值表达式
        left_val = self._extract_result_expr(left_forward)
        right_val = self._extract_result_expr(right_forward)
        
        # 合并前向代码
        forward_code = []
        forward_code.extend(left_forward[:-1])  # 除最后一行
        forward_code.extend(right_forward[:-1]) # 除最后一行
        
        op_type = type(binop.op).__name__
        
        if op_type == 'Add':
            # z = x + y, dz/da = dx/da + dy/da
            forward_code.append(f"result = {left_val} + {right_val}")
            grad_exprs = {}
            for arg in args:
                left_grad = left_grads.get(arg, "0")
                right_grad = right_grads.get(arg, "0")
                grad_exprs[arg] = f"({left_grad}) + ({right_grad})"
        
        elif op_type == 'Sub':
            # z = x - y, dz/da = dx/da - dy/da
            forward_code.append(f"result = {left_val} - {right_val}")
            grad_exprs = {}
            for arg in args:
                left_grad = left_grads.get(arg, "0")
                right_grad = right_grads.get(arg, "0")
                grad_exprs[arg] = f"({left_grad}) - ({right_grad})"
        
        elif op_type == 'Mult':
            # z = x * y, dz/da = dx/da * y + x * dy/da
            forward_code.append(f"result = {left_val} * {right_val}")
            grad_exprs = {}
            for arg in args:
                left_grad = left_grads.get(arg, "0")
                right_grad = right_grads.get(arg, "0")
                grad_exprs[arg] = f"({left_grad}) * ({right_val}) + ({left_val}) * ({right_grad})"
        
        elif op_type == 'Div':
            # z = x / y, dz/da = (dx/da * y - x * dy/da) / y²
            forward_code.append(f"result = {left_val} / {right_val}")
            grad_exprs = {}
            for arg in args:
                left_grad = left_grads.get(arg, "0")
                right_grad = right_grads.get(arg, "0")
                grad_exprs[arg] = f"(({left_grad}) * ({right_val}) - ({left_val}) * ({right_grad})) / (({right_val}) ** 2)"
        
        else:
            raise NotImplementedError(f"Operator {op_type} not supported")
        
        return forward_code, grad_exprs
    
    def _process_call(self, call, args):
        """处理函数调用"""
        if not isinstance(call.func, ast.Attribute):
            raise NotImplementedError("Only math.func() calls supported")
        
        if call.func.value.id != 'math':
            raise NotImplementedError("Only math module functions supported")
        
        func_name = call.func.attr
        
        if len(call.args) != 1:
            raise NotImplementedError("Only single-argument functions supported")
        
        # 处理参数
        arg_forward, arg_grads = self._process_expression(call.args[0], args)
        arg_val = self._extract_result_expr(arg_forward)
        
        # 前向计算
        forward_code = arg_forward[:-1]  # 除最后一行
        forward_code.append(f"result = math.{func_name}({arg_val})")
        
        # 梯度计算
        grad_exprs = {}
        for arg in args:
            arg_grad = arg_grads.get(arg, "0")
            
            if func_name == 'sin':
                # d/dx sin(f) = cos(f) * df/dx
                chain_rule = f"math.cos({arg_val}) * ({arg_grad})"
            elif func_name == 'cos':
                # d/dx cos(f) = -sin(f) * df/dx
                chain_rule = f"(-math.sin({arg_val})) * ({arg_grad})"
            elif func_name == 'sqrt':
                # d/dx sqrt(f) = 1/(2*sqrt(f)) * df/dx
                chain_rule = f"(1.0 / (2.0 * math.sqrt({arg_val}))) * ({arg_grad})"
            else:
                raise NotImplementedError(f"Function {func_name} not supported")
            
            grad_exprs[arg] = chain_rule
        
        return forward_code, grad_exprs
    
    def _extract_result_expr(self, forward_code):
        """从前向代码中提取结果表达式"""
        if not forward_code:
            return "0"
        
        last_line = forward_code[-1]
        if last_line.startswith("result = "):
            return last_line[9:]  # 去掉 "result = "
        else:
            # 如果不是 result = 的形式，返回变量名或表达式
            return last_line


In [34]:
# 使用示例和测试
def test_source_transform():
    print("Source Transform 自动微分测试")
    print("=" * 40)
    
    st = WorkingSourceTransform()
    
    # 测试1: 简单加法
    def f1(x, y):
        return x + y
    
    f1_grad = st.grad(f1)
    val, grad = f1_grad(2.0, 3.0)
    print(f"f1(x,y) = x + y")
    print(f"f1(2,3) = {val}, grad = {grad}")
    print(f"期望: 5, [1, 1] - {'✓' if val == 5 and grad == [1, 1] else '✗'}")
    print()
    
    # 测试2: 乘法
    def f2(x, y):
        return x * y
    
    f2_grad = st.grad(f2)
    val, grad = f2_grad(2.0, 3.0)
    print(f"f2(x,y) = x * y")
    print(f"f2(2,3) = {val}, grad = {grad}")
    print(f"期望: 6, [3, 2] - {'✓' if val == 6 and grad == [3.0, 2.0] else '✗'}")
    print()
    
    # 测试3: 复合函数
    def f3(x):
        return x * x + x
    
    f3_grad = st.grad(f3)
    val, grad = f3_grad(3.0)
    print(f"f3(x) = x² + x")
    print(f"f3(3) = {val}, grad = {grad}")
    print(f"期望: 12, [7] - {'✓' if val == 12 and abs(grad[0] - 7) < 1e-10 else '✗'}")
    print()
    
    # 测试4: 三角函数
    def f4(x):
        return math.sin(x)
    
    f4_grad = st.grad(f4)
    val, grad = f4_grad(0.0)
    print(f"f4(x) = sin(x)")
    print(f"f4(0) = {val}, grad = {grad}")
    expected_grad = math.cos(0.0)
    print(f"期望: 0, [1] - {'✓' if abs(val) < 1e-10 and abs(grad[0] - expected_grad) < 1e-10 else '✗'}")
    print()
    
    # 测试5: 复合三角函数
    def f5(x, y):
        return math.sin(x) + math.cos(y)
    
    f5_grad = st.grad(f5)
    val, grad = f5_grad(0.0, 0.0)
    print(f"f5(x,y) = sin(x) + cos(y)")
    print(f"f5(0,0) = {val}, grad = {grad}")
    expected_val = math.sin(0) + math.cos(0)
    expected_grad = [math.cos(0), -math.sin(0)]
    val_ok = abs(val - expected_val) < 1e-10
    grad_ok = abs(grad[0] - expected_grad[0]) < 1e-10 and abs(grad[1] - expected_grad[1]) < 1e-10
    print(f"期望: {expected_val}, {expected_grad} - {'✓' if val_ok and grad_ok else '✗'}")
    print()
    
    # 测试6: 除法和平方根
    def f6(x):
        return math.sqrt(x) / x
    
    f6_grad = st.grad(f6)
    val, grad = f6_grad(4.0)
    print(f"f6(x) = sqrt(x) / x")
    print(f"f6(4) = {val}, grad = {grad}")
    
    # 手动验证: f(x) = x^(1/2) / x = x^(-1/2)
    # f'(x) = -1/2 * x^(-3/2) = -1/(2*x^(3/2))
    expected_val = math.sqrt(4) / 4
    expected_grad = -1.0 / (2.0 * (4.0 ** 1.5))
    val_ok = abs(val - expected_val) < 1e-10
    grad_ok = abs(grad[0] - expected_grad) < 1e-10
    print(f"期望: {expected_val}, [{expected_grad}] - {'✓' if val_ok and grad_ok else '✗'}")

test_source_transform()

Source Transform 自动微分测试
f1(x,y) = x + y
f1(2,3) = 5.0, grad = [1, 1]
期望: 5, [1, 1] - ✓

f2(x,y) = x * y
f2(2,3) = 6.0, grad = [3.0, 2.0]
期望: 6, [3, 2] - ✓

f3(x) = x² + x
f3(3) = 12.0, grad = [7.0]
期望: 12, [7] - ✓

f4(x) = sin(x)
f4(0) = 0.0, grad = [1.0]
期望: 0, [1] - ✓

f5(x,y) = sin(x) + cos(y)
f5(0,0) = 1.0, grad = [1.0, 0.0]
期望: 1.0, [1.0, -0.0] - ✓

f6(x) = sqrt(x) / x
f6(4) = 0.5, grad = [-0.0625]
期望: 0.5, [-0.0625] - ✓
