In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import tensorflow as tf

In [3]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input, Multiply
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv3D, MaxPool3D

In [4]:
class ColoredMNISTEnvironments():
    
    def __init__(self):
        
        self.__load_initial_data()
        self.__create_envs()
        self.__create_validation_envs()

    def __load_initial_data(self):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        # convert to RGB
        x_train = np.stack((x_train,)*3, axis=-1)
        x_test = np.stack((x_test,)*3, axis=-1)

        # normalize
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255
        x_test /= 255

        # binary label
        y_train = (y_train < 5).astype(int)
        y_test = (y_test < 5).astype(int)
        
        self.original_data = {
            'x_train':x_train,
            'x_test':x_test,
            'y_train':y_train,
            'y_test':y_test
        }
        
    def __create_envs(self):
        k=10**4
        self.e1 = self.__create_env(self.original_data['x_train'][:k], 
                                    self.original_data['y_train'][:k], .1)
        self.e2 = self.__create_env(self.original_data['x_train'][k:2*k], 
                                    self.original_data['y_train'][k:2*k], .2)
        self.e3 = self.__create_env(self.original_data['x_train'][2*k:3*k], 
                                    self.original_data['y_train'][2*k:3*k], .9)
        
    def __create_validation_envs(self):
        k=10**4
        i=3*k
        self.e11 = self.__create_env(self.original_data['x_train'][i:i+k], 
                                     self.original_data['y_train'][i:i+k], .1)
        self.e22 = self.__create_env(self.original_data['x_train'][i+k:i+2*k], 
                                     self.original_data['y_train'][i+k:i+2*k], .2)
        self.e33 = self.__create_env(self.original_data['x_train'][i+2*k:i+3*k], 
                                     self.original_data['y_train'][i+2*k:i+3*k], .9)
        
    def __create_env(self, x, y, e, labelflip_proba=.25):
        x = x.copy()
        y = y.copy()

        y = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < labelflip_proba).astype(int)
        ).astype(int)

        color = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < e).astype(int)
        )

        x[color, :, :, 2] = 0
        x[color, :, :, 1] = 0
        return tf.data.Dataset.from_tensor_slices((x, y))
        
    def joint_iterator(self, shuffle_sz=256, batch_sz=128):
        def 
        
        

In [90]:
def get_model(n_final_units=32, compile=False):
    
    input_images = Input(shape=(28, 28, 3))
    
    cnn = Conv2D(32, kernel_size=(3, 3),
                 activation='relu')(input_images)
    cnn = Conv2D(64, (3, 3), activation='relu')(cnn)
    cnn = MaxPooling2D(pool_size=(2, 2))(cnn)
    cnn = Dropout(0.25)(cnn)
    cnn = Flatten()(cnn)
    
    env1 = Dense(32, activation='relu')(cnn)
    env1 = Dropout(0.5)(env1)
    env1 = Dense(1, name='env1')(env1)
        
    model = Model(
        inputs=[input_images],
        outputs=[env1]
    )
    
    if compile:
        model.compile(
            loss=[
                tf.keras.losses.binary_crossentropy,
            ],
            optimizer=tf.keras.optimizers.Adadelta(),
            metrics=['accuracy']
        )
    return model

In [91]:
from collections import defaultdict
from datetime import datetime

