PyTorch's `nn.Module` overrides the `__strattr__` function to register submodules to the model. Then `parameters` recurses through these submodules. We demo a partial reimplementation below.

In [1]:
from fastai.vision.all import *


In [2]:
class MyModule:
    def __init__(self, n_in, nh, n_out):
        self._modules = {}
        self.l1 = nn.Linear(n_in,nh)
        self.l2 = nn.Linear(nh,n_out)

    def __setattr__(self, k, v):
        if not k.startswith("_"):
            self._modules[k] = v
        super().__setattr__(k, v)

    def __repr__(self):
        return f'{self._modules}'
    
    def parameters(self):
        for l in self._modules.values():
            yield from l.parameters()


In [3]:
pickle_path = URLs.path('mnist_png')/'mnist_png.pkl'
path = untar_data(URLs.MNIST)/'training'

if not pickle_path.exists():
    pickle_path.parent.mkdir(parents=True, exist_ok=True)
    ds = DataBlock(
        blocks = (ImageBlock(PILImageBW), CategoryBlock),
        get_items = get_image_files,
        get_y = parent_label,
        splitter = RandomSplitter(1/6, seed=0)
    ).datasets(path)

    xs, ys = zip(*ds.train, *ds.valid)
    xs = np.stack(L(map(lambda x: np.array(x, dtype=np.float32).reshape(-1), xs))) / 255.
    ys = np.array(ys, dtype=np.int64)

    x_train, x_valid = xs[:len(ds.train)], xs[len(ds.train):]
    y_train, y_valid = ys[:len(ds.train)], ys[len(ds.train):]

    save_pickle(pickle_path, [x_train, y_train, x_valid, y_valid])

    del ds, xs, ys, x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = map(tensor, load_pickle(pickle_path))



In [4]:
n, m = x_train.shape
c = y_train.max() + 1
nh = 50

bs = 50                # batch size

lr = 0.5   # learning rate
epochs = 3 # how many epochs to train for


In [5]:
mdl = MyModule(m, nh, 10)
mdl


{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}

In [6]:
for p in mdl.parameters():
    print(p.shape)


torch.Size([50, 784])
torch.Size([50])
torch.Size([10, 50])
torch.Size([10])
