Finally we demonstrate automatic differentiation with PyTorch's autograd.

In [1]:
from fastai.vision.all import *

In [2]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.w = torch.randn(n_in, n_out).requires_grad_()
        self.b = torch.zeros(n_out).requires_grad_()
    def forward(self, inp):
        return inp @ self.w + self.b


In [3]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [Linear(n_in, nh), nn.ReLU(), Linear(nh, n_out)]
        
    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return F.mse_loss(x, targ[:, None])


In [4]:

pickle_path = URLs.path('mnist_png')/'mnist_png.pkl'
path = untar_data(URLs.MNIST)/'training'

if not pickle_path.exists():
    pickle_path.parent.mkdir(parents=True, exist_ok=True)
    ds = DataBlock(
        blocks = (ImageBlock(PILImageBW), CategoryBlock),
        get_items = get_image_files,
        get_y = parent_label,
        splitter = RandomSplitter(1/6, seed=0)
    ).datasets(path)

    xs, ys = zip(*ds.train, *ds.valid)
    xs = np.stack(L(map(lambda x: np.array(x, dtype=np.float32).reshape(-1), xs))) / 255.
    ys = np.array(ys, dtype=np.int64)

    x_train, x_valid = xs[:len(ds.train)], xs[len(ds.train):]
    y_train, y_valid = ys[:len(ds.train)], ys[len(ds.train):]

    save_pickle(pickle_path, [x_train, y_train, x_valid, y_valid])

    del ds, xs, ys, x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = map(tensor, load_pickle(pickle_path))


In [5]:
n, m = x_train.shape
c = y_train.max() + 1
nh = 50

In [6]:
model = Model(m, nh, 1)
loss = model(x_train, y_train.to(torch.float))
loss.backward()


In [7]:
l0 = model.layers[0]
l0.b.grad


tensor([ 2.9768e+00, -2.7384e+00, -8.0852e+00,  6.2184e-01, -1.1352e+01,
        -2.7832e+01,  8.5264e+01, -4.6342e+01, -8.1143e+00, -1.2306e+01,
        -1.1445e+02,  1.9698e+01,  3.4802e+01,  4.7113e+01,  9.5526e+01,
         3.3336e+01,  4.0315e+00,  6.1828e+00, -2.4543e-01,  1.1454e+01,
         1.6382e+01,  3.3253e+01,  2.9294e-02,  1.2234e+01, -3.0248e+01,
        -5.9352e+01,  2.9054e+01,  7.1098e+01, -1.6847e+01,  1.8057e+01,
         1.4760e+01,  2.0049e+01,  1.1648e+02,  9.3475e+01, -5.9332e-01,
         2.8878e+01,  1.8214e+01,  5.0956e+01,  1.3740e+02, -1.0519e+02,
        -1.1980e+02, -1.6477e+01, -5.6699e+00,  3.5579e+01,  3.3396e+00,
         8.3904e+01, -5.9529e+01, -2.5959e+01,  7.4153e+01,  2.1316e+02])