# PyTorch Examples

### Hand-Made Optimization

In [19]:
import torch
import random

In [2]:
dtype = torch.float
device = torch.device('cuda:0')

In [3]:
n, d_in, h, d_out = 64, 1000, 100, 10

In [4]:
def gauss_init(*size, grad=False):
    return torch.randn(
        *size, device=device, dtype=dtype, requires_grad=grad)

In [5]:
x = gauss_init(n, d_in)
y = gauss_init(n, d_out)

In [6]:
w1 = gauss_init(d_in, h, grad=True)
w2 = gauss_init(h, d_out, grad=True)

In [7]:
class ReLU(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, tensor):
        ctx.save_for_backward(tensor)
        return tensor.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        tensor, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[tensor < 0] = 0
        return grad_input

In [8]:
lr = 1e-6
for t in range(500):
    relu = ReLU.apply
    y_pred = relu(x.mm(w1)).mm(w2)
    loss = (y_pred - y).pow(2).sum()
    if t % 50 == 49:
        print(f'{t + 1:3d}: {loss.item():2.3f}')
    loss.backward()
    with torch.no_grad():
        w1 -= lr * w1.grad
        w2 -= lr * w2.grad
        w1.grad.zero_()
        w2.grad.zero_()

 50: 17476.141
100: 926.359
150: 78.971
200: 8.314
250: 0.996
300: 0.130
350: 0.018
400: 0.003
450: 0.001
500: 0.000


### High-Level Wrappers

In [9]:
import torch.nn as nn

In [10]:
n, d_in, h, d_out = 64, 1000, 100, 10

In [11]:
x = torch.randn(n, d_in)
y = torch.randn(n, d_out)

In [14]:
def create_model(n_in, n_out, hidden):
    return nn.Sequential(
        nn.Linear(n_in, hidden),
        nn.ReLU(),
        nn.Linear(hidden, n_out))

In [15]:
model = create_model(d_in, d_out, h)

In [16]:
loss_fn = nn.MSELoss(reduction='sum')
lr = 1e-4

for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    model.zero_grad()
    loss.backward()
    
    with torch.no_grad():
        for param in model.parameters():
            param -= lr * param.grad

0 656.9984130859375
1 610.76806640625
2 570.6632690429688
3 535.0881958007812
4 503.22576904296875
5 474.0798645019531
6 447.24542236328125
7 422.4091491699219
8 399.4298095703125
9 378.1115417480469
10 358.1604919433594
11 339.4648132324219
12 321.9073181152344
13 305.31048583984375
14 289.5698547363281
15 274.5966491699219
16 260.2539978027344
17 246.54222106933594
18 233.4724578857422
19 220.9554901123047
20 208.98426818847656
21 197.55950927734375
22 186.6103973388672
23 176.169677734375
24 166.21859741210938
25 156.74423217773438
26 147.7276611328125
27 139.17779541015625
28 131.0972137451172
29 123.45451354980469
30 116.22522735595703
31 109.40193176269531
32 102.9587631225586
33 96.88693237304688
34 91.15994262695312
35 85.75618743896484
36 80.67202758789062
37 75.88554382324219
38 71.38914489746094
39 67.16324615478516
40 63.191627502441406
41 59.44762420654297
42 55.93144989013672
43 52.632930755615234
44 49.539390563964844
45 46.63654708862305
46 43.91141128540039
47 41.35310

Or even higher:

In [18]:
model = create_model(d_in, d_out, h)
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 659.7439575195312
1 509.11090087890625
2 398.7759704589844
3 315.2510681152344
4 250.39820861816406
5 197.99574279785156
6 155.25906372070312
7 120.68888092041016
8 93.35316467285156
9 72.41675567626953
10 56.80306625366211
11 45.37241744995117
12 37.13508605957031
13 31.282773971557617
14 27.118690490722656
15 24.274131774902344
16 22.35956382751465
17 21.03483009338379
18 20.03095054626465
19 19.151653289794922
20 18.26032257080078
21 17.223970413208008
22 16.011198043823242
23 14.648751258850098
24 13.194315910339355
25 11.703246116638184
26 10.253011703491211
27 8.902661323547363
28 7.668289661407471
29 6.575948715209961
30 5.636777400970459
31 4.851404190063477
32 4.220165729522705
33 3.72855281829834
34 3.357003927230835
35 3.080597162246704
36 2.8746516704559326
37 2.710175037384033
38 2.564077615737915
39 2.419102191925049
40 2.2672433853149414
41 2.106081962585449
42 1.938624382019043
43 1.7660573720932007
44 1.5917608737945557
45 1.4217323064804077
46 1.2604904174804688
47 

