In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

%matplotlib inline

import tensorflow as tf
import numpy as np
import pickle as pkl
#from sklearn.manifold import TSNE

from flip_gradient import flip_gradient
from utils import *

In [2]:
source_domain = pkl.load(open('./Datasets/Datael4106/source_domain.pkl', 'rb'))
source_data_separation = 13500
source_train = source_domain['images'][:source_data_separation]
source_train_labels = source_domain['labels'][:source_data_separation]
source_test = source_domain['images'][source_data_separation:]
source_test_labels = source_domain['labels'][source_data_separation:]

target_domain = pkl.load(open('./Datasets/Datael4106/target_domain.pkl', 'rb'))
target_data_separation = int(target_domain['images'].shape[0]*0.9)
target_train = target_domain['images'][:target_data_separation]
target_train_labels = target_domain['labels'][:target_data_separation]
target_test = target_domain['images'][target_data_separation:]
target_test_labels = target_domain['labels'][target_data_separation:]

In [3]:
batch_size = 64

class SuperNovaModel(object):
    def __init__(self):
        self._build_model()
    
    def _build_model(self):
        
        self.X = tf.placeholder(tf.float32, [None, 21, 21, 3])
        self.y = tf.placeholder(tf.float32, [None, 2])
        self.domain = tf.placeholder(tf.float32, [None, 2])
        self.l = tf.placeholder(tf.float32, [])
        self.train = tf.placeholder(tf.bool, [])
        
        #X_input = (tf.cast(self.X, tf.float32) - pixel_mean) / 255.
        X_input = self.X
        
        # CNN model for feature extraction
        with tf.variable_scope('feature_extractor'):

            W_conv0 = weight_variable([4, 4, 3, 32])
            b_conv0 = bias_variable([32])
            h_conv0 = tf.nn.relu(conv2d(X_input, W_conv0) + b_conv0)
            
            W_conv1 = weight_variable([3, 3, 32, 32])
            b_conv1 = bias_variable([32])
            h_conv1 = tf.nn.relu(conv2d(h_conv0, W_conv1) + b_conv1)
            
            h_pool0 = max_pool_2x2(h_conv1)
            
            W_conv2 = weight_variable([3, 3, 32, 64])
            b_conv2 = bias_variable([64])
            h_conv2 = tf.nn.relu(conv2d(h_pool0, W_conv2) + b_conv2)
            
            W_conv3 = weight_variable([3, 3, 64, 64])
            b_conv3 = bias_variable([64])
            h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3) + b_conv3)
            
            W_conv4 = weight_variable([3, 3, 64, 64])
            b_conv4 = bias_variable([64])
            h_conv4 = tf.nn.relu(conv2d(h_conv3, W_conv4) + b_conv4)
            
            h_pool1 = max_pool_2x2(h_conv4)
            
            # The domain-invariant feature
            self.feature = tf.reshape(h_pool1, [-1, 6*6*64])
            
        # MLP for class prediction
        with tf.variable_scope('label_predictor',reuse=tf.AUTO_REUSE):
            
            # Switches to route target examples (second half of batch) differently
            # depending on train or test mode.
            all_features = lambda: self.feature
            source_features = lambda: tf.slice(self.feature, [0, 0], [batch_size // 2, -1])
            classify_feats = tf.cond(self.train, source_features, all_features)
            
            all_labels = lambda: self.y
            source_labels = lambda: tf.slice(self.y, [0, 0], [batch_size // 2, -1])
            self.classify_labels = tf.cond(self.train, source_labels, all_labels)
            
            dense0 = tf.layers.dense(inputs=classify_feats,units=6*6*64,activation=tf.nn.relu)
            dropout0 = tf.layers.dropout(inputs=dense0, rate=0.5, training=True)
            d_logits0 = tf.layers.dense(inputs=dropout0, units=64)
            
            dense1 = tf.layers.dense(inputs=d_logits0,units=64,activation=tf.nn.relu)
            dropout1 = tf.layers.dropout(inputs=dense1, rate=0.7, training=True)
            d_logits1 = tf.layers.dense(inputs=dropout1, units=64)
            
            dense2 = tf.layers.dense(inputs=d_logits1,units=64,activation=tf.nn.relu)
            dropout2 = tf.layers.dropout(inputs=dense2, rate=0.7, training=True)
            logits = tf.layers.dense(inputs=dropout2, units=2)
            
            self.pred = tf.nn.softmax(logits)
            self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.classify_labels)


        # Small MLP for domain prediction with adversarial loss
        with tf.variable_scope('domain_predictor'):
            
            # Flip the gradient when backpropagating through this operation
            feat = flip_gradient(self.feature, self.l)
            
            dense0 = tf.layers.dense(inputs=feat,units=6*6*64,activation=tf.nn.relu)
            dropout0 = tf.layers.dropout(inputs=dense0, rate=0.5, training=True)
            d_logits0 = tf.layers.dense(inputs=dropout0, units=64)
            
            dense1 = tf.layers.dense(inputs=d_logits0,units=64,activation=tf.nn.relu)
            dropout1 = tf.layers.dropout(inputs=dense1, rate=0.7, training=True)
            d_logits = tf.layers.dense(inputs=dropout1, units=2)
            
            self.domain_pred = tf.nn.softmax(d_logits)
            self.domain_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logits, labels=self.domain)



In [4]:
# Build the model graph
graph = tf.get_default_graph()
with graph.as_default():
    model = SuperNovaModel()
    
    learning_rate = tf.placeholder(tf.float32, [])
    
    pred_loss = tf.reduce_mean(model.pred_loss)
    domain_loss = tf.reduce_mean(model.domain_loss)
    total_loss = pred_loss + domain_loss

    regular_train_op = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(pred_loss)
    dann_train_op = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(total_loss)
    
    # Evaluation
    correct_label_pred = tf.equal(tf.argmax(model.classify_labels, 1), tf.argmax(model.pred, 1))
    label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))
    correct_domain_pred = tf.equal(tf.argmax(model.domain, 1), tf.argmax(model.domain_pred, 1))
    domain_acc = tf.reduce_mean(tf.cast(correct_domain_pred, tf.float32))


Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.



