# Homework: Not So Basic Artificial Neural Networks

Your task is to implement a simple framework for convolutional neural networks training. While convolutional neural networks is a subject of lecture 3, we expect that there are a lot of students who are familiar with the topic.

In order to successfully pass this homework, you will have to:

- Implement all the blocks in `homework_modules.ipynb` (esp `Conv2d` and `MaxPool2d` layers). Good implementation should pass all the tests in `homework_test_modules.ipynb`.
- Settle with a bit of math in `homework_differentiation.ipynb`
- Train a CNN that has at least one `Conv2d` layer, `MaxPool2d` layer and `BatchNormalization` layer and achieves at least 97% accuracy on MNIST test set.

Feel free to use `homework_main-basic.ipynb` for debugging or as source of code snippets. 

Note, that this homework requires sending **multiple** files, please do not forget to include all the files when sending to TA. The list of files:
- This notebook with cnn trained
- `homework_modules.ipynb`
- `homework_differentiation.ipynb`

In [1]:
%matplotlib inline
from time import time, sleep
import numpy as np
import matplotlib.pyplot as plt
from IPython import display

In [2]:
# (re-)load layers
%run homework_modules.ipynb

In [3]:
# batch generator
def get_batches(dataset, batch_size):
    X, Y = dataset
    n_samples = X.shape[0]
        
    # Shuffle at the start of epoch
    indices = np.arange(n_samples)
    np.random.shuffle(indices)
    
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        
        batch_idx = indices[start:end]
    
        yield X[batch_idx], Y[batch_idx]

In [4]:
import mnist
X_train, y_train, X_val, y_val, X_test, y_test = mnist.load_dataset()  # your dataset

In [5]:
h, w = X_train.shape[1:]

In [6]:
# Your turn - train and evaluate conv neural network
# net = Sequential()
# net.add(Linear(2, 4))
# net.add(ReLU())
# net.add(Linear(4, 2))
# net.add(LogSoftMax())

model = Sequential()         # [batch, 1, 28, 28]
model.add(Conv2d(1, 5, 5))   # [batch, 5, 28, 28]
model.add(MaxPool2d(2))      # [batch, 5, 14, 14]
model.add(ReLU())
model.add(Conv2d(5, 10, 5))  # [batch, 10, 14, 14]
model.add(MaxPool2d(2))      # [batch, 10, 7, 7]
model.add(ReLU())
model.add(Flatten())         # [batch, 490]
model.add(Linear(490, 10))
model.add(ReLU())
model.add(BatchNormalization())
model.add(LogSoftMax())
# model.add(Conv2d())

In [7]:
# model = Sequential()             # [batch, 1, 28, 28]
# model.add(Conv2d(1, 20, 5))      # [batch, 20, 28, 28]
# model.add(MaxPool2d(2))          # [batch, 20, 14, 14]
# model.add(ReLU())
# model.add(Conv2d(20, 50, 5))     #
# model.add(MaxPool2d(2))
# model.add(ReLU())
# model.add(Flatten())
# model.add(Linear(500, 10))
# model.add(BatchNormalization())
# model.add(LogSoftMax())

In [8]:
# Iptimizer params
optimizer_config = {'learning_rate' : 1e-1, 'momentum': 0.9}
optimizer_state = {}

# Looping params
n_epoch = 3
batch_size = 128

criterion = ClassNLLCriterion()

In [9]:
loss_history = []
model.training = True
for i in range(n_epoch):
    print('epoch {}'.format(i))
    for x_batch, y_batch in get_batches((X_train, y_train), batch_size):
        
        model.zeroGradParameters()
        x_batch = np.expand_dims(x_batch, 1)
        y_batch_encoded = np.eye(10)[y_batch]
        # Forward
        predictions = model.forward(x_batch)
        loss = criterion.forward(predictions, y_batch_encoded)
    
        # Backward
        dp = criterion.backward(predictions, y_batch_encoded)
        model.backward(x_batch, dp)
        
        # Update weights
        sgd_momentum(model.getParameters(), 
                     model.getGradParameters(), 
                     optimizer_config,
                     optimizer_state)      
        
        loss_history.append(loss)

    # Visualize
    display.clear_output(wait=True)
    plt.figure(figsize=(8, 6))
        
    plt.title("Training loss")
    plt.xlabel("#iteration")
    plt.ylabel("loss")
    plt.plot(loss_history, 'b')
    plt.show()
    
    print('Current loss: %f' % loss)    

epoch 0


KeyboardInterrupt: 

Print here your accuracy on test set. It should be >97%. Don't forget to switch the network in 'evaluate' mode