In [101]:
import sys
sys.path.append("./build") 

from cpp_custom_bind import *
import torch

In [102]:
def torch_grads(mat_a, mat_b, op):
    torch_mat_a = torch.tensor(mat_a, requires_grad=True, dtype=torch.float64)
    torch_mat_b = torch.tensor(mat_b, requires_grad=True, dtype=torch.float64)
    torch_mat_c = op(torch_mat_a, torch_mat_b)

    torch_mat_c.backward(torch.ones_like(torch_mat_c, dtype=torch.float64))

    torch_a_grad = torch_mat_a.grad.tolist()
    torch_b_grad = torch_mat_b.grad.tolist()
    return torch_a_grad, torch_b_grad

In [103]:
def _matrix_to_column_major_list(mat):
    mat = torch.tensor(mat)
    fmat = torch.flatten(mat.T)
    return fmat.tolist()

def _column_major_list_to_matrix(fmat, mat): # mat is the original matrix
    fmat = torch.tensor(fmat)
    mat = torch.tensor(mat)
    rmat = fmat.reshape(mat.T.shape).T
    return rmat.tolist()

def _mat_shape(mat):
    return list(torch.tensor(mat).shape)

In [104]:
def cpp_grads(mat_a, mat_b, op):
    cpp_ten_a = cTensor(_matrix_to_column_major_list(mat_a), _mat_shape(mat_a))
    cpp_ten_b = cTensor(_matrix_to_column_major_list(mat_b), _mat_shape(mat_b))

    cpp_var_a = cVariable(cpp_ten_a, True)
    cpp_var_b = cVariable(cpp_ten_b, True)

    op(cpp_var_a, cpp_var_b)
    cpp_a_grad = _column_major_list_to_matrix(cpp_var_a.grad.data, mat_a)
    cpp_b_grad = _column_major_list_to_matrix(cpp_var_b.grad.data, mat_b)
    return cpp_a_grad, cpp_b_grad

In [117]:
import random

LO_R, HI_R = 4, 10
LO_N, HI_N = -10, 10
EPSILON = 1e-5

def generate_matrix(rows, cols): 
    matrix = [[random.uniform(LO_N, HI_N) for _ in range(cols)] for _ in range(rows)]
    return matrix

def _similar_enough(mat_a, mat_b):
    mat_a = torch.tensor(mat_a)
    mat_b = torch.tensor(mat_b)
    return torch.all(torch.abs(mat_a - mat_b) < EPSILON)

def binary_ops_test(ops): # ops is tuple, (name, cpp_op, torch_op)
    rows, cols = random.randint(LO_R, HI_R), random.randint(LO_R, HI_R)
    mat_a = generate_matrix(rows, cols)
    mat_b = generate_matrix(rows, cols)

    cpp_a_grad, cpp_b_grad = cpp_grads(mat_a, mat_b, ops[1])
    torch_a_grad, torch_b_grad = torch_grads(mat_a, mat_b, ops[2])

    assert _similar_enough(cpp_a_grad, torch_a_grad), f"{ops[0]} failed, \ncpp_a_grad  : {cpp_a_grad}; \ntorch_a_grad: {torch_a_grad}"
    assert _similar_enough(cpp_b_grad, torch_b_grad), f"{ops[0]} failed, \ncpp_b_grad  : {cpp_b_grad}; \ntorch_b_grad: {torch_b_grad}"


In [118]:
binary_ops_tuples = [
    ("add", add, torch.add),
    ("sub", sub, torch.sub),
    ("mul", mul, torch.mul),
    ("div", div, torch.div),
]

for ops in binary_ops_tuples:
    binary_ops_test(ops)