In [28]:
import math

class FWD:
    def __init__(self, value:float, derivative:float=0):
        if isinstance(value, FWD):
            self.value = value.value
            self.derivative = value.derivative
            return
        assert isinstance(value, (int, float)), f"{value}, f{type(value)}"
        assert isinstance(derivative, (int, float))
        self.value = value
        self.derivative = derivative

    # z = x(θ) + y(θ), dz/dθ = dx/dθ + dy/dθ
    def __add__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value + other.value, self.derivative + other.derivative)
        else:
            return FWD(self.value + other, self.derivative)
        
    def __radd__(self, other):
        return self.__add__(other)

    # z = x(θ) - y(θ), dz/dθ = dx/dθ - dy/dθ
    def __sub__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value - other.value, self.derivative - other.derivative)
        else:
            return FWD(self.value - other, self.derivative)
        
    def __rsub__(self, other):
        return FWD(other - self.value, -self.derivative)

    # z = x(θ) * y(θ), dz/dθ = x(θ) * dy/dθ + y(θ) * dx/dθ
    def __mul__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value * other.value, self.derivative * other.value + self.value * other.derivative)
        else:
            return FWD(self.value * other, self.derivative * other)

    def __rmul__(self, other):
        return self.__mul__(other)

    # z = x(θ) / y(θ), dz/dθ = (y(θ) * dx/dθ - x(θ) * dy/dθ) / (y(θ))^2
    def __truediv__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value / other.value, (other.value * self.derivative - self.value * other.derivative) / (other.value ** 2))
        else:
            return FWD(self.value / other, self.derivative / other)

    def __rtruediv__(self, other):
        return FWD(other / self.value, -other * self.derivative / (self.value ** 2))

    # y = sin(x(θ)), dy/dθ = cos(x(θ)) * dx/dθ
    @staticmethod
    def sin(x):
        return FWD(math.sin(x.value), math.cos(x.value) * x.derivative)

    # y = cos(x(θ)), dy/dθ = -sin(x(θ)) * dx/dθ
    @staticmethod
    def cos(x):
        return FWD(math.cos(x.value), -math.sin(x.value) * x.derivative)

    # y = sqrt(x(θ)), dy/dθ = 0.5 / sqrt(x(θ)) * dx/dθ
    @staticmethod
    def sqrt(x):
        return FWD(math.sqrt(x.value), 0.5 / math.sqrt(x.value) * x.derivative)

In [None]:
# customize with 1~n input and 1~m output
# compute [∂y_1 / ∂x_1, ..., ∂y_1 / ∂x_n ] for i = 1~n, j = 1~m
#         [∂y_2 / ∂x_1, ..., ∂y_2 / ∂x_n ]
#         [...,         ...,        ...  ]
#         [∂y_m / ∂x_1, ..., ∂y_m / ∂x_n ]
def jvp(customfunc):
    def wrapper_func(*args):
        argsNum = len(args)
        
        grads : list[float] = [] 
        out : tuple[float, ...] | float = None # y_1, ... , y_n

        # loop through dy / dx_i for i = 1 ~ n, one column of jacobian matrix each iteration
        # [∂y_1 / ∂x_i]
        # [∂y_2 / ∂x_i]
        # [...        ]
        # [∂y_m / ∂x_i]
        for cur_index in range(argsNum):
            
            # Create FWD nodes, one column of identity matrix
            # [ 0 ]
            # [ 1 ] at index i
            # [...]
            # [ 0 ]
            fwdNodes : list[FWD] = []
            for index, arg in enumerate(args):
                if index == cur_index:
                    fwdNodes.append(FWD(arg, 1))
                else:
                    fwdNodes.append(FWD(arg, 0))
            
            
            # [(y_1, ∂y_1 / ∂x_i)]
            # [(y_2, ∂y_2 / ∂x_i)]
            # [...               ]
            # [(y_m, ∂y_m / ∂x_i)]
            y : tuple[FWD, ...] | float = customfunc(*fwdNodes)
            
            
            # split the output into valude and derivative
            if out is None:
                if isinstance(y, tuple):
                    out = tuple(e.value for e in y)
                else:
                    out = y.value
            if isinstance(y, tuple):
                grads.append(tuple(e.derivative for e in y))
            else:
                grads.append(y.derivative)

        return out, grads
    
    return wrapper_func

