[View in Colaboratory](https://colab.research.google.com/github/cyyeh/2018AI_summer_school/blob/master/MNIST_DANN_v2.ipynb)

# Implement Domain-Adversarial Neural Networks (DANN)

*** 先點擊 File -> Save a copy in Drive，並在Copy上繼續操作。


Reference: 

https://arxiv.org/abs/1505.07818 (JMLR 2015 paper)

https://github.com/pumpikano/tf-dann  (source code)

https://www.tensorflow.org/

![alt text](https://goo.gl/ivg4Q7)


## Get Ready

確認檔案是否齊全，並定義工具函數。

1.   以!wget下載所需要的檔案並用!ls列出，應有datalab, mnistm_data.pkl, DANN.png, MNIST_model.png。
2.   定義之後會用到的函數，如：convolution, max pooling, shuffle...等。



In [0]:
!wget https://www.dropbox.com/s/e227qfod9b5f0ed/mnistm_data.pkl
!wget https://www.dropbox.com/s/da5ukrtqlusex1z/MNIST_model.png
!wget https://www.dropbox.com/s/77v4rase9pz4lv7/DANN.png

In [0]:
### list files on workspace ###
!ls

In [0]:
### construct utilities ### 
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# Model construction utilities below adapted from
# https://www.tensorflow.org/versions/r0.8/tutorials/mnist/pros/index.html#deep-mnist-for-experts
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')


def shuffle_aligned_list(data):
    """Shuffle arrays in a list by shuffling each array identically."""
    num = data[0].shape[0]
    p = np.random.permutation(num)
    return [d[p] for d in data]


def batch_generator(data, batch_size, shuffle=True):
    """Generate batches of data.
    
    Given a list of array-like objects, generate batches of a given
    size by yielding a list of array-like objects corresponding to the
    same slice of each input.
    """
    if shuffle:
        data = shuffle_aligned_list(data)

    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len(data[0]):
            batch_count = 0

            if shuffle:
                data = shuffle_aligned_list(data)

        start = batch_count * batch_size
        end = start + batch_size
        batch_count += 1
        yield [d[start:end] for d in data]


def imshow_grid(images, shape=[2, 8]):
    """Plot images in a grid of a given shape."""
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    plt.show()


def plot_embedding(X, y, d, title=None):
    """Plot an embedding X with the class label y colored by the domain d."""
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    # Plot colors numbers
    plt.figure(figsize=(10,10))
    ax = plt.subplot(111)
    for i in range(X.shape[0]):
        # plot colored number
        plt.text(X[i, 0], X[i, 1], str(y[i]),
                 color=plt.cm.bwr(d[i] / 1.),
                 fontdict={'weight': 'bold', 'size': 9})

    plt.xticks([]), plt.yticks([])
    if title is not None:
        plt.title(title)
        
print ('Library and utility functions prepared.')

## Data Processing

In [0]:
%matplotlib inline
# To prevent warning created by loading MNIST 
tf.logging.set_verbosity(tf.logging.ERROR)

import tensorflow as tf
import numpy as np
import pickle as pkl
from sklearn.manifold import TSNE
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Process MNIST
# Generate mnist dataset with 3 channel
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)

# Load MNIST-M
mnistm = pkl.load(open('mnistm_data.pkl', 'rb'))
mnistm_train = mnistm['train']
mnistm_test = mnistm['test']
mnistm_valid = mnistm['valid']

# Compute pixel mean for normalizing data
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))

# Create a mixed dataset for TSNE visualization
num_test = 500
combined_test_imgs = np.vstack([mnist_test[:num_test], mnistm_test[:num_test]])
combined_test_labels = np.vstack([mnist.test.labels[:num_test], mnist.test.labels[:num_test]])
combined_test_domain = np.vstack([np.tile([1., 0.], [num_test, 1]),
        np.tile([0., 1.], [num_test, 1])])


## Show Images Example

In [0]:
print ('MNIST examples:')
imshow_grid(mnist_train)
print ('MNIST-M examples:')
imshow_grid(mnistm_train)

## Show Model Architecture

In [0]:
### Execute this to show general DANN model ###
from IPython.display import Image
Image('DANN.png', width=900, height=400)

In [0]:
### Execute this to show DANN model on MNIST dataset ###
from IPython.display import Image
Image('MNIST_model.png', width=900, height=200)

## Define Gradient Reversal Layer

TensorFlow中提供了讓使用者可以自定義gradient方法，閱讀以下程式碼需要的幾個觀念。

