In [None]:

import sys, random
sys.path.append("./build")
from cpp_custom_bind import *
import torch

EPS = 1e-4
_LO_N, _HI_N = -5, 5
_LO_R, _HI_R = 4, 10

In [None]:
def rand_matrix(rows, cols, lo=_LO_N, hi=_HI_N):
    return [[random.uniform(lo, hi) for _ in range(cols)] for _ in range(rows)]

def rand_shape(min_r=_LO_R, max_r=_HI_R, min_c=_LO_R, max_c=_HI_R):
    return random.randint(min_r, max_r), random.randint(min_c, max_c)

def is_close(a, b, eps=EPS):
    return torch.all(torch.abs(torch.tensor(a) - torch.tensor(b)) < eps)

def sample_unary():
    r, c = rand_shape()
    return [rand_matrix(r, c)]

def sample_binary_same():
    r, c = rand_shape()
    return [rand_matrix(r, c), rand_matrix(r, c)]

def sample_matmul():
    m = random.randint(4, 8)
    k = random.randint(4, 8)
    n = random.randint(4, 8)
    return [rand_matrix(m, k), rand_matrix(k, n)]

def _to_col_major(mat):
    return torch.tensor(mat).T.flatten().tolist()

def _from_col_major(flat, like):
    t = torch.tensor(flat).reshape(torch.tensor(like).T.shape).T
    return t.tolist()


In [None]:
def make_cpp_var(mat, requires_grad=True):
    ten = cTensor(_to_col_major(mat), list(torch.tensor(mat).shape))
    return cVariable(ten, requires_grad)

def compute_grads(cpp_op, torch_op, mats, *extra):
    torch_vars = [torch.tensor(m, dtype=torch.float64, requires_grad=True) for m in mats]
    torch_out = torch_op(*torch_vars, *extra)
    torch_out.backward(torch.ones_like(torch_out, dtype=torch.float64))
    torch_grads = [v.grad.tolist() for v in torch_vars]
    cpp_vars = [make_cpp_var(m) for m in mats]
    cpp_op(*cpp_vars, *extra)
    cpp_grads = [_from_col_major(v.grad.data, m) for v, m in zip(cpp_vars, mats)]
    return cpp_grads, torch_grads

def run_test(name, cpp_op, torch_op, sampler, *extra):
    mats = sampler()
    cpp_grads, torch_grads = compute_grads(cpp_op, torch_op, mats, *extra)
    for cg, tg in zip(cpp_grads, torch_grads):
        if not is_close(cg, tg):
            print(cg)
            print(tg)
        assert is_close(cg, tg), f"{name} failed, {torch.max(torch.abs(torch.tensor(cg) - torch.tensor(tg)))}"

In [None]:


TEST_CASES = [
    ("add", add, torch.add, sample_binary_same),
    ("sub", sub, torch.sub, sample_binary_same),
    ("mul", mul, torch.mul, sample_binary_same),
    ("div", div, torch.div, sample_binary_same),
    ("relu", relu, torch.relu, sample_unary),
    ("exp", exp, torch.exp, sample_unary),
    ("log", log, torch.log, sample_unary),
    ("matmul", matmul, torch.matmul, sample_matmul),
    ("transpose", transpose, lambda x: x.T, sample_unary),
    ("sum_axis0", lambda x, axis: sum(x, axis), lambda t, axis: torch.sum(t, dim=axis), sample_unary, 0),
    ("sum_axis1", lambda x, axis: sum(x, axis), lambda t, axis: torch.sum(t, dim=axis), sample_unary, 1),
]

for case in TEST_CASES:
    name, cpp_op, torch_op, sampler, *extra = case
    run_test(name, cpp_op, torch_op, sampler, *extra)

print("All gradient checks passed")