In [1]:
import sys
sys.path.append('../../pyutils')

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

import metrics

In [2]:
class Node:
    
    '''
    shape - tuple, shape of node array, must be complete (no -1)
    '''
    def __init__(self, preds, shape):
        self.preds = list(preds)
        self.succs = list()
        self.shape = shape
        self.grads = dict()
        
        for p in self.preds:
            p.succs.append(self)
        
    '''
    Evaluate and return the value of the node
    '''
    def data(self):
        raise Exception('Node::data() Not implemented')
    
    '''
    Build new nodes to compute the gradient of the predecessors
    dout - node to compute the gradient of the output dE_out
    returns node to compute gradient of predecssor i dE_pred[i]
    '''
    def backward(self, dout, i):
        raise Exception('Node::backward() Not implemented')
        
    '''
    Check if node self in an ancestor of x
    '''
    def is_pred_rec(self, x):
        if self is x:
            return True
        for p in self.succs:
            if p.is_pred_rec(x):
                return True
        return False

In [3]:
class Value(Node):
    
    def __init__(self, val):
        if type(val) is not np.ndarray:
            val = np.array(val)
            
        super().__init__([], val.shape)
        self.val = val
    
    def data(self):
        return self.val
    
class Transpose(Node):
    
    def __init__(self, x):
        super().__init__([x], (x.shape[1], x.shape[0]))
        self.x = x
        
    def data(self):
        return self.x.data().T     
    
class Matmul(Node):
    
    def __init__(self, x, y):
        super().__init__([x, y], (x.shape[0], y.shape[1]))
        self.x = x
        self.y = y
    
    def data(self):
        return self.x.data() @ self.y.data()
    
    def backward(self, dout, i):
        if i == 0:
            return Matmul(dout, Transpose(self.y))
        else:
            return Matmul(Transpose(self.x), dout)
        
class Sum(Node):
    
    def __init__(self, x):
        super().__init__([x], ())
        self.x = x
        
    def data(self):
        return np.sum(self.x.data())
        
    def backward(self, dout, i):
        return Multiply(Value(np.ones(self.x.shape)), dout)
    
class Add(Node):
    
    def __init__(self, x, y):
        super().__init__([x, y], x.shape)
        self.x = x
        self.y = y
    
    def data(self):
        return self.x.data() + self.y.data()
    
class Multiply(Node):
    
    def __init__(self, x, y):
        super().__init__([x, y], x.shape)
        self.x = x
        self.y = y
    
    def data(self):
        return self.x.data() * self.y.data()

In [4]:
x = np.random.randn(12, 7)
y = np.random.randn(7, 86)

nx = Value(x)
ny = Value(y)
nz = Matmul(nx, ny)
nloss = Sum(nz)

tx = torch.from_numpy(x).requires_grad_(True)
ty = torch.from_numpy(y).requires_grad_(True)
tz = tx @ ty
tloss = torch.sum(tz)

print(nloss.data(), tloss.data.numpy())
print(metrics.tdist(nz.data(), tz.data.numpy()))

-66.33951850504663 -66.33951850504661
0.0


In [5]:
def get_grad(out, x):
    
    # 1 Basic checks
    if len(out.shape) != 0:
        raise Exception('The output tensor must be a scalar')
    if not x.is_pred_rec(out):
        #also possible to return a 0 tensor
        raise Exception('x is not an ancestor or out')
    if x in out.grads:
        return out.grads[x]
    if x is out:
        g = Value(1.)
        out.grads[x] = g
        return g
    
    res = None #res = Value(np.zeros(x.shape))
    
    for yi in x.succs:
        if not yi.is_pred_rec(out):
            continue
            
        g_yi = get_grad(out, yi)
        g_x = yi.backward(g_yi, yi.preds.index(x))
        
        if res is None: #trick to avoid addition by 0
            res = g_x
        else:
            res = Add(res, g_x)
    
    if res is None: #should never happen
        print(x, x.shape)
        raise Exception('Backprop internal error')
    
    out.grads[x] = res
    return res

In [6]:
ndx = get_grad(nloss, nx)
ndy = get_grad(nloss, ny)

tloss.backward()
tdx = tx.grad
tdy = ty.grad

print(metrics.tdist(ndx.data(), tdx.data.numpy()))
print(metrics.tdist(ndy.data(), tdy.data.numpy()))

0.0
0.0