1.   Gradient reversal layer  
在實作上，將一層identity layer的gradient乘上一個負號，並且保持feed forward的方式不變，即為gradient reversal layer。

2.   gradient_override_map     
在TensorFlow之中有許多已定義好的運算，例如：identity, matmul等。gradient_override_map函式的輸入為一個dictionary，以一個運算的名字作為key，另一個運算的名字作為value。格式為：gradient_override_map({"Opearation A":"Opearation B" })，雙引號代表型態為string。
這個函式會將運算A計算gradient的方法替換為運算B計算gradient的方法，並且保持運算A的feed forward的算法不變。所以我們可以自定義出一個新的運算(FlipGradient)，將它的gradient算法乘上負號，再透過gradient_override_map將identity運算的gradient算法替換成FlipGradient的gradient算法。

3.   decorator @  
由於gradient_override_map的輸入為TensorFlow之中已定義的運算名稱，因此我們需要將自定義的FlipGradient的運算名稱(型態為string)註冊到TensorFlow，完成這件事情需要使用到decorator的概念。  
在python當中，可以函式當作輸入，輸入到其它函式或是類別(class)作操作。Decorator的作用就是將一個函式輸入到另外一個函式或類別，並將結果再assign回原來的函式，就如同它的名字：裝飾器將函式丟到另一個函式或類別裝飾一番之後，再將原本的函式取代，具體來說，Decorator的作用為：
function(x) = decorator(function(x))  
語法上，在函式的上面一行以@decorator_name來呼叫，範例如下：  
```
    @my_decorator
    def my_func(stuff):
        do_things()
    以上三行的作用相當於以下三行
    def my_func(stuff):
        do_things()
    my_func = my_decorator(my_func)
```
4.  @RegisterGradient(grad_name)  
def _flip_gradients(op, grad):  
用decorator將函式_flip_gradients輸入RegisterGradient這個類別，後者會將前者以"grad_name"這個名字註冊，註冊之後的gradient函式就會變成TensorFlow已知的運算函式，就能夠用第1點所介紹的方法，呼叫flip_gradient取代identity的gradient算法。

總結這段程式碼所做的事情：註冊自定義的gradient運算、將Identity的gradient算法替換掉、將以上操作包裝成一個類別FlipGradientBuilder方便互叫。之後以將類別FlipGradientBuilder當作函式呼叫時(\__call__)，輸入x會通過一層gradient算法自定義的Identity layer(即gradient reversal layer)，並將輸出y回傳。


---
Reference:

register gradient:
https://www.tensorflow.org/api_docs/python/tf/RegisterGradient

operation:
https://www.tensorflow.org/api_docs/python/tf/Operation

decorator (@ symbol):
https://blog.techbridge.cc/2018/06/15/python-decorator-introduction/

gradient_override_map:
https://stackoverflow.com/questions/41391718/tensorflows-gradient-override-map-function



In [6]:
### construct gradient reversal layer ###

from tensorflow.python.framework import ops


class FlipGradientBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, l=1.0):
        grad_name = "FlipGradient%d" % self.num_calls
        @ops.RegisterGradient(grad_name)
        def _flip_gradients(op, grad):
            return [tf.negative(grad) * l]
        
        g = tf.get_default_graph()
        with g.gradient_override_map({"Identity": grad_name}):
            y = tf.identity(x)
            
        self.num_calls += 1
        return y
    
flip_gradient = FlipGradientBuilder()

print('Construct gradient reversal layer.')

Construct gradient reversal layer.


##Define Model

DANN的設計上有個特別之處，source domain的資料會進到label predictor和domain classifier，但target domain的資料只會進到domain classifier，TensorFlow裡的tf.cond可以幫助我們做出分流的效果，以下是它的格式，通常有3個輸入：  
tf.cond(pred, true, false)   
pred為判斷式，如果pred的結果為真，則選擇true回傳，反之選擇false回傳。 
    
另外，為了方便將source和target的資料分開，會事先將source資料放在batch的前半部，target資料放在batch的後半部。

---
Reference:  
tf.cond:  
https://www.tensorflow.org/api_docs/python/tf/cond  


In [0]:
batch_size = 64

