In [23]:
import torch

class MatrixMultiplyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input1, input2):
        ctx.save_for_backward(input1, input2)
        return torch.matmul(input1, input2)

    @staticmethod
    def backward(ctx, grad_output):
        input1, input2 = ctx.saved_tensors
        print(grad_output.shape)
        grad_input1 = torch.matmul(grad_output, input2.t())
        grad_input2 = torch.matmul(input1.t(), grad_output)
        return grad_input1, grad_input2

# Example usage:
# Create tensors
x = torch.randn(1, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=True)

# Call the custom function
mul = MatrixMultiplyFunction.apply
output = mul(x, y)

# Suppose we have some loss function
loss = output.sum()

# Backward pass
loss.backward()

# Gradients
print(x)
print(y)
print("x_grad", x.grad)
print("y_grad", y.grad)

torch.Size([1, 2])
tensor([[-0.6703,  1.6879]], requires_grad=True)
tensor([[1.1448, 0.8391],
        [0.0912, 2.8409]], requires_grad=True)
x_grad tensor([[1.9839, 2.9322]])
y_grad tensor([[-0.6703, -0.6703],
        [ 1.6879,  1.6879]])


In [1]:
import torch

In [3]:
A = torch.randn(2, 2, requires_grad=True)
B = torch.randn(1, 2, requires_grad=True)

In [8]:
(torch.sum(B @ A)).backward()

In [9]:
A

tensor([[-0.3815,  0.7679],
        [-1.6135,  1.3187]], requires_grad=True)

In [10]:
B

tensor([[0.8098, 1.2933]], requires_grad=True)

In [11]:
A.grad

tensor([[0.8098, 0.8098],
        [1.2933, 1.2933]])

In [12]:
B.grad

tensor([[ 0.3863, -0.2948]])

In [13]:
0.7679 - 0.3815

0.3864