In [1]:
import torch
from torch import nn
import numpy as np
import copy 
from time import time

import pcn_kernels

class pcnpass(torch.autograd.Function):
    @staticmethod
    def forward(c: torch.Tensor, l: torch.Tensor, lnext: torch.Tensor, b: torch.Tensor):
        output = pcn_kernels.forward(c, l, lnext, b)
        return output
    
    @staticmethod
    def setup_context(ctx, inputs, _):
        c, l, lnext, b = inputs
        ctx.save_for_backward(c, l, lnext, b)

    @staticmethod
    def backward(ctx, grad_cnext):
        c, l, lnext, b = ctx.saved_tensors
        grad_c, grad_l, grad_lnext, grad_b = pcn_kernels.backward(grad_cnext.contiguous(), c, l, lnext, b)
        return grad_c, grad_l, grad_lnext, grad_b
        # return None, None, None, None

def tri(period: float, amplitude: float):
    """
    triangle wave function centered around 0 with period and amplitude
    """

    def triangle_wave_transform(x: torch.Tensor):
        # using sigal
        return (amplitude / period) * (
            (period - abs(x % (2 * period) - (1 * period)) - period / 2)
        )

    return triangle_wave_transform

In [19]:
# Speed @ scale test

B = 32
N = 5000
F = 5000
D = 16

l = nn.Parameter(torch.rand(N, D))
lnext = nn.Parameter(torch.rand(F, D))
c = nn.Parameter(torch.rand(B, N))
b = nn.Parameter(torch.rand(F))
# l = torch.rand(N, D, requires_grad=True)
# lnext = torch.rand(F, D, requires_grad=True)
# c = torch.rand(B, N)
# b = torch.rand(F, requires_grad=True)
l = l.cuda()
lnext = lnext.cuda()
c = c.cuda()
b = b.cuda()

app = pcnpass.apply

t = time()
cnext: torch.Tensor = app(c, l, lnext, b)
# cnext = 1 + (c @ (tri(0.1, 2)(torch.cdist(l, lnext))) / np.sqrt(l.shape[0]) + b)
# cnext += 1
print(f"Time taken: {time() - t} seconds")

t = time()
print(cnext[0, 0])
print(f"Time taken: {time() - t} seconds")

t = time()
cnext.sum().backward()
print(f"Time taken: {time() - t} seconds")


Time taken: 0.0004994869232177734 seconds
tensor(0.6575, device='cuda:0', grad_fn=<SelectBackward0>)
Time taken: 0.25350189208984375 seconds
Time taken: 3.120999336242676 seconds


In [5]:
cnext

tensor([[0.7459, 0.5493, 0.5304,  ..., 0.5658, 0.7056, 0.6523],
        [0.8264, 0.8070, 0.6457,  ..., 0.3933, 0.2267, 0.8523],
        [0.7831, 0.7814, 0.4453,  ..., 0.7721, 0.3657, 0.3119],
        ...,
        [0.7208, 0.6646, 0.4097,  ..., 0.7346, 0.4220, 0.6039],
        [0.6091, 0.8735, 0.5585,  ..., 0.4688, 0.2861, 0.5726],
        [0.6846, 0.5720, 0.6047,  ..., 0.3046, 0.1498, 0.5137]],
       device='cuda:0', grad_fn=<pcnpassBackward>)

In [7]:
B = 32
N = 500
F = 5000
D = 64

l = torch.rand(N, D, requires_grad=True)
lnext = torch.rand(F, D, requires_grad=True)
c = torch.rand(B, N, requires_grad=True)
b = torch.rand(F, requires_grad=True)

l2 = copy.deepcopy(l)
lnext2 = copy.deepcopy(lnext)
c2 = copy.deepcopy(c)
b2 = copy.deepcopy(b)

l = l.cuda()
lnext = lnext.cuda()
c = c.cuda()
b = b.cuda()
l2 = l2.cuda()
lnext2 = lnext2.cuda()
c2 = c2.cuda()
b2 = b2.cuda()

l.retain_grad()
lnext.retain_grad()
c.retain_grad()
b.retain_grad()
l2.retain_grad()
lnext2.retain_grad()
c2.retain_grad()
b2.retain_grad()

cnext: torch.Tensor = pcnpass.apply(c, l, lnext, b)
cnext2: torch.Tensor = c2 @ ((tri(0.1, 2)(torch.cdist(l2, lnext2))) / np.sqrt(l2.shape[0])) + b2

t = time()
cnext.mean().backward()
print(f"Time taken: {time() - t} seconds")

t = time()
cnext2.mean().backward()
print(f"Time taken: {time() - t} seconds")

Time taken: 0.3254985809326172 seconds
Time taken: 0.0020017623901367188 seconds


In [8]:
v1 = l.grad
v2 = l2.grad

print(torch.allclose(v1, v2))
print(v1)
print(v2)