class MNISTModel(object):
    """Simple MNIST domain adaptation model."""
    def __init__(self):
        self._build_model()
  
    def _build_model(self):
        
       # image size = 28 x 28 x 3 
        self.X = tf.placeholder(tf.uint8, [None, 28, 28, 3])
        # number of class = 10 (0~9)
        self.y = tf.placeholder(tf.float32, [None, 10])
        # domain=1: mnist / domain=0: mnist-m 
        self.domain = tf.placeholder(tf.float32, [None, 2])
        self.l = tf.placeholder(tf.float32, [])
        self.train = tf.placeholder(tf.bool, [])
        
        X_input = (tf.cast(self.X, tf.float32) - pixel_mean) / 255.
        
        # CNN model for feature extraction (Green Part)
        with tf.variable_scope('feature_extractor'):

            W_conv0 = weight_variable([5, 5, 3, 32])
            b_conv0 = bias_variable([32])
            h_conv0 = tf.nn.relu(conv2d(X_input, W_conv0) + b_conv0)
            h_pool0 = max_pool_2x2(h_conv0)
            
            W_conv1 = weight_variable([5, 5, 32, 48])
            b_conv1 = bias_variable([48])
            h_conv1 = tf.nn.relu(conv2d(h_pool0, W_conv1) + b_conv1)
            h_pool1 = max_pool_2x2(h_conv1)
            
            # The domain-invariant feature
            self.feature = tf.reshape(h_pool1, [-1, 7*7*48])
            
        # MLP for class prediction (Blue Part)
        with tf.variable_scope('label_predictor'):
            
            # Switches to route target examples (second half of batch) differently
            # depending on train or test mode.
            all_features = lambda: self.feature
            source_features = lambda: tf.slice(self.feature, [0, 0], [batch_size // 2, -1])
            classify_feats = tf.cond(self.train, source_features, all_features)
            
            all_labels = lambda: self.y
            source_labels = lambda: tf.slice(self.y, [0, 0], [batch_size // 2, -1])
            self.classify_labels = tf.cond(self.train, source_labels, all_labels)
            
            W_fc0 = weight_variable([7 * 7 * 48, 100])
            b_fc0 = bias_variable([100])
            h_fc0 = tf.nn.relu(tf.matmul(classify_feats, W_fc0) + b_fc0)

            W_fc1 = weight_variable([100, 100])
            b_fc1 = bias_variable([100])
            h_fc1 = tf.nn.relu(tf.matmul(h_fc0, W_fc1) + b_fc1)

            W_fc2 = weight_variable([100, 10])
            b_fc2 = bias_variable([10])
            logits = tf.matmul(h_fc1, W_fc2) + b_fc2
            
            self.pred = tf.nn.softmax(logits)
            self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.classify_labels)

        # Small MLP for domain prediction with adversarial loss (Pink Part)
        with tf.variable_scope('domain_predictor'):
            
            # Flip the gradient when backpropagating through this operation
            feat = flip_gradient(self.feature, self.l)
            
            d_W_fc0 = weight_variable([7 * 7 * 48, 100])
            d_b_fc0 = bias_variable([100])
            d_h_fc0 = tf.nn.relu(tf.matmul(feat, d_W_fc0) + d_b_fc0)
            
            d_W_fc1 = weight_variable([100, 2])
            d_b_fc1 = bias_variable([2])
            d_logits = tf.matmul(d_h_fc0, d_W_fc1) + d_b_fc1
            
            self.domain_pred = tf.nn.softmax(d_logits)
            self.domain_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logits, labels=self.domain)


In [0]:
# Build the model graph
graph = tf.get_default_graph()
with graph.as_default():
    model = MNISTModel()
    
    learning_rate = tf.placeholder(tf.float32, [])
    
    pred_loss = tf.reduce_mean(model.pred_loss)
    domain_loss = tf.reduce_mean(model.domain_loss)
    total_loss = pred_loss + domain_loss

    regular_train_op = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(pred_loss)
    dann_train_op = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(total_loss)
    
    # Evaluation
    correct_label_pred = tf.equal(tf.argmax(model.classify_labels, 1), tf.argmax(model.pred, 1))
    label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))
    correct_domain_pred = tf.equal(tf.argmax(model.domain, 1), tf.argmax(model.domain_pred, 1))
    domain_acc = tf.reduce_mean(tf.cast(correct_domain_pred, tf.float32))


## Training Process

Training分為三種模式：dann, source, target，一個模式training大約2分鐘。


1.   dann  
使用source和target domain的影像，以及source domain的標籤。
2.   source  
僅使用source domain的影像以及標籤。
3.   target   
僅使用target domain的影像以及標籤。一般來說，target domain的標籤是沒辦法取得，也不允許被使用的。這個模式使用這些標籤只是為了產生上界(upper bound)，作為對照組比較。


