In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import tensorflow as tf

In [5]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input, Multiply
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv3D, MaxPool3D

In [12]:
class ColoredMNISTEnvironments():
    
    def __init__(self):
        
        self.__load_initial_data()
        self.__create_envs()
        self.__create_validation_envs()

    def __load_initial_data(self):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        # convert to RGB
        x_train = np.stack((x_train,)*3, axis=-1)
        x_test = np.stack((x_test,)*3, axis=-1)

        # normalize
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255
        x_test /= 255

        # binary label
        y_train = (y_train < 5).astype(int)
        y_test = (y_test < 5).astype(int)
        
        self.original_data = {
            'x_train':x_train,
            'x_test':x_test,
            'y_train':y_train,
            'y_test':y_test
        }
        
    def __create_envs(self):
        k=10**4
        self.e1 = self.__create_env(self.original_data['x_train'][:k], 
                                    self.original_data['y_train'][:k], .1)
        self.e2 = self.__create_env(self.original_data['x_train'][k:2*k], 
                                    self.original_data['y_train'][k:2*k], .2)
        self.e3 = self.__create_env(self.original_data['x_train'][2*k:3*k], 
                                    self.original_data['y_train'][2*k:3*k], .9)
        
    def __create_validation_envs(self):
        k=10**4
        i=3*k
        self.e11 = self.__create_env(self.original_data['x_train'][i:i+k], 
                                     self.original_data['y_train'][i:i+k], .1)
        self.e22 = self.__create_env(self.original_data['x_train'][i+k:i+2*k], 
                                     self.original_data['y_train'][i+k:i+2*k], .2)
        self.e33 = self.__create_env(self.original_data['x_train'][i+2*k:i+3*k], 
                                     self.original_data['y_train'][i+2*k:i+3*k], .9)
        
    def __create_env(self, x, y, e, labelflip_proba=.25):
        x = x.copy()
        y = y.copy()

        y = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < labelflip_proba).astype(int)
        ).astype(int)

        color = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < e).astype(int)
        )

        x[color, :, :, 2] = 0
        x[color, :, :, 1] = 0
        return tf.data.Dataset.from_tensor_slices((x, y))
        

In [15]:
def get_model(compile=False):
    
    input_images = Input(shape=(28, 28, 3))
    
    cnn = Conv2D(32, kernel_size=(3, 3),
                 activation='relu')(input_images)
    cnn = Conv2D(64, (3, 3), activation='relu')(cnn)
    cnn = MaxPooling2D(pool_size=(2, 2))(cnn)
    cnn = Dropout(0.25)(cnn)
    cnn = Flatten()(cnn)
    
    env1 = Dense(32, activation='relu')(cnn)
    env1 = Dropout(0.5)(env1)
    env1 = Dense(1, name='env1')(env1)
        
    model = Model(
        inputs=[input_images],
        outputs=[env1]
    )
    
    if compile:
        model.compile(
            loss=[
                tf.keras.losses.binary_crossentropy,
            ],
            optimizer=tf.keras.optimizers.Adadelta(),
            metrics=['accuracy']
        )
    return model

In [13]:
mnist_envs = ColoredMNISTEnvironments()

In [38]:
def train(model, dataset, valid_dataset, epochs, lambdas, 
          dummy=tf.convert_to_tensor([1.]),
          loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True),
          accuracy_object = tf.keras.metrics.Accuracy(),
          optimizer = tf.keras.optimizers.Adam()):
  for epoch in range(epochs):
    lambda_ = lambdas[epoch]
    for (batch, (images, labels)) in enumerate(dataset):
    
      # compute penalty
      with tf.GradientTape() as tape:
        tape.watch(dummy)
        logits = model(images, training=False)
        loss_value = loss_object(labels, logits * dummy)
      accuracy_object.update_state(labels, 
                                   tf.math.greater(
                                       tf.keras.activations.sigmoid(logits),
                                       .5)
                                   )
      grads = tape.gradient(loss_value, dummy)
      penalty = tf.math.reduce_mean(loss_value * tf.math.square(grads)).numpy()
    
      # train
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        loss_value = loss_object(labels, logits)
      grads = tape.gradient(loss_value, model.trainable_variables)
      grads += penalty * lambda_
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
      if not batch % 30:
        tr_acc = accuracy_object.result().numpy()
        accuracy_object.reset_states()
        # validation
        for (v_batch, (v_images, v_labels)) in enumerate(valid_dataset):
          logits = model(v_images, training=False)
          accuracy_object.update_state(v_labels, 
                                       tf.math.greater(
                                         tf.keras.activations.sigmoid(logits),
                                         .5)
                                       )
        v_acc = accuracy_object.result().numpy()
        accuracy_object.reset_states()
        print ('Epoch %3d TrainLoss %.5f Penalty %.5f TrainAcc %.3f TestAcc %.3f' % (
            epoch, loss_value.numpy().mean(), penalty, tr_acc, v_acc 
        ))
        