class IRMModel(object):
    
    def __init__(self, model = get_model(), optimizer = tf.keras.optimizers.Adam()):
        self.model = model
        self.optimizer = optimizer
        self.envs = ColoredMNISTEnvironments()
        self.dummy = tf.convert_to_tensor([1.])
        self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.logs = defaultdict(list)
        self.logdir = "/home/e.diemert/tflogs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.file_writer = tf.summary.create_file_writer(self.logdir + "/metrics")
        self.file_writer.set_as_default()

        
    def evaluate(self, env):
        accuracy = tf.keras.metrics.Accuracy()
        loss = tf.keras.metrics.BinaryCrossentropy()
        per_batch_penalties = []
        for (batch, (x, y)) in enumerate(env):
            with tf.GradientTape() as tape:
                tape.watch(self.dummy)
                logits = self.model(x, training=False)
                dummy_loss = self.loss(y, logits) * self.dummy
            batch_grads = tape.gradient(dummy_loss, [self.dummy])
            per_batch_penalties += [
                tf.math.square(
                    dummy_loss * batch_grads
                )
            ]
            loss.update_state(y, logits)
            accuracy.update_state(y, tf.math.greater(tf.keras.activations.sigmoid(logits), .5))
        return loss.result().numpy(), accuracy.result().numpy(), tf.reduce_mean(per_batch_penalties)
    
    def compute_penalty(self, x, y):
        with tf.GradientTape() as tape:
            tape.watch(self.dummy)
            dummy_logits = self.model(x, training=True)
            dummy_loss = self.loss(y, dummy_logits * self.dummy)
        dummy_grads = tape.gradient(dummy_loss, self.dummy)
        dummy_penalty = dummy_grads ** 2
        return dummy_penalty
        
    def batch_gradients(self, x, y, penalty):
        with tf.GradientTape() as tape:
            tape.watch(penalty)
            logits = self.model(x, training=True)
            loss_value = self.loss(y, logits) + penalty
        grads = tape.gradient(loss_value, self.model.trainable_variables)
        return loss_value, grads
    
    def do_evaluations(self, epoch, print_=True, batch_size=128):
        if print_:
            print('-'*80)
            print("epoch:", epoch)
        ood_loss, ood_acc, ood_penalty = self.evaluate(self.envs.e3.shuffle(2*batch_size).batch(batch_size))
        self.logs['ood-loss'] += [ood_loss]
        self.logs['ood-acc'] += [ood_acc]
        self.logs['ood-penalty'] += [ood_penalty.numpy()]
        self.log_event('ood_loss', ood_loss, epoch)
        self.log_event('ood_acc', ood_acc, epoch)
        self.log_event('ood_pen', ood_penalty.numpy(), epoch)
        if print_:
            print('ood  loss %.5f acc: %.3f'%(ood_loss, ood_acc))   
        for env_name, env in (('e11',self.envs.e11.shuffle(2*batch_size).batch(batch_size)), 
                              ('e22',self.envs.e22.shuffle(2*batch_size).batch(batch_size)),):
            env_loss, env_acc, env_penalty = self.evaluate(env)
            self.logs[env_name+'-test-loss'] += [env_loss]
            self.logs[env_name+'-test-acc'] += [env_acc]
            self.log_event(env_name+'_loss', env_loss, epoch)
            self.log_event(env_name+'_acc', env_acc, epoch)
            self.log_event(env_name+'_pen', env_penalty.numpy(), epoch)
            if print_:
                print('%s loss %.5f acc: %.3f'%(env_name, env_loss, env_acc))
        if print_:
            print('-'*80)

    def log_event(self, event, value, epoch):
        tf.summary.scalar(event, data=value, step=epoch)
                
    def train(self, epochs, lambda_, batch_size=128):
        for epoch in range(epochs):
            d1 = self.envs.e1.shuffle(2*batch_size).batch(batch_size).__iter__()
            d2 = self.envs.e2.shuffle(2*batch_size).batch(batch_size).__iter__()
            batch = 0
            while True:
                try:
                    x1, y1 = d1.next()
                    x2, y2 = d2.next()
                    pen1 = self.compute_penalty(x1, y1)
                    pen2 = self.compute_penalty(x2, y2)
                    self.logs['e1'+'-train-penalty'] += pen1.numpy()
                    self.logs['e2'+'-train-penalty'] += pen2.numpy()
                    pen = tf.reduce_mean([pen1, pen2])
                    l1, grads1 = self.batch_gradients(x1, y1, lambda_(epoch) * pen)
                    l2, grads2 = self.batch_gradients(x2, y2, lambda_(epoch) * pen)
                    grads = [ tf.reduce_mean([grads1[_], grads2[_]], axis=0) for _ in range(len(grads1)) ]
                    self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

                    self.logs['e1'+'-train-loss'] += [l1]
                    self.logs['e2'+'-train-loss'] += [l2]
                    if not batch % 10:
                        print("%4d"%batch, 
                              "tr-l1: %.5f"%l1.numpy(), "tr-p1: %.5f"%(pen1.numpy()*lambda_(epoch)), 
                              "tr-l2: %.5f"%l2.numpy(), "tr-p2: %.5f"%(pen2.numpy()*lambda_(epoch)))
                    batch += 1
                except StopIteration:
                    break
            self.do_evaluations(epoch, batch_size=batch_size)

