Whereas in the previous notebook we explicitly referred to `.weights` and `.bias`, this is not an generalizable approach. Instead, we rely on PyTorch's infrastructure to register relevant weights so that they can be retrieved with `.parameters()`, and then retrieve them for learning.

In [1]:
from fastai.vision.all import *


In [2]:
class MLP(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.l1 = nn.Linear(n_in, nh)
        self.l2 = nn.Linear(nh, n_out)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.l2(self.relu(self.l1(x)))



In [3]:
pickle_path = URLs.path('mnist_png')/'mnist_png.pkl'
path = untar_data(URLs.MNIST)/'training'

if not pickle_path.exists():
    pickle_path.parent.mkdir(parents=True, exist_ok=True)
    ds = DataBlock(
        blocks = (ImageBlock(PILImageBW), CategoryBlock),
        get_items = get_image_files,
        get_y = parent_label,
        splitter = RandomSplitter(1/6, seed=0)
    ).datasets(path)

    xs, ys = zip(*ds.train, *ds.valid)
    xs = np.stack(L(map(lambda x: np.array(x, dtype=np.float32).reshape(-1), xs))) / 255.
    ys = np.array(ys, dtype=np.int64)

    x_train, x_valid = xs[:len(ds.train)], xs[len(ds.train):]
    y_train, y_valid = ys[:len(ds.train)], ys[len(ds.train):]

    save_pickle(pickle_path, [x_train, y_train, x_valid, y_valid])

    del ds, xs, ys, x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = map(tensor, load_pickle(pickle_path))


In [4]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]
        
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return x


In [5]:
n, m = x_train.shape
c = y_train.max() + 1
nh = 50

bs = 50                # batch size

lr = 0.5   # learning rate
epochs = 3 # how many epochs to train for


In [6]:
model = MLP(m, nh, 10)
model.l1, model


(Linear(in_features=784, out_features=50, bias=True),
 MLP(
   (l1): Linear(in_features=784, out_features=50, bias=True)
   (l2): Linear(in_features=50, out_features=10, bias=True)
   (relu): ReLU()
 ))

In [7]:
for name, l in model.named_children():
    print(f"{name}: {l}")


l1: Linear(in_features=784, out_features=50, bias=True)
l2: Linear(in_features=50, out_features=10, bias=True)
relu: ReLU()


In [8]:
for p in model.parameters():
    print(p.shape)


torch.Size([50, 784])
torch.Size([50])
torch.Size([10, 50])
torch.Size([10])


In [9]:
loss_func = F.cross_entropy
def accuracy(out, yb):
    return (out.argmax(dim=1) == yb).float().mean()
def report(loss, preds, yb):
    print(f'{loss:.2f}, {accuracy(preds, yb):.2f}')


In [10]:
def fit():
    for epoch in range(epochs):
        for i in range(0, n, bs):
            s = slice(i, min(n,i+bs))
            xb,yb = x_train[s],y_train[s]
            preds = model(xb)
            loss = loss_func(preds, yb)
            loss.backward()
            with torch.no_grad():
                for p in model.parameters(): p -= p.grad * lr
                model.zero_grad()
        report(loss, preds, yb)


In [11]:
fit()

0.22, 0.94
0.17, 0.96
0.10, 0.98