In [39]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e11.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69673 Penalty 0.00000 TrainAcc 0.539 TestAcc 0.795
Epoch   0 TrainLoss 0.39744 Penalty 0.00158 TrainAcc 0.900 TestAcc 0.903
Epoch   0 TrainLoss 0.34157 Penalty 0.00084 TrainAcc 0.898 TestAcc 0.903


In [28]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e2.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.70284 Penalty 0.00004 TrainAcc 0.881 TestAcc 0.793
Epoch   0 TrainLoss 0.37105 Penalty 0.00120 TrainAcc 0.901 TestAcc 0.796
Epoch   0 TrainLoss 0.52029 Penalty 0.03031 TrainAcc 0.897 TestAcc 0.796


In [29]:
train(
    get_model(), 
    mnist_envs.e11.shuffle(256).batch(128), 
    mnist_envs.e22.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69566 Penalty 0.00001 TrainAcc 0.880 TestAcc 0.493
Epoch   0 TrainLoss 0.33700 Penalty 0.00000 TrainAcc 0.890 TestAcc 0.810
Epoch   0 TrainLoss 0.27881 Penalty 0.00393 TrainAcc 0.902 TestAcc 0.810


In [30]:
train(
    get_model(), 
    mnist_envs.e2.shuffle(256).batch(128), 
    mnist_envs.e1.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.68854 Penalty 0.00000 TrainAcc 0.878 TestAcc 0.491
Epoch   0 TrainLoss 0.55199 Penalty 0.00109 TrainAcc 0.782 TestAcc 0.899
Epoch   0 TrainLoss 0.53440 Penalty 0.00001 TrainAcc 0.795 TestAcc 0.899


In [18]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e3.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69673 Penalty 0.00001 TrainAcc 0.874 TestAcc 0.329
Epoch   0 TrainLoss 0.32922 Penalty 0.00074 TrainAcc 0.886 TestAcc 0.099
Epoch   0 TrainLoss 0.33516 Penalty 0.00146 TrainAcc 0.899 TestAcc 0.099


In [26]:
train(
    get_model(), 
    mnist_envs.e1.concatenate(mnist_envs.e2).shuffle(256).batch(128),
    mnist_envs.e3.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.70098 Penalty 0.00001 TrainAcc 0.878 TestAcc 0.379
Epoch   0 TrainLoss 0.38374 Penalty 0.00027 TrainAcc 0.887 TestAcc 0.099
Epoch   0 TrainLoss 0.43384 Penalty 0.00162 TrainAcc 0.898 TestAcc 0.099
Epoch   0 TrainLoss 0.55759 Penalty 0.00041 TrainAcc 0.858 TestAcc 0.099
Epoch   0 TrainLoss 0.51440 Penalty 0.00195 TrainAcc 0.801 TestAcc 0.099
Epoch   0 TrainLoss 0.51244 Penalty 0.00595 TrainAcc 0.792 TestAcc 0.099


In [34]:
train(
    get_model(), 
    mnist_envs.e1.concatenate(mnist_envs.e2).shuffle(256).batch(128),
    mnist_envs.e3.batch(128), 
    epochs = 10, 
    lambdas = [0,0,1,1,5,5,10,10,100,1000]
)

Epoch   0 TrainLoss 0.69155 Penalty 0.00002 TrainAcc 0.735 TestAcc 0.491
Epoch   0 TrainLoss 0.41020 Penalty 0.00093 TrainAcc 0.887 TestAcc 0.099
Epoch   0 TrainLoss 0.32322 Penalty 0.00237 TrainAcc 0.898 TestAcc 0.099
Epoch   0 TrainLoss 0.49490 Penalty 0.00702 TrainAcc 0.859 TestAcc 0.101
Epoch   0 TrainLoss 0.44098 Penalty 0.00200 TrainAcc 0.798 TestAcc 0.104
Epoch   0 TrainLoss 0.49879 Penalty 0.00003 TrainAcc 0.794 TestAcc 0.105
Epoch   1 TrainLoss 0.42940 Penalty 0.00689 TrainAcc 0.784 TestAcc 0.108
Epoch   1 TrainLoss 0.35474 Penalty 0.00009 TrainAcc 0.899 TestAcc 0.099
Epoch   1 TrainLoss 0.29648 Penalty 0.00002 TrainAcc 0.900 TestAcc 0.099
Epoch   1 TrainLoss 0.49480 Penalty 0.00059 TrainAcc 0.855 TestAcc 0.104
Epoch   1 TrainLoss 0.49123 Penalty 0.00192 TrainAcc 0.799 TestAcc 0.105
Epoch   1 TrainLoss 0.43876 Penalty 0.00322 TrainAcc 0.795 TestAcc 0.116
Epoch   2 TrainLoss 0.39052 Penalty 0.00156 TrainAcc 0.789 TestAcc 0.119
Epoch   2 TrainLoss 0.33577 Penalty 0.00003 TrainAc

KeyboardInterrupt: 