In [None]:
irm = IRMModel(model = get_model(n_final_units=128), 
               optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3))
def lambda_scheduler(epoch):
    if epoch < 2:
        return 0
    elif epoch < 5:
        return 10000
    else:
        return 100000
irm.train(30, lambda_scheduler, batch_size=256)

   0 tr-l1: 0.68825 tr-p1: 0.00000 tr-l2: 0.69205 tr-p2: 0.00000
   3 tr-l1: 0.43447 tr-p1: 0.00000 tr-l2: 0.51341 tr-p2: 0.00000
   6 tr-l1: 0.47731 tr-p1: 0.00000 tr-l2: 0.52048 tr-p2: 0.00000
   9 tr-l1: 0.30967 tr-p1: 0.00000 tr-l2: 0.54819 tr-p2: 0.00000
  12 tr-l1: 0.34170 tr-p1: 0.00000 tr-l2: 0.46500 tr-p2: 0.00000
  15 tr-l1: 0.35993 tr-p1: 0.00000 tr-l2: 0.52421 tr-p2: 0.00000
  18 tr-l1: 0.39303 tr-p1: 0.00000 tr-l2: 0.43595 tr-p2: 0.00000
  21 tr-l1: 0.33504 tr-p1: 0.00000 tr-l2: 0.54854 tr-p2: 0.00000
  24 tr-l1: 0.36328 tr-p1: 0.00000 tr-l2: 0.49512 tr-p2: 0.00000
  27 tr-l1: 0.38166 tr-p1: 0.00000 tr-l2: 0.53743 tr-p2: 0.00000
  30 tr-l1: 0.33367 tr-p1: 0.00000 tr-l2: 0.47294 tr-p2: 0.00000
  33 tr-l1: 0.39552 tr-p1: 0.00000 tr-l2: 0.54928 tr-p2: 0.00000
  36 tr-l1: 0.29414 tr-p1: 0.00000 tr-l2: 0.46333 tr-p2: 0.00000
  39 tr-l1: 0.18042 tr-p1: 0.00000 tr-l2: 0.39414 tr-p2: 0.00000
--------------------------------------------------------------------------------
epoch: 0


  36 tr-l1: 173.41801 tr-p1: 292.35021 tr-l2: 173.49834 tr-p2: 53.86415
  39 tr-l1: 1619.60046 tr-p1: 3196.16571 tr-l2: 1619.71313 tr-p2: 42.73580
