In [1]:
# Code in file autograd/two_layer_net_autograd.py
import torch

device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Create random Tensors for weights; setting requires_grad=True means that we
# want to compute gradients for these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y using operations on Tensors. Since w1 and
  # w2 have requires_grad=True, operations involving these Tensors will cause
  # PyTorch to build a computational graph, allowing automatic computation of
  # gradients. Since we are no longer implementing the backward pass by hand we
  # don't need to keep references to intermediate values.
  y_pred = x.mm(w1).clamp(min=0).mm(w2)
  
  # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
  # is a Python number giving its value.
  loss = (y_pred - y).pow(2).sum()
  print(t, loss.item())

  # Use autograd to compute the backward pass. This call will compute the
  # gradient of loss with respect to all Tensors with requires_grad=True.
  # After this call w1.grad and w2.grad will be Tensors holding the gradient
  # of the loss with respect to w1 and w2 respectively.
  loss.backward()

  # Update weights using gradient descent. For this step we just want to mutate
  # the values of w1 and w2 in-place; we don't want to build up a computational
  # graph for the update steps, so we use the torch.no_grad() context manager
  # to prevent PyTorch from building a computational graph for the updates
  with torch.no_grad():
    w1 -= learning_rate * w1.grad
    w2 -= learning_rate * w2.grad

    # Manually zero the gradients after running the backward pass
    w1.grad.zero_()
    w2.grad.zero_()

0 30313842.0
1 24868706.0
2 20328038.0
3 15414803.0
4 10798616.0
5 7119085.0
6 4632601.5
7 3076083.5
8 2140034.25
9 1568987.5
10 1207572.625
11 965398.0
12 793777.5625
13 665683.6875
14 566085.4375
15 486304.9375
16 421032.75
17 366753.40625
18 321016.78125
19 282176.5625
20 249043.9375
21 220565.5
22 195942.78125
23 174571.703125
24 155941.40625
25 139624.109375
26 125302.96875
27 112674.484375
28 101514.0234375
29 91609.7578125
30 82810.7734375
31 74978.171875
32 67989.796875
33 61741.59765625
34 56144.2890625
35 51120.76953125
36 46605.5859375
37 42541.3984375
38 38883.71875
39 35590.65234375
40 32613.33203125
41 29916.109375
42 27471.88671875
43 25253.13671875
44 23237.65234375
45 21404.9375
46 19734.12109375
47 18208.875
48 16815.478515625
49 15540.6181640625
50 14374.3759765625
51 13305.8330078125
52 12325.337890625
53 11424.703125
54 10597.390625
55 9836.236328125
56 9135.650390625
57 8490.5458984375
58 7895.990234375
59 7347.3779296875
60 6840.99267578125
61 6373.14990234375
62

384 0.0011200535809621215
385 0.0010830069659277797
386 0.0010468747932463884
387 0.0010123841930180788
388 0.000979167758487165
389 0.0009470457443967462
390 0.0009163279319182038
391 0.0008867966243997216
392 0.0008581003057770431
393 0.0008313871803693473
394 0.0008055148646235466
395 0.0007797854486852884
396 0.0007560129160992801
397 0.0007330838707275689
398 0.0007111167069524527
399 0.0006882844027131796
400 0.0006685475818812847
401 0.0006484755431301892
402 0.0006283598486334085
403 0.0006106384098529816
404 0.0005919559625908732
405 0.0005746936076320708
406 0.0005583213642239571
407 0.0005424652481451631
408 0.000526918622199446
409 0.0005121440044604242
410 0.0004983730614185333
411 0.00048456439981237054
412 0.0004710329230874777
413 0.0004581603279802948
414 0.00044666504254564643
415 0.0004334957047831267
416 0.00042230618419125676
417 0.0004104089457541704
418 0.0003996596788056195
419 0.0003903617907781154
420 0.00037982489448040724
421 0.0003695096238516271
422 0.0003

In [2]:
import jax
import jax.numpy as jnp
import numpy as onp

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x = onp.random.randn(N, D_in)
y = onp.random.randn(N, D_out)

# Create random Tensors for weights; setting requires_grad=True means that we
# want to compute gradients for these Tensors during the backward pass.
w1 = jnp.array(onp.random.randn(D_in, H))
w2 = jnp.array(onp.random.randn(H, D_out))

@jax.jit
def loss(w1, w2, x, y):
    y_pred = jnp.dot(jnp.maximum(jnp.dot(x, w1), 0), w2)
    return ((y_pred - y) ** 2).sum()

dloss = jax.jit(jax.grad(loss, argnums=[0,1]))

learning_rate = 1e-6
for t in range(500):
    loss_value = loss(w1, w2, x, y)
    print(t, loss_value)
    d_w1, d_w2 = dloss(w1, w2, x, y)
    w1 -= learning_rate * d_w1
    w2 -= learning_rate * d_w2




0 30240600.0
1 27949160.0
2 29831936.0
3 30861894.0
4 27333242.0
5 19576946.0
6 11437878.0
7 5958857.0
8 3111738.0
9 1797752.1
10 1186809.8
11 876889.3
12 696943.1
13 577497.1
14 489667.44
15 420882.16
16 365005.12
17 318683.4
18 279786.22
19 246822.83
20 218691.45
21 194518.95
22 173611.81
23 155462.58
24 139633.56
25 125772.41
26 113596.08
27 102855.625
28 93356.91
29 84923.71
30 77411.43
31 70701.4
32 64689.633
33 59293.625
34 54445.09
35 50069.074
36 46115.812
37 42535.348
38 39286.457
39 36329.805
40 33637.582
41 31180.621
42 28934.76
43 26879.441
44 24994.201
45 23263.55
46 21672.06
47 20208.088
48 18861.02
49 17616.734
50 16468.158
51 15406.117
52 14422.955
53 13511.82
54 12666.695
55 11882.179
56 11152.98
57 10474.873
58 9844.158
59 9256.335
60 8708.143
61 8196.76
62 7719.424
63 7273.2837
64 6856.7407
65 6467.5454
66 6103.326
67 5762.1533
68 5442.46
69 5142.5713
70 4861.057
71 4596.852
72 4348.671
73 4115.436
74 3896.4011
75 3690.3567
76 3496.4724
77 3313.8298
78 3141.818
79 29

In [3]:
import haiku

In [4]:
import haiku as hk

In [13]:
import haiku as hk
import jax.numpy as jnp

@hk.transform
def make_mlp():
    return hk.Sequential([
      hk.Linear(H), jax.nn.relu,
      hk.Linear(D_out)])

@hk.transform
def loss_fn(y_pred, y):
    return ((y_pred - y) ** 2).sum()

make_mlp.init()

TypeError: init_fn() missing 1 required positional argument: 'rng'

In [6]:
# Initial parameter values are typically random. In JAX you need a key in order
# to generate random numbers and so Haiku requires you to pass one in.
rng = jax.random.PRNGKey(42)

# `init` runs your function, as such we need an example input. Typically you can
# pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
# is not usually data dependent.
images, labels = next(input_dataset)

# The result of `init` is a nested data structure of all the parameters in your
# network. You can pass this into `apply`.
params = loss_obj.init(rng, images, labels)

NameError: name 'input_dataset' is not defined