In [1]:
import torch

class MyReLU(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        # ctx - context object, который может быть использован для накопления информации и выполнения обратного распространения
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        # Получаем тензор, содержащий градиент loss'a w.r.t output, и нам нужен dl/dinput
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
    
dtype = torch.float
device = torch.device('cuda:0')
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

lr = 1e-6
epochs = 500
losses = []
for epoch in range(epochs):
    relu = MyReLU.apply
    y_pred = relu(x.mm(w1)).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(epoch, loss)
    
    loss.backward()
    
    with torch.no_grad():
        w1 -= lr * w1.grad
        w2 -= lr * w2.grad
    w1.grad.zero_()
    w2.grad.zero_()
    losses.append(loss.item())

0 tensor(30036410., device='cuda:0', grad_fn=<SumBackward0>)
1 tensor(27299036., device='cuda:0', grad_fn=<SumBackward0>)
2 tensor(25758444., device='cuda:0', grad_fn=<SumBackward0>)
3 tensor(22357958., device='cuda:0', grad_fn=<SumBackward0>)
4 tensor(16910556., device='cuda:0', grad_fn=<SumBackward0>)
5 tensor(11136128., device='cuda:0', grad_fn=<SumBackward0>)
6 tensor(6690526.5000, device='cuda:0', grad_fn=<SumBackward0>)
7 tensor(3930013.5000, device='cuda:0', grad_fn=<SumBackward0>)
8 tensor(2404234.2500, device='cuda:0', grad_fn=<SumBackward0>)
9 tensor(1587338.3750, device='cuda:0', grad_fn=<SumBackward0>)
10 tensor(1136970.8750, device='cuda:0', grad_fn=<SumBackward0>)
11 tensor(870804.6250, device='cuda:0', grad_fn=<SumBackward0>)
12 tensor(699541.1250, device='cuda:0', grad_fn=<SumBackward0>)
13 tensor(579888.8125, device='cuda:0', grad_fn=<SumBackward0>)
14 tensor(490440.2188, device='cuda:0', grad_fn=<SumBackward0>)
15 tensor(420169.0625, device='cuda:0', grad_fn=<SumBackw

131 tensor(69.8311, device='cuda:0', grad_fn=<SumBackward0>)
132 tensor(65.9322, device='cuda:0', grad_fn=<SumBackward0>)
133 tensor(62.2578, device='cuda:0', grad_fn=<SumBackward0>)
134 tensor(58.7933, device='cuda:0', grad_fn=<SumBackward0>)
135 tensor(55.5272, device='cuda:0', grad_fn=<SumBackward0>)
136 tensor(52.4486, device='cuda:0', grad_fn=<SumBackward0>)
137 tensor(49.5436, device='cuda:0', grad_fn=<SumBackward0>)
138 tensor(46.8046, device='cuda:0', grad_fn=<SumBackward0>)
139 tensor(44.2211, device='cuda:0', grad_fn=<SumBackward0>)
140 tensor(41.7838, device='cuda:0', grad_fn=<SumBackward0>)
141 tensor(39.4841, device='cuda:0', grad_fn=<SumBackward0>)
142 tensor(37.3144, device='cuda:0', grad_fn=<SumBackward0>)
143 tensor(35.2663, device='cuda:0', grad_fn=<SumBackward0>)
144 tensor(33.3352, device='cuda:0', grad_fn=<SumBackward0>)
145 tensor(31.5111, device='cuda:0', grad_fn=<SumBackward0>)
146 tensor(29.7899, device='cuda:0', grad_fn=<SumBackward0>)
147 tensor(28.1649, devi

266 tensor(0.0540, device='cuda:0', grad_fn=<SumBackward0>)
267 tensor(0.0514, device='cuda:0', grad_fn=<SumBackward0>)
268 tensor(0.0489, device='cuda:0', grad_fn=<SumBackward0>)
269 tensor(0.0465, device='cuda:0', grad_fn=<SumBackward0>)
270 tensor(0.0442, device='cuda:0', grad_fn=<SumBackward0>)
271 tensor(0.0421, device='cuda:0', grad_fn=<SumBackward0>)
272 tensor(0.0400, device='cuda:0', grad_fn=<SumBackward0>)
273 tensor(0.0381, device='cuda:0', grad_fn=<SumBackward0>)
274 tensor(0.0362, device='cuda:0', grad_fn=<SumBackward0>)
275 tensor(0.0345, device='cuda:0', grad_fn=<SumBackward0>)
276 tensor(0.0328, device='cuda:0', grad_fn=<SumBackward0>)
277 tensor(0.0312, device='cuda:0', grad_fn=<SumBackward0>)
278 tensor(0.0297, device='cuda:0', grad_fn=<SumBackward0>)
279 tensor(0.0283, device='cuda:0', grad_fn=<SumBackward0>)
280 tensor(0.0269, device='cuda:0', grad_fn=<SumBackward0>)
281 tensor(0.0256, device='cuda:0', grad_fn=<SumBackward0>)
282 tensor(0.0244, device='cuda:0', grad

402 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
403 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
404 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
405 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
406 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
407 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
408 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
409 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
410 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
411 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
412 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
413 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
414 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
415 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
416 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
417 tensor(0.0002, device='cuda:0', grad_fn=<SumBackward0>)
418 tensor(0.0002, device='cuda:0', grad