[Link to Page](https://tutorials.pytorch.kr/beginner/examples_autograd/two_layer_net_custom_function.html)

In [1]:
import torch

In [2]:
class MyReLU(torch.autograd.Function):
    
    # 정적 메소드 --> 인스턴스화 하지 않고 호출 가능
    # 클래스 내부 참조 불가
    
    @staticmethod
    def forward(ctx, input):
        #        ctx : context object       
        #  역전파를 위한 정보를 저장(cache)        
        ctx.save_for_backward(input)
        return input.clamp(min = 0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone() # relu --> gradient 그대로
        grad_input[input < 0] = 0 # 0보다 작으면 0
        return grad_input

In [3]:
dtype = torch.float
device = torch.device('cpu')

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

In [4]:
learning_rate = 1e-6
for t in range(501):
    relu = MyReLU.apply
    
    #forward
    y_pred = relu(x.mm(w1)).mm(w2)
    
    #loss
    loss = (y_pred - y).pow(2).sum()
    if t % 10 == 0:
        print(t, loss.item())
    
    #backward
    loss.backward()
    
    #update weight
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 23202500.0
10 3400719.25
20 246944.171875
30 73458.0390625
40 29632.986328125
50 13511.732421875
60 6623.73583984375
70 3419.353271484375
80 1834.195068359375
90 1013.7418212890625
100 574.943115234375
110 332.53082275390625
120 195.54351806640625
130 116.6374740600586
140 70.45059967041016
150 43.01255798339844
160 26.506183624267578
170 16.46858024597168
180 10.30650806427002
190 6.490759372711182
200 4.110825538635254
210 2.616396903991699
220 1.6724348068237305
230 1.0731499195098877
240 0.6910210847854614
250 0.44622835516929626
260 0.28895053267478943
270 0.18748614192008972
280 0.12192117422819138
290 0.07943622767925262
300 0.05188418924808502
310 0.0339406281709671
320 0.022280234843492508
330 0.014676293358206749
340 0.00972140021622181
350 0.006483951583504677
360 0.004368926864117384
370 0.0029841966461390257
380 0.0020649568177759647
390 0.0014575822278857231
400 0.0010565704433247447
410 0.0007763088215142488
420 0.0005830462905578315
430 0.0004468579718377441
440 0.000