378 8.900892176089137e-12
379 9.950396409608153e-12
380 9.479477966478633e-12
381 1.03895945666177e-11
382 9.869464620559931e-12
383 9.409639734059283e-12
384 9.091978038500148e-12
385 8.654951755282525e-12
386 8.748583454898373e-12
387 9.26712699633736e-12
388 1.0567278822815052e-11
389 1.0595058684559344e-11
390 9.954990824734278e-12
391 8.468659801197376e-12
392 1.019943938335155e-11
393 9.019391136733113e-12
394 9.697418215659503e-12
395 9.737591809277912e-12
396 9.88620817155006e-12
397 9.475691065130576e-12
398 1.0397168369313814e-11
399 9.840643924730053e-12
400 9.759054675484435e-12
401 9.815187725081831e-12
402 1.0050993023980048e-11
403 9.56973823101892e-12
404 1.02861712203417e-11
405 9.982699562816055e-12
406 9.365376529846259e-12
407 9.37100310544059e-12
408 9.970998852970592e-12
409 9.670951539586525e-12
410 8.88483384087202e-12
411 9.422816693582803e-12
412 8.837818497864358e-12
413 8.587594177433822e-12
414 8.99909660678766e-12
415 9.427633153313852e-12
416 9.3163168152

### Custom Models

In [21]:
class DynamicNet(torch.nn.Module):
    
    def __init__(self, d_in, h, d_out):
        super().__init__()
        self.input = torch.nn.Linear(d_in, h)
        self.hid = torch.nn.Linear(h, h)
        self.out = torch.nn.Linear(h, d_out)
        
    def forward(self, x):
        h_relu = self.input(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.hid(h_relu).clamp(min=0)
        y_pred = self.out(h_relu)
        return y_pred

In [22]:
n, d_in, h, d_out = 64, 1000, 100, 10
x = torch.randn(n, d_in)
y = torch.randn(n, d_out)

In [23]:
model = DynamicNet(d_in, h, d_out)

In [24]:
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 632.2785034179688
1 629.6087036132812
2 628.8663330078125
3 626.359619140625
4 623.2782592773438
5 610.8251342773438
6 625.4854125976562
7 640.6768188476562
8 611.5213623046875
9 559.4812622070312
10 596.5060424804688
11 606.8108520507812
12 592.0465087890625
13 604.9483642578125
14 570.630615234375
15 603.3956298828125
16 586.6893310546875
17 601.2735595703125
18 599.6747436523438
19 528.777587890625
20 595.2610473632812
21 502.68060302734375
22 589.1446533203125
23 468.232421875
24 551.0310668945312
25 541.9570922851562
26 570.1929931640625
27 518.3908081054688
28 503.78765869140625
29 362.59710693359375
30 531.4132690429688
31 262.67828369140625
32 309.8822937011719
33 289.5657043457031
34 265.58929443359375
35 461.9869689941406
36 217.54742431640625
37 195.22694396972656
38 172.5098419189453
39 173.1092529296875
40 132.93177795410156
41 141.65406799316406
42 260.1700134277344
43 332.4174499511719
44 298.0873107910156
45 111.49224090576172
46 202.90924072265625
47 84.430908203125


374 0.2299944907426834
375 0.5247659087181091
376 0.15967269241809845
377 1.001258373260498
378 0.365831196308136
379 0.9101892709732056
380 0.7365949749946594
381 0.41733309626579285
382 0.2767772674560547
383 0.10856592655181885
384 0.3884049654006958
385 2.359867572784424
386 0.3563497066497803
387 1.5377123355865479
388 0.4109915494918823
389 1.77914559841156
390 0.2908664643764496
391 0.29451417922973633
392 0.4459017813205719
393 0.5956968665122986
394 2.792590856552124
395 0.19102106988430023
396 0.22566959261894226
397 0.1916453242301941
398 1.2661058902740479
399 1.2612683773040771
400 11.058863639831543
401 1.2765499353408813
402 5.298357009887695
403 22.176620483398438
404 1.7295728921890259
405 4.586020469665527
406 32.28859329223633
407 0.18667566776275635
408 3.5252554416656494
409 10.149150848388672
410 3.3871796131134033
411 0.3173706531524658
412 0.4350637197494507
413 2.6338613033294678
414 0.7793236970901489
415 34.47877502441406
416 0.6375674605369568
417 13.1125745