In [1]:
import torch
from torch import nn

In [2]:
x = torch.tensor(5., dtype=torch.float32)
w = torch.tensor(2., dtype=torch.float32, requires_grad=True)
y = w * x
z = y.detach()
print(z.requires_grad)
u = torch.tensor(2., dtype=torch.float32, requires_grad=True)
v = z * u

False


In [3]:
v.backward()

In [4]:
u.grad, z, u.grad == z

(tensor(10.), tensor(10.), tensor(True))

In [5]:
w.grad

In [6]:
class LinFn(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x, w):
        # with torch.no_grad():
        o = x * w
        print(x, w)
        ctx.save_for_backward(x, w)
        print(o, o.data_ptr(), o.requires_grad)
        return o.detach()
    @staticmethod
    def backward(ctx, grad_o):
        x, w = ctx.saved_tensors
        print("backward x: , w: ", x, w)
        grad_x, grad_w = None, None
        if ctx.needs_input_grad[0]:
            grad_x = w
            grad_x = grad_x * grad_o
        if ctx.needs_input_grad[1]:
            grad_w = x
            grad_w = grad_w * grad_o
            
        
        return grad_x , grad_w 
            
        

In [7]:
def linear(x, w):
    return LinFn.apply(x, w)

In [8]:
x = torch.tensor(5., dtype=torch.float32)
w = torch.tensor(2., dtype=torch.float32, requires_grad=True)
y = linear(x, w)
print(y)

tensor(5.) tensor(2., requires_grad=True)
tensor(10.) 94420417797760 False
tensor(10., grad_fn=<LinFnBackward>)


In [9]:
y.data_ptr()

94420417797760

In [10]:
x = torch.tensor(5., dtype=torch.float32)
w = torch.tensor(2., dtype=torch.float32, requires_grad=True)
with torch.no_grad():
    y = linear(x, w)
print(y)

tensor(5.) tensor(2., requires_grad=True)
tensor(10.) 94420417801216 False
tensor(10.)


In [11]:
class LinFn1(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x, w):
        # with torch.no_grad():
        o = x * w
        ctx.save_for_backward(x, w)
        return o.detach()
    @staticmethod
    def backward(ctx, grad_o):
        x, w = ctx.saved_tensors
        print("backward x: , w: ", x, w)
        grad_x, grad_w = None, None
        if ctx.needs_input_grad[0]:
            grad_x = w
            grad_x = grad_x * grad_o
        if ctx.needs_input_grad[1]:
            grad_w = x
            grad_w = grad_w * grad_o
            
        
        return grad_x , grad_w 

In [12]:
def linear1(x, w):
    return LinFn1.apply(x, w)

In [13]:
x = torch.tensor(5., dtype=torch.float32)
w = torch.tensor(2., dtype=torch.float32, requires_grad=True)
y = linear1(x, w)
z = y.detach()
print(z.requires_grad)
u = torch.tensor(2., dtype=torch.float32, requires_grad=True)
v = linear1(z, u)

False


In [14]:
v.backward()

backward x: , w:  tensor(10.) tensor(2., requires_grad=True)


In [15]:
u.grad

tensor(10.)

In [16]:
w.grad

In [17]:
x = torch.tensor(5., dtype=torch.float32)
w = torch.tensor(2., dtype=torch.float32, requires_grad=True)
y = linear1(x, w)
z = y
print(z.requires_grad)
u = torch.tensor(2., dtype=torch.float32, requires_grad=True)
v = linear1(z, u)

True


In [18]:
v.backward()

backward x: , w:  tensor(10., grad_fn=<LinFn1Backward>) tensor(2., requires_grad=True)
backward x: , w:  tensor(5.) tensor(2., requires_grad=True)


In [19]:
u.grad

tensor(10.)

In [20]:
w.grad # v = z * u = y * u = x * w * u

tensor(10.)

In [180]:
class QueryLinear(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, inputs, weights):
        q = inputs.view(-1, inputs.size(2),) @ weights
        ctx.save_for_backward(inputs, weights)
        return q.view(inputs.size(0), inputs.size(1), -1).detach()
    
    
    @staticmethod
    def backward(ctx, grad_q):
        inputs, weights = ctx.saved_tensors
        grad_input, grad_weight = None, None
        if ctx.needs_input_grad[0]:
            grad_input = grad_q @  weights.t()
        if ctx.needs_input_grad[1]:
            grad_weight = grad_q.unsqueeze(-1) * inputs.unsqueeze(inputs.dim() - 1)
            grad_weight = grad_weight.sum((0, 1)).T
                        
        return grad_input,  grad_weight

In [181]:
torch.manual_seed(5)
bz, sz, hz = 2, 6, 32
fdim = 64
x = torch.randn(bz, sz, hz, requires_grad=True)
w = nn.Parameter(torch.Tensor(hz, fdim))
with torch.no_grad():
    w.data.copy_(torch.randn_like(w.data))
y = QueryLinear.apply(x, w)
y.shape

torch.Size([2, 6, 64])

In [182]:
z = y.sum()
z.backward()

