## section 2: support for softmax

In [56]:
import math

In [53]:
from math import exp, log

class Value:
  
  def __init__(self, data, _children=(), _op='', label=''):
    self.data = data
    self.grad = 0.0
    self._backward = lambda: None
    self._prev = set(_children)
    self._op = _op
    self.label = label

  def __repr__(self):
    return f"Value(data={self.data})"
  
  def __add__(self, other):
    other = other if isinstance(other, Value) else Value(other)
    out = Value(self.data + other.data, (self, other), '+')
    
    def _backward():
      self.grad += 1.0 * out.grad
      other.grad += 1.0 * out.grad
    out._backward = _backward
    
    return out
  
  def __mul__(self, other):
    other = other if isinstance(other, Value) else Value(other)
    out = Value(self.data * other.data, (self, other))
    def _backward():
      self.grad += other.data * out.grad
      other.grad += self.data * out.grad
    out._backward = _backward
    return out
  def __pow__(self, other):
    assert isinstance(other, (int, float))
    
    out = Value(self.data ** (other), (self,) ,'**' )
    def _backward():
      self.grad += other * (self.data ** (other-1 )) * out.grad

    out._backward = _backward
    return out 
  def __rmul__(self, other): # other * self
      return self * other

  def __truediv__(self, other): # self / other
      return self * other**-1

  def __neg__(self): # -self
      return self * -1

  def __sub__(self, other): # self - other
      return self + (-other)

  def __radd__(self, other): # other + self
      return self + other
  def log(self):
    x = self.data
    out = Value(math.log(x), (self, ))
    def _backward():
      self.grad = (1.0 / self.data) * out.grad
    out._backward = _backward
    return out 
  def exp(self):
    x = self.data
    out = Value(math.exp(x), (self,), 'exp')
    def _backward():
      self.grad += out.data * out.grad
    out._backward = _backward
    return out
  def backward(self):   
    topo = []
    visited = set()
    def build_topo(v):
      if v not in visited:
        visited.add(v)
        for child in v._prev:
          build_topo(child)
        topo.append(v)
    build_topo(self)
    
    self.grad = 1.0
    for node in reversed(topo):
      node._backward()

In [66]:

def softmax(logits):
  counts = [logit.exp() for logit in logits]
  denominator = (sum(counts))
  out = [c / denominator for c in counts]
  return out

logits = [Value(0.0), Value(3.0), Value(-2.0), Value(1.0)]
probs = softmax(logits)
loss = -probs[3].log() # dim 3 acts as the label for this input example
loss.backward()
print(loss.data)

ans = [0.041772570515350445, 0.8390245074625319, 0.005653302662216329, -0.8864503806400986]
for dim in range(4):
  ok = 'OK' if abs(logits[dim].grad - ans[dim]) < 1e-5 else 'WRONG!'
  print(f"{ok} for dim {dim}: expected {ans[dim]}, yours returns {logits[dim].grad}")


2.1755153626167147
OK for dim 0: expected 0.041772570515350445, yours returns 0.041772570515350445
OK for dim 1: expected 0.8390245074625319, yours returns 0.8390245074625319
OK for dim 2: expected 0.005653302662216329, yours returns 0.005653302662216329
OK for dim 3: expected -0.8864503806400986, yours returns -0.8864503806400986


In [77]:
# verify the gradient using the torch library
# torch should give you the exact same gradient
import torch
logits = torch.tensor([0.0, 3.0, -2.0, 1.0], requires_grad=True)
probs = torch.softmax(logits, 0 )
loss = - probs[3].log()
loss.backward()
print(loss.data)

tensor(2.1755)