--------------------------------------------------------------------------------
epoch: 6
ood  loss 11.50330 acc: 0.114
e11 loss 1.35769 acc: 0.894
e22 loss 2.71768 acc: 0.791
--------------------------------------------------------------------------------
   0 tr-l1: 606.13275 tr-p1: 789.53719 tr-l2: 606.34174 tr-p2: 422.31246
   3 tr-l1: 1072.59705 tr-p1: 94.71638 tr-l2: 1072.70813 tr-p2: 2049.82460
   6 tr-l1: 528.60510 tr-p1: 712.35346 tr-l2: 528.78760 tr-p2: 344.35086
   9 tr-l1: 452.87686 tr-p1: 717.43303 tr-l2: 452.99728 tr-p2: 187.62981
  12 tr-l1: 304.70078 tr-p1: 413.93936 tr-l2: 304.83075 tr-p2: 194.75438
  15 tr-l1: 438.26999 tr-p1: 874.22030 tr-l2: 438.37668 tr-p2: 1.67936
  18 tr-l1: 739.67218 tr-p1: 268.68649 tr-l2: 739.81818 tr-p2: 1210.05280
  21 tr-l1: 951.71881 tr-p1: 313.29759 tr-l2: 951.95477 tr-p2: 1589.63073
  24 tr-l1

   6 tr-l1: 628.97571 tr-p1: 787.58858 tr-l2: 629.16901 tr-p2: 469.86928
   9 tr-l1: 455.14041 tr-p1: 17.63229 tr-l2: 455.24326 tr-p2: 891.93527
  12 tr-l1: 286.90491 tr-p1: 148.98566 tr-l2: 287.06476 tr-p2: 424.16952
  15 tr-l1: 423.03836 tr-p1: 540.88244 tr-l2: 423.14734 tr-p2: 304.56006
  18 tr-l1: 311.90085 tr-p1: 516.00537 tr-l2: 312.02698 tr-p2: 107.22094
  21 tr-l1: 92.80791 tr-p1: 24.04306 tr-l2: 92.93174 tr-p2: 160.94819
  24 tr-l1: 69.18248 tr-p1: 83.72321 tr-l2: 69.28125 tr-p2: 54.03581
  27 tr-l1: 410.28043 tr-p1: 616.57415 tr-l2: 410.44238 tr-p2: 203.39701
  30 tr-l1: 860.18689 tr-p1: 27.28302 tr-l2: 860.39526 tr-p2: 1692.38765
  33 tr-l1: 132.84303 tr-p1: 264.16534 tr-l2: 132.92201 tr-p2: 0.86572
  36 tr-l1: 632.34473 tr-p1: 1153.93661 tr-l2: 632.45691 tr-p2: 110.23809
  39 tr-l1: 895.75134 tr-p1: 716.81915 tr-l2: 895.64392 tr-p2: 1074.01237
--------------------------------------------------------------------------------
epoch: 13
ood  loss 10.40044 acc: 0.119
e11 loss 1.

  30 tr-l1: 479.09314 tr-p1: 795.03721 tr-l2: 479.25421 tr-p2: 162.63939
  33 tr-l1: 422.65970 tr-p1: 777.49738 tr-l2: 422.81619 tr-p2: 67.23967
  36 tr-l1: 435.70200 tr-p1: 867.77704 tr-l2: 435.79816 tr-p2: 3.08469
  39 tr-l1: 176.08873 tr-p1: 333.97598 tr-l2: 176.20604 tr-p2: 17.76657
--------------------------------------------------------------------------------
epoch: 19
ood  loss 10.31274 acc: 0.145
e11 loss 1.33506 acc: 0.889
e22 loss 2.59996 acc: 0.788
--------------------------------------------------------------------------------
   0 tr-l1: 444.80771 tr-p1: 886.28577 tr-l2: 444.85614 tr-p2: 2.81689
   3 tr-l1: 293.00046 tr-p1: 1.37412 tr-l2: 293.09686 tr-p2: 584.07146
   6 tr-l1: 358.71786 tr-p1: 138.68758 tr-l2: 358.85172 tr-p2: 578.21255
   9 tr-l1: 1110.45605 tr-p1: 482.43487 tr-l2: 1110.71082 tr-p2: 1738.03847
  12 tr-l1: 227.25642 tr-p1: 439.18039 tr-l2: 227.40192 tr-p2: 14.77835
  15 tr-l1: 730.25861 tr-p1: 1198.88233 tr-l2: 730.32355 tr-p2: 261.04646
  18 tr-l1: 582.6