In [30]:
# Test 1: Simple polynomial function
def poly_function_fwd(x):
    return 3 * (x * x) + 2 * x + 1

grad_poly = jvp(poly_function_fwd)
val, grads = grad_poly(2)
print(f"Polynomial f(x) = 3x² + 2x + 1 at x=2:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(2) = 17, f'(2) = 14\n")

# Test 2: Trigonometric function
def trig_function(x):
    return FWD.sin(x) + FWD.cos(x)

grad_trig = jvp(trig_function)
val, grads = grad_trig(0)
print(f"Trigonometric f(x) = sin(x) + cos(x) at x=0:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(0) = 1, f'(0) = 1\n")

# Test 3: Multiple variables with products and divisions
def complex_function_fwd(x, y):
    product = x * y
    sum_xy = x + y
    quotient = product / sum_xy
    squares = (x * x) + (y * y)
    return quotient, squares

grad_complex = jvp(complex_function_fwd)
val, grads = grad_complex(3, 4)
print(f"Complex function at (3,4):")
print(f"f1(x,y) = xy/(x+y), f2(x,y) = x² + y²")
print(f"Values: {val}")
print(f"Gradients: {grads}")
print(f"Expected: f1(3,4) = 12/7 ≈ 1.714, f2(3,4) = 25")
print(f"Expected gradients: ∂f1/∂x = 16/49 ≈ 0.327, ∂f1/∂y = 9/49 ≈ 0.184")
print(f"                    ∂f2/∂x = 6, ∂f2/∂y = 8\n")

# Test 4: Chain rule validation
def chain_function(x):
    inner = x * x + 1
    return FWD.sin(inner)

grad_chain = jvp(chain_function)
val, grads = grad_chain(1)
print(f"Chain rule f(x) = sin(x² + 1) at x=1:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(1) = sin(2) ≈ 0.909, f'(1) = 2cos(2) ≈ -0.833\n")

# Test 5: Square root function
def sqrt_function(x):
    return FWD.sqrt(x)

grad_sqrt = jvp(sqrt_function)
val, grads = grad_sqrt(4)
print(f"Square root f(x) = √x at x=4:")
print(f"Value: {val}, Gradient: {grads}")
print(f"Expected: f(4) = 2, f'(4) = 0.25")

Polynomial f(x) = 3x² + 2x + 1 at x=2:
Value: 17, Gradient: [14]
Expected: f(2) = 17, f'(2) = 14

Trigonometric f(x) = sin(x) + cos(x) at x=0:
Value: 1.0, Gradient: [1.0]
Expected: f(0) = 1, f'(0) = 1

Complex function at (3,4):
f1(x,y) = xy/(x+y), f2(x,y) = x² + y²
Values: (1.7142857142857142, 25)
Gradients: [(0.32653061224489793, 6), (0.1836734693877551, 8)]
Expected: f1(3,4) = 12/7 ≈ 1.714, f2(3,4) = 25
Expected gradients: ∂f1/∂x = 16/49 ≈ 0.327, ∂f1/∂y = 9/49 ≈ 0.184
                    ∂f2/∂x = 6, ∂f2/∂y = 8

Chain rule f(x) = sin(x² + 1) at x=1:
Value: 0.9092974268256817, Gradient: [-0.8322936730942848]
Expected: f(1) = sin(2) ≈ 0.909, f'(1) = 2cos(2) ≈ -0.833

Square root f(x) = √x at x=4:
Value: 2.0, Gradient: [0.25]
Expected: f(4) = 2, f'(4) = 0.25