In [5]:
classes = 2
def fix_label_dimension(labels):
    the_batch = labels.shape[0]
    reshaped = np.zeros(the_batch*classes).reshape(the_batch,classes)
    for i in range(the_batch):
        val = 0
        if labels[i] == 1:
            val = 1
        reshaped[i][val] = 1
    return reshaped


In [None]:
BASE_LOG_DIR = "./logs/"
TARGET_DIR = "target/"
SOURCE_DIR = "source/"

def train_and_evaluate(training_mode, graph, model, num_steps=8600, verbose=False):
    """Helper to run the model with different training modes."""

    with tf.Session(graph=graph) as sess:
                
        tf.global_variables_initializer().run()

        # Batch generators
        gen_source_batch = batch_generator(
            [source_train, source_train_labels], batch_size // 2)
        gen_target_batch = batch_generator(
            [target_train, target_train_labels], batch_size // 2)
        gen_source_only_batch = batch_generator(
            [source_train, source_train_labels], batch_size)
        gen_target_only_batch = batch_generator(
            [target_train, target_train_labels], batch_size)

        domain_labels = np.vstack([np.tile([1., 0.], [batch_size // 2, 1]),
                                   np.tile([0., 1.], [batch_size // 2, 1])])

        # Training loop
        for i in range(num_steps):
            
            # Adaptation param and learning rate schedule as described in the paper
            p = float(i) / num_steps
            gamma = 10.
            l = 2. / (1. + np.exp(-gamma * p)) - 1
            lr = 0.01 / (1. + 10 * p)**0.75

            # Training step
            if training_mode == 'dann':

                X0, y0 = next(gen_source_batch)
                X1, y1 = next(gen_target_batch)
                y0 = fix_label_dimension(y0)
                y1 = fix_label_dimension(y1)
                X = np.vstack([X0, X1])
                y = np.vstack([y0, y1])

                _, batch_loss, dloss, ploss, d_acc, p_acc = sess.run(
                    [dann_train_op, total_loss, domain_loss, pred_loss, domain_acc, label_acc],
                    feed_dict={model.X: X, model.y: y, model.domain: domain_labels,
                               model.train: True, model.l: l, learning_rate: lr})

                if verbose and i % 250 == 0:
                    print('loss: {}  d_acc: {}  p_acc: {}  p: {}  l: {}  lr: {}'.format(
                            batch_loss, d_acc, p_acc, p, l, lr))

            elif training_mode == 'source':
                X, y = next(gen_source_only_batch)
                y = fix_label_dimension(y)
                _, batch_loss = sess.run([regular_train_op, pred_loss],
                                     feed_dict={model.X: X, model.y: y, model.train: False,
                                                model.l: l, learning_rate: lr})

            elif training_mode == 'target':
                X, y = next(gen_target_only_batch)
                y = fix_label_dimension(y)
                _, batch_loss = sess.run([regular_train_op, pred_loss],
                                     feed_dict={model.X: X, model.y: y, model.train: False,
                                                model.l: l, learning_rate: lr})

        # Compute final evaluation on test data
        source_acc = sess.run(label_acc,
                            feed_dict={model.X: source_test, model.y: fix_label_dimension(source_test_labels),
                                       model.train: False})

        target_acc = sess.run(label_acc,
                            feed_dict={model.X: target_test, model.y: fix_label_dimension(target_test_labels),
                                       model.train: False})
        """
        test_domain_acc = sess.run(domain_acc,
                            feed_dict={model.X: combined_test_imgs,
                                       model.domain: combined_test_domain, model.l: 1.0})
        
        test_emb = sess.run(model.feature, feed_dict={model.X: combined_test_imgs})
        """
        
    return source_acc, target_acc
    #return source_acc, target_acc, test_domain_acc, test_emb

"""

print('\nSource only training')
source_acc, target_acc, _, source_only_emb = train_and_evaluate('source', graph, model)
print('Source (Resnet Super Nova) accuracy:', source_acc)
print('Target (Real Super Nova) accuracy:', target_acc)
"""

"""
print('\nSource only training')
source_acc, target_acc = train_and_evaluate('source', graph, model)
print('Source (Super Nova) accuracy:', source_acc)
print('Target (Real Super Nova) accuracy:', target_acc)
"""
print('\nDomain adaptation training')
source_acc, target_acc = train_and_evaluate('dann', graph, model,9000, True)
print('Source (Super Nova) accuracy:', source_acc)
print('Target (Real Super Nova) accuracy:', target_acc)
#print('Domain accuracy:', d_acc)


Domain adaptation training
loss: 4.577471733093262  d_acc: 0.4375  p_acc: 0.5  p: 0.0  l: 0.0  lr: 0.01
loss: 1.2242934703826904  d_acc: 0.765625  p_acc: 0.46875  p: 0.011111111111111112  l: 0.05549847010902642  lr: 0.00924021086472307
loss: 1.3370829820632935  d_acc: 0.703125  p_acc: 0.4375  p: 0.022222222222222223  l: 0.1106561105247379  lr: 0.008602751305990648
loss: 1.3952481746673584  d_acc: 0.5625  p_acc: 0.46875  p: 0.03333333333333333  l: 0.16514041292462944  lr: 0.008059274488676564
loss: 1.3872456550598145  d_acc: 0.5625  p_acc: 0.4375  p: 0.044444444444444446  l: 0.21863508368712115  lr: 0.0075896957722755174
loss: 1.377767562866211  d_acc: 0.5  p_acc: 0.53125  p: 0.05555555555555555  l: 0.2708471185167214  lr: 0.007179362054645374
loss: 1.3831708431243896  d_acc: 0.578125  p_acc: 0.5  p: 0.06666666666666667  l: 0.32151273753163445  lr: 0.006817316198804996
loss: 1.3864843845367432  d_acc: 0.5  p_acc: 0.53125  p: 0.07777777777777778  l: 0.3704019533306313  lr: 0.00649519052

loss: 1.3872383832931519  d_acc: 0.5625  p_acc: 0.4375  p: 0.7222222222222222  l: 0.9985407095480496  lr: 0.002059482435009419
