In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function


In [2]:
# 
dtype = torch.float
device = torch.device("cuda:0")
N, D_in, H, D_out = 64, 1000, 100, 10
lr =1e-6

In [3]:
class MyReLU(Function):
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    

class MySigmoid(Function):
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        sigmoid = 1 / 1+torch.exp(-input)
        return sigmoid.clamp(min=0, max=1)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input <0] = 0
        return grad_input

In [4]:
x = torch.randn(N,D_in, device=device, requires_grad=False)
y = torch.randn(N,D_out, device=device, requires_grad=False)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H,D_out, device=device, dtype=dtype, requires_grad=True)


In [7]:
for i in range(500):
    # 새로정의한 함수 적용하기 위해서 apply method사용
    relu = MyReLU.apply
    sig = MySigmoid.apply
    y_pred = relu(x.mm(w1)).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(i,loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        w1 -= lr * w1.grad
        w2 -= lr * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 28143212.0
1 21045562.0
2 18944580.0
3 18898318.0
4 18794294.0
5 17792824.0
6 15010102.0
7 11504648.0
8 7888748.0
9 5147717.5
10 3249941.0
11 2096483.5
12 1403198.0
13 995524.5
14 746674.25
15 588899.0625
16 482445.4375
17 406354.875
18 348779.375
19 303304.1875
20 266155.65625
21 235087.78125
22 208650.953125
23 185908.1875
24 166202.5
25 148897.78125
26 133695.75
27 120301.078125
28 108461.3671875
29 97959.4375
30 88622.9921875
31 80298.4375
32 72861.8671875
33 66199.515625
34 60222.7890625
35 54854.3515625
36 50025.3046875
37 45674.3515625
38 41744.5859375
39 38189.75
40 34967.7421875
41 32046.4375
42 29397.154296875
43 26993.328125
44 24806.212890625
45 22813.08984375
46 20994.287109375
47 19333.212890625
48 17815.28515625
49 16427.080078125
50 15158.1298828125
51 13996.072265625
52 12930.4765625
53 11952.9521484375
54 11054.74609375
55 10229.103515625
56 9469.9375
57 8771.1640625
58 8127.7421875
59 7535.0361328125
60 6989.24462890625
61 6486.255859375
62 6022.01953125
63 5593.33

455 0.00010261456191074103
456 0.00010102457599714398
457 9.936479909811169e-05
458 9.82136552920565e-05
459 9.681568917585537e-05
460 9.505078196525574e-05
461 9.386296005686745e-05
462 9.244008106179535e-05
463 9.08343426999636e-05
464 8.932205673772842e-05
465 8.856540080159903e-05
466 8.671271643834189e-05
467 8.547175821149722e-05
468 8.451578469248489e-05
469 8.328286639880389e-05
470 8.181428711395711e-05
471 8.051157055888325e-05
472 7.952332816785201e-05
473 7.858620665501803e-05
474 7.765966438455507e-05
475 7.656596426386386e-05
476 7.552810711786151e-05
477 7.420076144626364e-05
478 7.334839028771967e-05
479 7.239672413561493e-05
480 7.134681800380349e-05
481 7.050448039080948e-05
482 6.934507109690458e-05
483 6.841958384029567e-05
484 6.773741915822029e-05
485 6.65760162519291e-05
486 6.559224129887298e-05
487 6.439629942178726e-05
488 6.372742063831538e-05
489 6.281084642978385e-05
490 6.21763028902933e-05
491 6.130454130470753e-05
492 6.0389022110030055e-05
493 5.9677113