In [48]:
class Scalar():
    def __init__(self, data: int, _children=(), _op=''):
        self.data = data
        self.grad = 0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        
    def __add__(self, other: Scalar | int):
        other = other if isinstance(other, Scalar) else Scalar(other)
        out = Scalar((self.data + other.data), (self, other), '+')

        def _backward():
            self.grad += out.grad
            other.grad += out.grad

        out._backward = _backward
        return out
    
    def __radd__(self, other):
        return self + other 

    def backward(self):
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)

        self.grad = 1
        for v in reversed(topo):
            v._backward()

    def __repr__(self):
        return f"Scalar(data={self.data}, grad={self.grad}{f', {self._op}' if len(self._op) > 0 else ''})"

In [49]:
x = Scalar(2)
y = Scalar(5)
z = x + y
z, y, x

(Scalar(data=7, grad=0, +), Scalar(data=5, grad=0), Scalar(data=2, grad=0))

In [50]:
z.backward()
z, y, x

(Scalar(data=7, grad=1, +), Scalar(data=5, grad=1), Scalar(data=2, grad=1))