In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("dark_background")
%matplotlib inline

from datasets import load_mnist, download_mnist
from mygrad.nnet.layers.pooling import max_pool
from mygrad.tensor_manip.tiling.funcs import repeat

from mynn.layers import conv
from mynn.activations import relu, sigmoid
from mynn.losses import mean_squared_loss
from mynn.initializers import glorot_uniform
from mynn.optimizers import Adam

In [None]:
download_mnist()

In [None]:
train_data, train_labels, val_data, val_labels = load_mnist()

train_data = train_data / 255 # [0, 255] -> [0, 1]
val_data = val_data / 255

In [None]:
class Model:
    """ A simple auto-encoder. """

    def __init__(self):
        init = glorot_uniform
        args = {'gain': np.sqrt(2)}
        self.conv11 = conv(1, 8, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv21 = conv(8, 16, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv31 = conv(16, 16, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv4 = conv(16, 16, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv32 = conv(16, 16, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv22 = conv(16, 8, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        self.conv12 = conv(8, 1, 3, 3, padding=1, weight_initializer=init, weight_kwargs=args)
        
    def __call__(self, x):
        """ Perform a forward pass of the model.
        
        Parameters
        ----------
        x : Union[numpy.ndarray, mygrad.Tensor]
            The data to send through the model.
            
        Returns
        -------
        mygrad.Tensor
            The reconstructed input.
        """
        x = relu(self.conv21(relu(self.conv11(x))))
        x = max_pool(x, (2, 2), 2)                   # 28x28 -> 14x14
        x = relu(self.conv31(x))
        x = max_pool(x, (2, 2), 2)                   # 14x14 -> 7x7
        x = relu(self.conv4(x))
        x = repeat(repeat(x, 2, axis=2), 2, axis=3)  # 7x7 -> 14x14
        x = relu(self.conv22(relu(self.conv32(x))))
        x = repeat(repeat(x, 2, axis=2), 2, axis=3)  # 14x14 -> 28x28
        x = sigmoid(self.conv12(x))
        return x
        
    @property
    def parameters(self):
        """ Access the parameters of the model.
        
        Returns
        -------
        Tuple[mygrad.Tensor, ...]
            The parameters of the model.
        """
        params = []
        for layer in (
            self.conv11,
            self.conv21,
            self.conv31,
            self.conv4,
            self.conv12,
            self.conv22,
            self.conv32,
        ):
            params += list(layer.parameters)
        return params

In [None]:
m = Model()
optim = Adam(m.parameters, learning_rate=1e-04)

In [None]:
def train_epoch(batch_size=128):
    """ Train the model for one epoch. """
    idxs = np.arange(len(train_data)) # shuffle our data
    np.random.shuffle(idxs)
    
    for batch in range(0, len(idxs), batch_size):
        data = train_data[idxs[batch:batch+batch_size]]
        outs = m(data)                        # get the model output
        loss = mean_squared_loss(outs, data)  # compute the loss
        loss.backward()                       # backpropagate the loss
        optim.step()                          # update the model weights
        loss.null_gradients()                 # clear the gradients
        print(f"Batch {batch // batch_size} / {len(idxs) // batch_size}: loss {loss.data:0.4f}", end="\r")

In [None]:
for epoch in range(5):
    print(f"Starting epoch {epoch}")
    train_epoch()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 8))
idx = np.random.randint(len(train_data))
ax[0].imshow(train_data[idx].squeeze(), "gray")
ax[0].set_title("Original")

out = m(train_data[idx][np.newaxis])
ax[1].imshow(out.data.squeeze(), "gray")
ax[1].set_title("Reconstructed")