False
tensor([[ 8.9799e-04,  1.4514e-03,  6.9525e-04,  ...,  1.3024e-03,
          3.1200e-05, -8.5248e-04],
        [ 6.5266e-04, -3.9078e-04,  2.9419e-04,  ...,  2.9573e-04,
          4.9298e-05, -6.1420e-04],
        [ 2.5319e-04,  5.6479e-04, -1.4843e-05,  ...,  9.6867e-04,
         -4.6092e-04,  1.5381e-04],
        ...,
        [-7.5668e-04,  8.4798e-04, -7.0788e-04,  ...,  7.9027e-04,
          6.7182e-04,  8.4375e-05],
        [-9.6641e-04, -1.9852e-04,  4.2675e-05,  ..., -2.8976e-04,
          1.4574e-03,  1.6807e-04],
        [-5.5577e-05, -5.8934e-04,  5.6020e-04,  ..., -7.7825e-04,
         -9.7816e-04, -1.2462e-03]], device='cuda:0')
tensor([[ 8.9799e-04,  1.4514e-03,  6.9525e-04,  ...,  1.3024e-03,
          3.1200e-05, -8.5248e-04],
        [ 6.1246e-04, -4.1406e-04,  2.7982e-04,  ...,  2.6567e-04,
          5.0411e-05, -6.0400e-04],
        [ 2.5319e-04,  5.6478e-04, -1.4840e-05,  ...,  9.6867e-04,
         -4.6091e-04,  1.5381e-04],
        ...,
        [-7.5669e-04,  

In [9]:
(cnext - cnext2).abs().max()
# (cnext - cnext2).var()
# cnext, cnext2
# cnext

tensor(3.8281e-05, device='cuda:0', grad_fn=<MaxBackward1>)

In [10]:
print(torch.allclose(cnext, cnext2, atol=1e-6)),
print(torch.allclose(c.grad, c2.grad, atol=1e-6)),
print(torch.allclose(l.grad, l2.grad, atol=1e-6)),
print(torch.allclose(lnext.grad, lnext2.grad, atol=1e-6)),
print(torch.allclose(b.grad, b2.grad))

(cnext - cnext2).max(), \
(c.grad - c2.grad).max(), \
(l.grad - l2.grad).max(), \
(lnext.grad - lnext2.grad).max(), \
(b.grad - b2.grad).max()

False
True
False
False
True


(tensor(3.8281e-05, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(8.3901e-10, device='cuda:0'),
 tensor(5.7835e-05, device='cuda:0'),
 tensor(5.4964e-05, device='cuda:0'),
 tensor(1.4552e-11, device='cuda:0'))

In [19]:
from torch.profiler import profile, record_function

loss = cnext.mean()

with profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_backward"):
        loss.backward()

print(prof.key_averages().table(sort_by="cuda_time_total"))

backward time:  0.0
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         model_backward         0.01%       1.678ms        50.00%        9.299s        9.299s        8.228s        50.00%        8.228s        8.228s             1  
autograd::engine::evaluate_function: pcnpassBackward...         0.00%      36.000us         0.00%     880.000us     880.000us       4.000us         0.00%        8.228s        8.228s      

In [12]:
class MySin(torch.autograd.Function):
    @staticmethod
    def forward(inp, inp2, inp3, inp4):
        """ compute forward pass of custom function """
        dum1 = torch.zeros_like(inp)
        dum2 = torch.zeros_like(inp2)
        dum3 = torch.zeros_like(inp3)
        dum4 = torch.zeros_like(inp4)
        return pcn_kernels.forward(inp, inp2, inp3, inp4)
        # pcn_kernels.forward(dum1, dum2, dum3, dum4)
        # return torch.zeros((inp.shape[0], inp3.shape[0]), requires_grad=True).cuda()
        # return pcn_kernels.forward(inp, inp2, inp3, inp4)
        # return inp#(inp + inp2 + inp3 + inp4).sin()  # compute forward pass, can also be computed by any other library

    @staticmethod
    def setup_context(ctx, inputs, _):
        inp, inp2, inp3, inp4 = inputs
        ctx.save_for_backward(inp, inp2, inp3, inp4)

    @staticmethod
    def backward(ctx, grad_out):
        """ compute product of output gradient with the 
        jacobian of your function evaluated at input """
        inp, inp2, inp3, inp4 = ctx.saved_tensors
        grad_inp = torch.zeros_like(inp)
        grad_inp2 = torch.zeros_like(inp2)
        grad_inp3 = torch.zeros_like(inp3)
        grad_inp4 = torch.zeros_like(inp4)
        # print (pcn_kernels.backward(grad_out.contiguous(), inp, inp2, inp3, inp4))
        return tuple(pcn_kernels.backward(grad_out.contiguous(), inp, inp2, inp3, inp4))
        # grad_inp = grad_out * torch.cos(inp)  # propagate gradient, can also be computed by any other library
        return grad_inp, grad_inp2, grad_inp3, grad_inp4

In [14]:
B = 32
N = 5000
F = 50000
D = 16

c = nn.Parameter(torch.rand(B, N))
l = nn.Parameter(torch.rand(N, D))
lnext = nn.Parameter(torch.rand(F, D))
b = nn.Parameter(torch.rand(F))

inp1 = nn.Parameter(torch.rand(B, N)).cuda()
inp2 = nn.Parameter(torch.rand(N, D)).cuda()
inp3 = nn.Parameter(torch.rand(F, D)).cuda()
inp4 = nn.Parameter(torch.rand(F)).cuda()

t = time()
# torch.sin(input + inp2 + inp3 + inp4).sum().backward()
print(f"Time taken: {time() - t} seconds")

t = time()
x = MySin.apply(inp1, inp2, inp3, inp4)
print(f"Time taken: {time() - t} seconds")

t = time()
x.sum().backward()
print(f"Time taken: {time() - t} seconds")

Time taken: 0.0 seconds
Time taken: 0.0010004043579101562 seconds
Time taken: 8.629997491836548 seconds


In [65]:
# It is a mystery to me what is happening with the inner mechanics of the backward pass
# the `pcn_kernel.backward` runs quickly like the forward pass, and has printable values
# but timing the full `.backward()` is very slow

# The pcn_kernel.backward doesn't even need to be called in the backward pass
# regardless of what happens in the backward pass, if the forward pass calls `pcn_kernel.forward`
# then the backward pass ends up being much slower then if that isn't called in the forward pass
# Why is the backward time affected by what happens in the forward pass ???????

In [30]:
import torch
from torch import nn
import pcn_kernels
from time import time
import numpy as np

B = 32
N = 50000
F = 50000
D = 16

t = time()
c = nn.Parameter(torch.rand(B, N)).cuda()
l = nn.Parameter(torch.rand(N, D)).cuda()
lnext = nn.Parameter(torch.rand(F, D)).cuda()
b = nn.Parameter(torch.rand(F)).cuda()
print(f"Time taken: {time() - t} seconds")

t = time()
cnext = pcn_kernels.forward(c, l, lnext, b)
print(f"Time taken: {time() - t} seconds")

# sum = cnext.sum()
# # manually calculate grad_cnext
# grad_cnext = torch.Tensor(cnext.shape).cuda()
# grad_cnext.fill_(1.0 / np.sqrt(cnext.shape[0]))

# t = time()
# grad_c, grad_l, grad_lnext, grad_b = pcn_kernels.backward(grad_cnext, c, l, lnext, b)
# print(f"Time taken: {time() - t} seconds")

Time taken: 0.024500370025634766 seconds
Time taken: 0.0 seconds


In [31]:
cnext[0, 0]

tensor(-0.2098, device='cuda:0')

In [25]:
t = time()
print(cnext)
print(f"Time taken: {time() - t} seconds")

# It seems to maybe be the case that kernel function outputs are lazy loaded?
# - If we run the above code, and above block and then this block, this block takes ~7.8 seconds
# - If we then run the top block again, the first time taken will be low (~0.01 seconds)
# - But if we run the top block twice in a row, the first time taken will be high (~7.8 seconds)
# - whichs makes me think that, here, by printing one of the grads, we force the `pcn_kernel.backward` to trigger
# - In the first time taken section in the above block, we trigger the `pcn_kernel.backward` pass
#   by trying to reset some of the inputs of the `pcn_kernel.backward` to new values

# So the moral of the story is, my backward pass simply is just slow. I should be able to fix this with a better
# threading layout and hopefully removing the need for atomic functions

tensor([[ 0.3853,  0.7242,  0.3629,  ...,  0.4866,  0.5627,  0.0840],
        [ 0.6922,  0.6477,  0.1949,  ...,  0.3359,  0.5440,  0.1395],
        [ 0.4976,  0.4154,  0.4153,  ...,  0.4049,  0.4551, -0.0340],
        ...,
        [ 0.7102,  0.9383,  0.4187,  ...,  0.3982,  0.3767, -0.2267],
        [ 0.4361,  0.7735,  0.3790,  ...,  0.3035,  0.5073, -0.1499],
        [ 0.5498,  0.7083,  0.3224,  ...,  0.5351,  0.6693, -0.0820]],
       device='cuda:0')
Time taken: 9.345498323440552 seconds


In [14]:
# pcn2d
# pcn16d
(784 + 512 + 10) * 5, \
(784 + 2048 + 10) * 5, \
(784 + 128 + 10) * 17,

(6530, 14210, 15674)

In [12]:
# xs
# s
# m
# l
((784 + 1) * 10 + (10 + 1) * 10), \
((784 + 1) * 64 + (64 + 1) * 10), \
((784 + 1) * 128 + (128 + 1) * 10), \
((784 + 1) * 128 + (128 + 1) * 128 + (128 + 1) * 128 + (128 + 1) * 10)

(7960, 50890, 101770, 134794)