# Imports

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
import graphviz

%matplotlib inline

# Utils

# Exercise

We just implemented the backward function for the `__add__`. Your goal is to modify this class to implement the `_backward()` for the remaining operations: subtraction, multiplication, division, and tanh.

In [4]:
class Value:

    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0 # at initialization every Value does not impact the output
        self._backward = lambda: None # this method does the chain rule and stores how it transmits the output's gradient into the inputs' gradient of the current node
        self._prev = set(_children)
        self._op = _op
        self.label = label

    def __repr__(self):
        return f"Value({self.data})"
    
    def __add__(self, other: Value):
        out =  Value(self.data + other.data, (self,other), '+')

        def _backward():
            self.grad += 1.0 * out.grad
            other.grad += 1.0 * out.grad
        out._backward = _backward
        return out
    
    def __mul__(self, other: Value):
        out =  Value(self.data * other.data, (self,other), '*')
        return out

    def __sub__(self, other: Value):
        out = Value(self.data - other.data, (self,other), '-')
        return out
    
    def __truediv__(self, other: Value):
        out = Value(self.data / other.data, (self,other), '/')
        return out
    
    def tanh(self):
        x = self.data
        t = (math.exp(2*x)-1) / (math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
        return out 

test

In [None]:
a = Value(3)
b = Value(6)

try: 
    assert (a * b).data == 18, 'Something Went wrong with multiplication'
    print('You got multiplication right!')
except:
    print('Something Went wrong with multiplication')

try:
    assert (b / a).data == 2, 'Something Went wrong with division'
    print('You got division right!')
except:
    print('Something Went wrong with division')

try:
    assert (a - b).data == -3, 'Something Went wrong with subtraction'
    print('You got subtraction right!')
except:
    print('Something Went wrong with subtraction')  