In [1]:
import torch

In [8]:
class MyReLU(torch.autograd.Function):
    '''
    Must extend the torch.autograd.Function and have forward and backward
    methods
    '''
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [9]:
dtype = torch.float
device = torch.device('cpu')

In [10]:
BATCH = 64
IN = 1000
HIDDEN = 100
OUT = 10

ETA = 1e-6

In [11]:
x = torch.randn(BATCH, IN, device=device, dtype=dtype)
y = torch.randn(BATCH, OUT, device=device, dtype=dtype)

W1 = torch.randn(
    IN, HIDDEN, device=device, dtype=dtype, requires_grad=True)
W2 = torch.randn(
    HIDDEN, OUT, device=device, dtype=dtype, requires_grad=True)

In [13]:
for epoch in range(100):
    relu = MyReLU.apply
    y_pred = relu(x.mm(W1)).mm(W2)
    loss = (y_pred - y).pow(2).sum()
    print(epoch, loss.item())
    loss.backward()
    with torch.no_grad():
        W1 -= ETA * W1.grad
        W2 -= ETA * W2.grad
        W1.grad.zero_()
        W2.grad.zero_()

0 33060372.0
1 5842550.0
2 5492347.5
3 5479184.0
4 5698065.0
5 6042964.0
6 6423330.5
7 6717053.0
8 6838993.0
9 6696388.5
10 6295803.0
11 5647687.0
12 4858834.0
13 4013733.0
14 3213731.5
15 2506150.0
16 1921395.25
17 1455071.625
18 1097206.75
19 826463.8125
20 625172.6875
21 476058.46875
22 366137.53125
23 284869.0
24 224606.140625
25 179552.1875
26 145663.8125
27 119888.0625
28 100064.1875
29 84626.6171875
30 72446.2109375
31 62695.33984375
32 54787.9453125
33 48289.1328125
34 42874.04296875
35 38311.34375
36 34420.77734375
37 31068.646484375
38 28153.607421875
39 25599.015625
40 23343.779296875
41 21341.0859375
42 19552.27734375
43 17947.69921875
44 16502.865234375
45 15196.126953125
46 14011.6015625
47 12934.1005859375
48 11952.53125
49 11056.7607421875
50 10237.90625
51 9487.759765625
52 8799.15625
53 8166.677734375
54 7585.46142578125
55 7049.75634765625
56 6555.984375
57 6100.12158203125
58 5679.1337890625
59 5290.06640625
60 4929.94287109375
61 4596.6533203125
62 4287.7900390625
