In [1]:
import torch
import pcn_kernels
class pcnpass(torch.autograd.Function):
    @staticmethod
    def forward(c: torch.Tensor, l: torch.Tensor, lnext: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        output: torch.Tensor = pcn_kernels.forward(c, l, lnext, b)
        return output
    
    @staticmethod
    def setup_context(ctx, inputs, _):
        c, l, lnext, b = inputs
        ctx.save_for_backward(c, l, lnext, b)

    @staticmethod
    def backward(ctx, grad_cnext) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        c, l, lnext, b = ctx.saved_tensors
        grad_c, grad_l, grad_lnext, grad_b = pcn_kernels.backward(grad_cnext.contiguous(), c, l, lnext, b)
        return grad_c, grad_l, grad_lnext, grad_b

FREQ = 10.0

# def tri(z):
#     return 1 - 2 * abs(((z) % 2) - 1)

def tri_new_derivative(z):
    return 2 if z % 2 < 1 else -2

def tri(period: float, amplitude: float):
    """
    triangle wave function centered around 0 with period and amplitude
    """

    def triangle_wave_transform(x: torch.Tensor):
        # using sigal
        return (amplitude / period) * (
            (period - abs(x % (2 * period) - (1 * period)) - period / 2)
        )

    return triangle_wave_transform

In [2]:
# test.py
import numpy as np
from time import time
import copy
gpu = torch.device('cuda:0')

B = 32
N = 500
F = 50000
D = 4

l = torch.rand(N, D, requires_grad=True)
lnext = torch.rand(F, D, requires_grad=True)
c = torch.rand(B, N, requires_grad=True)
b = torch.rand(F, requires_grad=True)


# all to gpu


In [3]:




# if N * F < 1e8:
l2 = copy.deepcopy(l)
lnext2 = copy.deepcopy(lnext)
c2 = copy.deepcopy(c)
b2 = copy.deepcopy(b)

l = l.to(gpu)
lnext = lnext.to(gpu)
c = c.to(gpu)
b = b.to(gpu)
l2 = l2.to(gpu)
lnext2 = lnext2.to(gpu)
c2 = c2.to(gpu)
b2 = b2.to(gpu)
l2.retain_grad()
lnext2.retain_grad()
c2.retain_grad()
b2.retain_grad()
# Retain grad
l.retain_grad()
lnext.retain_grad()
c.retain_grad()
b.retain_grad()
t = time()
cnext: torch.Tensor = pcnpass.apply(c, l, lnext, b)
print(time() - t)
cnext.retain_grad()

cnext2 = None
if N * F < 1e8:
    t = time()
    cnext2 = (c2 @ (tri(0.1, 1)(torch.cdist(l2, lnext2)) / np.sqrt(l2.shape[0]))) + b2
    print(time() - t)
    cnext2.retain_grad()

0.038500070571899414
0.20549821853637695


In [6]:
cnext2.grad

tensor([[0.0156, 0.0041, 0.0038,  ..., 0.0076,    nan, 0.0049],
        [0.0769, 0.0054, 0.0036,  ..., 0.0069,    nan, 0.0044],
        [0.0248, 0.0050, 0.0033,  ..., 0.0062,    nan, 0.0063],
        ...,
        [0.0354, 0.0050, 0.0034,  ..., 0.0059,    nan, 0.0038],
        [0.0185, 0.0068, 0.0037,  ..., 0.0076,    nan, 0.0061],
        [0.0151, 0.0046, 0.0039,  ..., 0.0056,    nan, 0.0040]],
       device='cuda:0')

In [5]:
loss = torch.sum(cnext ** (1/10)) / B

t = time()
loss.backward()
print(time() - t)

if cnext2 is not None:
    loss2 = torch.sum(cnext2 ** (1/10)) / B
    t = time()
    loss2.backward()
    print(time() - t)

    # Report findings
    print(torch.allclose(c.grad, c2.grad))
    print(torch.allclose(l.grad, l2.grad))
    print(torch.allclose(lnext.grad, lnext2.grad))
    print(torch.allclose(b.grad, b2.grad))

0.5574979782104492
0.04899883270263672
False
False
False
False


In [17]:
v = lnext
v2 = lnext2

print(v.grad)
print(v2.grad)
print(v2.grad - v.grad)

tensor([[-0.2256,  0.2692,  0.4401,  0.5024],
        [    nan,     nan,     nan,     nan],
        [ 0.0703,  2.0552, -1.9464, -0.9433],
        ...,
        [    nan,     nan,     nan,     nan],
        [-0.7863,  0.0325, -0.0623, -0.4276],
        [ 0.1722, -0.6660, -0.0199, -0.1288]], device='cuda:0')
tensor([[-0.0827,  0.1060,  0.1561,  0.1937],
        [    nan,     nan,     nan,     nan],
        [ 0.0281,  0.8795, -0.8420, -0.3483],
        ...,
        [    nan,     nan,     nan,     nan],
        [-0.4081,  0.0196, -0.0338, -0.2233],
        [ 0.0872, -0.3473, -0.0112, -0.0654]], device='cuda:0')
tensor([[ 0.1430, -0.1633, -0.2840, -0.3087],
        [    nan,     nan,     nan,     nan],
        [-0.0422, -1.1757,  1.1044,  0.5950],
        ...,
        [    nan,     nan,     nan,     nan],
        [ 0.3782, -0.0129,  0.0284,  0.2043],
        [-0.0850,  0.3187,  0.0087,  0.0634]], device='cuda:0')


In [18]:
(c.grad - c2.grad).max(), \
(l.grad - l2.grad).max(), \
(lnext.grad - lnext2.grad).max(), \
(b.grad - b2.grad).max()

(tensor(nan, device='cuda:0'),
 tensor(nan, device='cuda:0'),
 tensor(nan, device='cuda:0'),
 tensor(nan, device='cuda:0'))

In [9]:
(c.grad - c2.grad).max() / c.grad.var(), \
(l.grad - l2.grad).max() / l.grad.var(), \
(lnext.grad - lnext2.grad).max() / lnext.grad.var(), \
(b.grad - b2.grad).max() / b.grad.var()

(tensor(0.0004, device='cuda:0'),
 tensor(6.4654e-05, device='cuda:0'),
 tensor(0.0602, device='cuda:0'),
 tensor(0.0029, device='cuda:0'))

In [10]:
(lnext.grad - lnext2.grad).max()

tensor(0.0007, device='cuda:0')

In [None]:

# Scale test

from time import time
import torch
gpu = torch.device('cuda:0')

B = 64
N = 5000
F = 500000
D = 4

l = torch.rand(N, D, requires_grad=True).to(gpu)
lnext = torch.rand(F, D, requires_grad=True).to(gpu)
c = torch.rand(B, N, requires_grad=True).to(gpu)
b = torch.rand(F, requires_grad=True).to(gpu)

l.retain_grad()
lnext.retain_grad()
c.retain_grad()
b.retain_grad()

t = time()
cnext = pcnpass.apply(c, l, lnext, b)
print(time() - t)
cnext.retain_grad()

loss = torch.sum(cnext)

t = time()
loss.backward()
print(time() - t)
