# Part 3: Training and Fine-tuning on Fashion MNIST and MNIST
Training neural networks with a huge number of parameters on a small dataset greatly affects the networks' generalization ability, often resulting in overfitting. Therefore, more often in practice, one would fine-tune existing networks that are trained on a larger dataset by continuing training on a smaller dataset. To get familiar with the fine-tuning procedure, in this problem you need to train a model from scratch on Fashion MNIST dataset and then fine-tune it on MNIST dataset. Note that we are training models on these two toy datasets because of limited computational resources. In most cases, we train models on ImageNet and fine-tune them on smaller datasets.

* <b>Learning Objective:</b> In part 2, you implemented a covolutional neural network to perform classification task in TensorFlow. In this part of the assignment, we will show you how to use TensorFlow to fine-tune a trained network on a different task.
* <b>Provided Codes:</b> We provide the the dataset downloading and preprocessing codes, conv2d(), and fc() functions to build the model performing the fine-tuning task.
* <b>TODOs:</b> Train a model from scratch on Fashion MNIST dataset and then fine-tune it on MNIST dataset. Both the training loss and the training accuracy need to be shown.

In [4]:
import numpy as np
import os.path as osp
import os
import subprocess

def download_data(download_root='data/', dataset='mnist'):
    if dataset == 'mnist':
        data_url = 'http://yann.lecun.com/exdb/mnist/'
    elif dataset == 'fashion_mnist':
        data_url = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
    else:
        raise ValueError('Please specify mnist or fashion_mnist.')

    data_dir = osp.join(download_root, dataset)
    if osp.exists(data_dir):
        print('The dataset was downloaded.')
        return
    else:
        os.mkdir(data_dir)

    keys = ['train-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz',
            'train-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz']

    for k in keys:
        url = (data_url+k).format(**locals())
        target_path = osp.join(data_dir, k)
        cmd = ['curl', url, '-o', target_path]
        print('Downloading ', k)
        subprocess.call(cmd)
        cmd = ['gzip', '-d', target_path]
        print('Unzip ', k)
        subprocess.call(cmd)


def load_data(data_dir):
    num_train = 60000
    num_test = 10000

    def load_file(filename, num, shape):
        fd = open(osp.join(data_dir, filename))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        return loaded[num:].reshape(shape).astype(np.float)

    train_image = load_file('train-images-idx3-ubyte', 16, (num_train, 28, 28, 1))
    train_label = load_file('train-labels-idx1-ubyte', 8, num_train)
    test_image = load_file('t10k-images-idx3-ubyte', 16, (num_test, 28, 28, 1))
    test_label = load_file('t10k-labels-idx1-ubyte', 8, num_test)
    return train_image, train_label, test_image, test_label

In [1]:
# Download MNIST and Fashion MNIST


In [3]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import matplotlib.pyplot as plt
%matplotlib inline
    #############################################################################
    # TODO: Train the model on Fashion MNIST from scratch                       #
    # and then fine-tune it on MNIST                                            #
    # Collect the training loss and the training accuracy                       #
    # fetched from each iteration                                               #
    # After the two stages of the training, the length of                       #
    # total_loss and total_accuracy shuold be                                   #
    # 2 *num_epoch * num_train / batch_size = 2 * 5 * 60000 / 100 = 6000        #
    #############################################################################
    # Train the model on Fashion MNIST
    
    # Train the model on MNIST


In [3]:
#loss, accuracy on Fashion MNIST & MNIST

In [2]:
# Plot the training loss and the training accuracy
