In [None]:
!pip install -Uqq fastai

In [None]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from fastai.vision.all import *
from fastai import *
from torchvision.datasets import MNIST
from torchvision import transforms

In [None]:
path = untar_data(URLs.MNIST)

In [None]:
dls = ImageDataLoaders.from_folder(path, 'training', 'testing')

In [None]:
dls.show_batch()

In [None]:
x_b, y_b = dls.one_batch()

In [None]:
x_b.max(), x_b.min(), x_b.shape, y_b

In [None]:
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dropout=0.2, *args, **kwargs):
    """
    Returns a sequential module of a conv block that we want to repeat.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=1, **kwargs),
        nn.PReLU(),
        nn.BatchNorm2d(out_channels),
        nn.Dropout2d(dropout)
    )

In [None]:
model = nn.Sequential(
    conv_block(3, 32), # 3x28x28 --> 32x28x28
    conv_block(32, 64), # 32x28x28 --> 64x28x28
    conv_block(64, 64, stride=2), # 64x28x28 --> 64x14x14
    conv_block(64, 128), # 64x14x14 --> 128x14x14
    conv_block(128, 128, stride=2), # 128x14x14 --> 128x7x7
    nn.AdaptiveMaxPool2d(1), # 128x7x7 --> 128
    nn.Flatten(), # 1x1x128 --> 128
    nn.Linear(128, 10) # 128 features --> 10 outputs
)

In [None]:
cbs = [
    EarlyStoppingCallback()
]

In [None]:
learn = Learner(dls, model, nn.CrossEntropyLoss(), metrics=[metrics.accuracy, metrics.error_rate], cbs=cbs)

In [None]:
learn.lr_find()

In [None]:
LR = 1e-3

In [None]:
learn.fit_one_cycle(3, LR)

In [None]:
learn.recorder.plot_loss()

In [None]:
learn.recorder.plot_sched()

In [None]:
# Pick which block we want to inspect
layer = 4
# "hook" the output of the conv layer of that block
with hook_output(model[layer][0]) as hook:
    # Pass one batch through the model, outputs will be stored in hook.stored
    with torch.no_grad():
        _ = learn.model(x_b)
        # Store the outputs
        outputs = hook.stored

In [None]:
# Pick an index in the batch
idx = 2
n_filters = outputs.shape[1]
output_size = outputs.shape[2]
# Pick 9 random filters to visualize
sampled_filters = np.random.choice(range(n_filters), 9)
# Show the image 
show_image(x_b[idx], figsize=(8,8), title=str(y_b[idx]))

In [None]:
# Show the outputs for our 9 randomly sampled filters
fig = plt.figure(figsize=(10, 10))
for i, fidx in enumerate(sampled_filters):
    ax = fig.add_subplot(3, 3, 1+i)
    ax.matshow(
        outputs[idx, fidx].squeeze().cpu().numpy(), 
        cmap='Greys_r'
    )
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_yticklabels([])
    ax.set_yticks([])
fig.tight_layout()