In [1]:
import torch
from torch.autograd import Variable

In [2]:
x = Variable(torch.ones(2, 2), requires_grad=True)
print(x)

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)


In [3]:
y = x + 2
print(y)

tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)


In [4]:
print(y.grad_fn)

<AddBackward0 object at 0x00000211C09DB730>


In [10]:
z = y * y * 3
out = z.mean()
print(z, out)

tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)


In [11]:
print(z)

tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)


In [12]:
print(out)

tensor(27., grad_fn=<MeanBackward0>)


In [6]:
out.backward()

In [8]:
print(x)

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)


In [9]:
print(x.grad)

tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])


The Chain Rule:

The chain rule is a fundamental concept in calculus that allows us to efficiently compute the derivative of a composite function. It states that the derivative of a composite function w(u(x)) can be found by multiplying the derivative of the outer function w evaluated at the inner function u(x) with the derivative of the inner function u(x):

dw/dx = dw/du * du/dx

Applying the Chain Rule:

In your code, let's assume the following (we don't see the exact calculation of y):

x is a 2x2 tensor with elements x_ij (i representing the row index and j representing the column index).
y is another 2x2 tensor with elements y_ij calculated as y_ij = x_ij + 2.
The final output you're interested in is the gradient of the loss out with respect to each element in x (i.e., the elements of x.grad). Here's how we can apply the chain rule:

1. Loss Function (Outer Function):

Let's assume a common loss function used in neural networks, the mean squared error (MSE):

loss (out) = 1/n * sum((out - target)^2)

where n is the number of elements in out (in this case, 1 since it's a scalar) and target is the desired output value (which might be a constant or another tensor depending on the specific task).

The derivative of the MSE loss function with respect to out is simply:

d (loss) / d (out) = 2/n * (out - target)

2. out (Inner Function):

out is calculated as the mean of the elements in z:

out = 1/4 * sum(z_ij)

where z_ij represents an element in the 2x2 tensor z.

The derivative of out with respect to each element z_ij in z is simply:

d (out) / d (z_ij) = 1/4 (since all elements in z contribute equally to the mean)

3. z (Inner Function):

Each element z_ij in z is calculated as:

z_ij = 3 * (y_ij)^2

The derivative of z_ij with respect to y_ij is:

d (z_ij) / d (y_ij) = 6 * y_ij

4. y (Inner Function):

Each element y_ij in y is related to x_ij through:

y_ij = x_ij + 2

The derivative of y_ij with respect to x_ij is simply:

d (y_ij) / d (x_ij) = 1

Putting It Together:

Now, we can use the chain rule to compute the gradient of the loss out with respect to each element x_ij in x:

d (loss) / d (x_ij) = d (loss) / d (out) * d (out) / d (z_ij) * d (z_ij) / d (y_ij) * d (y_ij) / d (x_ij)

                      = (2/n * (out - target)) * (1/4) * (6 * y_ij) * 1

                      = (3 * (out - target) * y_ij) / n
Note: This is the general form of the gradient for each element in x. The specific values in y_ij (which depend on the initial value of x_ij) and the target value in the loss function would determine the final numerical values in x.grad.