In [0]:
def train_and_evaluate(training_mode, graph, model, num_steps=8600, verbose=False):
    """Helper to run the model with different training modes."""

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()

        # Batch generators
        gen_source_batch = batch_generator(
            [mnist_train, mnist.train.labels], batch_size // 2)
        gen_target_batch = batch_generator(
            [mnistm_train, mnist.train.labels], batch_size // 2)
        gen_source_only_batch = batch_generator(
            [mnist_train, mnist.train.labels], batch_size)
        gen_target_only_batch = batch_generator(
            [mnistm_train, mnist.train.labels], batch_size)

        domain_labels = np.vstack([np.tile([1., 0.], [batch_size // 2, 1]),
                                   np.tile([0., 1.], [batch_size // 2, 1])])

        # Training loop
        for i in range(num_steps):
            
            # Adaptation param and learning rate schedule as described in the paper
            p = float(i) / num_steps
            l = 2. / (1. + np.exp(-10. * p)) - 1
            lr = 0.01 / (1. + 10 * p)**0.75

            # Training step
            if training_mode == 'dann':

                X0, y0 = next(gen_source_batch)
                X1, y1 = next(gen_target_batch)
                X = np.vstack([X0, X1])
                y = np.vstack([y0, y1])

                _, batch_loss, dloss, ploss, d_acc, p_acc = sess.run(
                    [dann_train_op, total_loss, domain_loss, pred_loss, domain_acc, label_acc],
                    feed_dict={model.X: X, model.y: y, model.domain: domain_labels,
                               model.train: True, model.l: l, learning_rate: lr})

                if verbose and i % 100 == 0:
                    print('loss: {}  d_acc: {}  p_acc: {}  p: {}  l: {}  lr: {}'.format(
                            batch_loss, d_acc, p_acc, p, l, lr))

            elif training_mode == 'source':
                X, y = next(gen_source_only_batch)
                _, batch_loss = sess.run([regular_train_op, pred_loss],
                                     feed_dict={model.X: X, model.y: y, model.train: False,
                                                model.l: l, learning_rate: lr})

            elif training_mode == 'target':
                X, y = next(gen_target_only_batch)
                _, batch_loss = sess.run([regular_train_op, pred_loss],
                                     feed_dict={model.X: X, model.y: y, model.train: False,
                                                model.l: l, learning_rate: lr})

        # Compute final evaluation on test data
        source_acc = sess.run(label_acc,
                            feed_dict={model.X: mnist_test, model.y: mnist.test.labels,
                                       model.train: False})

        target_acc = sess.run(label_acc,
                            feed_dict={model.X: mnistm_test, model.y: mnist.test.labels,
                                       model.train: False})
        
        test_domain_acc = sess.run(domain_acc,
                            feed_dict={model.X: combined_test_imgs,
                                       model.domain: combined_test_domain, model.l: 1.0})
        
        test_emb = sess.run(model.feature, feed_dict={model.X: combined_test_imgs})
        
    return source_acc, target_acc, test_domain_acc, test_emb


print('\nSource only training')
source_acc, target_acc, _, source_only_emb = train_and_evaluate('source', graph, model)
print('Source (MNIST) accuracy:', source_acc)
print('Target (MNIST-M) accuracy:', target_acc)

print('\nDomain adaptation training')
source_acc, target_acc, d_acc, dann_emb = train_and_evaluate('dann', graph, model)
print('Source (MNIST) accuracy:', source_acc)
print('Target (MNIST-M) accuracy:', target_acc)

print('\nTarget only training')
source_acc, target_acc, _, source_only_emb = train_and_evaluate('target', graph, model)
print('Source (MNIST) accuracy:', source_acc)
print('Target (MNIST-M) accuracy:', target_acc)

## Visualization

將實驗結果視覺化，藍色數字為source domain資料的預測結果，紅色數字為target domain資料的預測結果。  
從source模式的結果(上圖)可以看出，藍色部份的分群已非常完整，代表source domain的正確率很高，而紅色部份則比較混亂。  
然而dann模式的結果(下圖)，藍色部份的分群保持完整，而且紅色部份的分群有顯著改善，證明DANN對target domain具有明顯改善效果。

In [0]:
# Take few minutes to visualize
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000)
source_only_tsne = tsne.fit_transform(source_only_emb)

tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000)
dann_tsne = tsne.fit_transform(dann_emb)
        
plot_embedding(source_only_tsne, combined_test_labels.argmax(1), combined_test_domain.argmax(1), 'Source only')
plot_embedding(dann_tsne, combined_test_labels.argmax(1), combined_test_domain.argmax(1), 'Domain Adaptation')