In [37]:
import math

class FWD:
    def __init__(self, value:float, derivative:float=0):
        if isinstance(value, FWD):
            self.value = value.value
            self.derivative = value.derivative
            return
        assert isinstance(value, (int, float)), f"{value}, f{type(value)}"
        assert isinstance(derivative, (int, float))
        self.value = value
        self.derivative = derivative

    # z = x(θ) + y(θ), dz/dθ = dx/dθ + dy/dθ
    def __add__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value + other.value, self.derivative + other.derivative)
        else:
            return FWD(self.value + other, self.derivative)
        
    def __radd__(self, other):
        return self.__add__(other)

    # z = x(θ) - y(θ), dz/dθ = dx/dθ - dy/dθ
    def __sub__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value - other.value, self.derivative - other.derivative)
        else:
            return FWD(self.value - other, self.derivative)
        
    def __rsub__(self, other):
        return FWD(other - self.value, -self.derivative)

    # z = x(θ) * y(θ), dz/dθ = x(θ) * dy/dθ + y(θ) * dx/dθ
    def __mul__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value * other.value, self.derivative * other.value + self.value * other.derivative)
        else:
            return FWD(self.value * other, self.derivative * other)

    def __rmul__(self, other):
        return self.__mul__(other)

    # z = x(θ) / y(θ), dz/dθ = (y(θ) * dx/dθ - x(θ) * dy/dθ) / (y(θ))^2
    def __truediv__(self, other):
        if isinstance(other, FWD):
            return FWD(self.value / other.value, (other.value * self.derivative - self.value * other.derivative) / (other.value ** 2))
        else:
            return FWD(self.value / other, self.derivative / other)

    def __rtruediv__(self, other):
        return FWD(other / self.value, -other * self.derivative / (self.value ** 2))

    # y = sin(x(θ)), dy/dθ = cos(x(θ)) * dx/dθ
    @staticmethod
    def sin(x):
        return FWD(math.sin(x.value), math.cos(x.value) * x.derivative)

    # y = cos(x(θ)), dy/dθ = -sin(x(θ)) * dx/dθ
    @staticmethod
    def cos(x):
        return FWD(math.cos(x.value), -math.sin(x.value) * x.derivative)

    # y = sqrt(x(θ)), dy/dθ = 0.5 / sqrt(x(θ)) * dx/dθ
    @staticmethod
    def sqrt(x):
        return FWD(math.sqrt(x.value), 0.5 / math.sqrt(x.value) * x.derivative)

In [38]:
def forward(func):
    def gradient_func(*args):
        argsNum = len(args)
        grads:list[float] = []
      
        out: tuple[float, ...] |float| None = None
        for cur_index in range(argsNum):
            fwdNodes:list[FWD] = []
            for index, arg in enumerate(args):
                if index == cur_index:
                    fwdNodes.append(FWD(arg, 1))
                else:
                    fwdNodes.append(FWD(arg, 0))
            y:tuple[FWD, ...]|float = func(*fwdNodes)
            if out is None:
                if isinstance(y, tuple):
                    out = tuple(e.value for e in y)
                else:
                    out = y.value
            if isinstance(y, tuple):
                grads.append(tuple(e.derivative for e in y))
            else:
                grads.append(y.derivative)

        return out, grads
    
    return gradient_func

def simple_function(x:float, y:float, z:float):
    result = FWD(x) + FWD(y) * FWD(z)
    result2 = FWD(x) * FWD(y)
    return result, result2

if __name__ == "__main__":
    grad_func = forward(simple_function)
    f_val, f_grads = grad_func(4,3,2)
    print(f"f(4) = {f_val}")
    print(f"f'(4) = {f_grads}")

f(4) = (10, 12)
f'(4) = [(1, 3), (2, 4), (3, 0)]
