In [19]:
import numpy as np
import torch

from lib.Tensor import Tensor

In [54]:
data1, data2 = np.random.randn(2, 3), np.random.randn(2, 3) # random data

# ! Custom Tensor class
t1 = Tensor(data1, requires_grad=True) # init
t2 = Tensor(data2, requires_grad=True) # init
pt1 = torch.Tensor(data1).double().requires_grad_(True) # pytorch
pt2 = torch.Tensor(data2).double().requires_grad_(True) # pytorch

def ops(a, b):
    c = a * b
    return c.sum()

def ops_2(a, b):
    # ! THIS FAILS
    c = a + b # add, radd
    d = c - b # sub, rsub
    e = c * 2 # mul scalar
    x = d * e # mul tensor (elementwise)
    return x.sum()

In [39]:
result1 = ops(t1, t2)
result2 = ops(pt1, pt2)

In [40]:
result1.backward()
result2.backward()

In [45]:
def check_tol(t1, t2, pt1, pt2):
    tols = [1e-10, 1e-8, 1e-6, 1e-4, 1e-2]
    for tol in tols:
        print(f"tol: {tol}")
        print(np.allclose(t1.grad, pt1.grad.numpy(), atol=tol))
        print(np.allclose(t2.grad, pt2.grad.numpy(), atol=tol))

In [46]:
print(f"t1 grad {t1.grad}")
print(f"pt1 grad {pt1.grad.numpy()}")
print(f"t2 grad {t2.grad}") 
print(f"pt2 grad {pt2.grad.numpy()}")

t1 grad [[-1.37240982 -1.26545133  1.96119259]
 [ 0.16582696 -1.53786034  0.61533814]]
pt1 grad [[-1.37240982 -1.26545131  1.96119261]
 [ 0.16582696 -1.53786039  0.61533815]]
t2 grad [[ 1.20716919 -1.57160246  0.36859056]
 [ 0.19955592 -0.35067234 -0.56818483]]
pt2 grad [[ 1.20716918 -1.57160246  0.36859056]
 [ 0.19955592 -0.35067233 -0.56818485]]


In [47]:
check_tol(t1, t2, pt1, pt2)

tol: 1e-10
True
True
tol: 1e-08
True
True
tol: 1e-06
True
True
tol: 0.0001
True
True
tol: 0.01
True
True


In [53]:
result3 = ops_2(t1, t2)
result4 = ops_2(pt1, pt2)

result3.backward()
result4.backward()

print(f"t1 grad {t1.grad}")
print(f"pt1 grad {pt1.grad.numpy()}")

print(f"t2 grad {t2.grad}")
print(f"pt2 grad {pt2.grad.numpy()}")

t1 grad [[ 0.25706986 -2.53289069 -8.23273654]
 [ 0.53637653 -4.95895261  2.1283796 ]]
pt1 grad [[-0.81876063 -1.04067016 -5.24569333]
 [ 0.18049803 -3.24273695  1.55202657]]
t2 grad [[-0.81876055 -1.04067013 -5.24569314]
 [ 0.18049802 -3.24273697  1.55202658]]
pt2 grad [[-1.89459097  0.45155042 -2.25864983]
 [-0.17538048 -1.52652133  0.97567356]]