Grad q shape:  torch.Size([2, 6, 64])
1s:  torch.Size([2, 6, 32]) torch.Size([2, 6, 64])
1s:  torch.Size([2, 6, 1, 32]) torch.Size([2, 6, 64])
grad_weight :  torch.Size([32, 64])


In [183]:
x.grad[0][0]

tensor([  3.3197,   6.0255,  -0.6870,  -8.2403,   3.2846,  11.9100,   9.0050,
         -3.3683,  -1.7355,   0.5828,   6.0587,  -5.2056,   4.4373,  -1.7716,
          2.5015,  -9.2717,   1.4272,  -2.0466,   6.3974,  -6.8130,   6.5871,
         -0.6450, -19.5809,  13.9310,   9.5542,  -0.8859,  -1.9697,  -4.5802,
          0.4948,  12.9984,  -4.4692,   1.4250])

In [184]:
x.grad[-1][0]

tensor([  3.3197,   6.0255,  -0.6870,  -8.2403,   3.2846,  11.9100,   9.0050,
         -3.3683,  -1.7355,   0.5828,   6.0587,  -5.2056,   4.4373,  -1.7716,
          2.5015,  -9.2717,   1.4272,  -2.0466,   6.3974,  -6.8130,   6.5871,
         -0.6450, -19.5809,  13.9310,   9.5542,  -0.8859,  -1.9697,  -4.5802,
          0.4948,  12.9984,  -4.4692,   1.4250])

In [185]:
w.grad

tensor([[ 0.6579,  0.6579,  0.6579,  ...,  0.6579,  0.6579,  0.6579],
        [ 3.7910,  3.7910,  3.7910,  ...,  3.7910,  3.7910,  3.7910],
        [ 4.2263,  4.2263,  4.2263,  ...,  4.2263,  4.2263,  4.2263],
        ...,
        [-1.6648, -1.6648, -1.6648,  ..., -1.6648, -1.6648, -1.6648],
        [ 0.0623,  0.0623,  0.0623,  ...,  0.0623,  0.0623,  0.0623],
        [ 1.5735,  1.5735,  1.5735,  ...,  1.5735,  1.5735,  1.5735]])

In [186]:
w.grad.sum()

tensor(468.4965)

In [187]:
torch.manual_seed(5)
bz, sz, hz = 2, 6, 32
fdim = 64
x = torch.randn(bz, sz, hz, requires_grad=True)
w = nn.Parameter(torch.Tensor(hz, fdim))
with torch.no_grad():
    w.data.copy_(torch.randn_like(w.data))
y = x @ w
z = y.sum()
z.backward()
x.grad[0][0], x.grad[-1][0]

(tensor([  3.3197,   6.0255,  -0.6870,  -8.2403,   3.2846,  11.9100,   9.0050,
          -3.3683,  -1.7355,   0.5828,   6.0587,  -5.2056,   4.4373,  -1.7716,
           2.5015,  -9.2717,   1.4272,  -2.0466,   6.3974,  -6.8130,   6.5871,
          -0.6450, -19.5809,  13.9310,   9.5542,  -0.8859,  -1.9697,  -4.5802,
           0.4948,  12.9984,  -4.4692,   1.4250]),
 tensor([  3.3197,   6.0255,  -0.6870,  -8.2403,   3.2846,  11.9100,   9.0050,
          -3.3683,  -1.7355,   0.5828,   6.0587,  -5.2056,   4.4373,  -1.7716,
           2.5015,  -9.2717,   1.4272,  -2.0466,   6.3974,  -6.8130,   6.5871,
          -0.6450, -19.5809,  13.9310,   9.5542,  -0.8859,  -1.9697,  -4.5802,
           0.4948,  12.9984,  -4.4692,   1.4250]))

In [188]:
w.grad

tensor([[ 0.6579,  0.6579,  0.6579,  ...,  0.6579,  0.6579,  0.6579],
        [ 3.7910,  3.7910,  3.7910,  ...,  3.7910,  3.7910,  3.7910],
        [ 4.2263,  4.2263,  4.2263,  ...,  4.2263,  4.2263,  4.2263],
        ...,
        [-1.6648, -1.6648, -1.6648,  ..., -1.6648, -1.6648, -1.6648],
        [ 0.0623,  0.0623,  0.0623,  ...,  0.0623,  0.0623,  0.0623],
        [ 1.5735,  1.5735,  1.5735,  ...,  1.5735,  1.5735,  1.5735]])

In [189]:
w.grad.sum()

tensor(468.4965)

In [163]:
a = torch.randn(2, 8, 3, 2)
b =   torch.randn(2, 8, 2, 4) 
print(a.size(), b.size())
torch.matmul(a , b).shape

torch.Size([2, 8, 3, 2]) torch.Size([2, 8, 2, 4])


torch.Size([2, 8, 3, 4])

In [141]:
(torch.randn(2, 6, 1, 32) * torch.randn(2, 6, 64, 1) ).shape

torch.Size([2, 6, 64, 32])

In [192]:
(torch.randn(2, 6, 64, 1) @ torch.randn(2, 6, 1, 32) ).shape

torch.Size([2, 6, 64, 32])

In [176]:
(torch.randn(2, 6, 64, 1) @ torch.randn(2, 6, 1, 32) ).shape

torch.Size([2, 6, 64, 32])