In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import tensorflow as tf

In [5]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input, Multiply
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv3D, MaxPool3D

In [12]:
class ColoredMNISTEnvironments():
    
    def __init__(self):
        
        self.__load_initial_data()
        self.__create_envs()
        self.__create_validation_envs()

    def __load_initial_data(self):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        # convert to RGB
        x_train = np.stack((x_train,)*3, axis=-1)
        x_test = np.stack((x_test,)*3, axis=-1)

        # normalize
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255
        x_test /= 255

        # binary label
        y_train = (y_train < 5).astype(int)
        y_test = (y_test < 5).astype(int)
        
        self.original_data = {
            'x_train':x_train,
            'x_test':x_test,
            'y_train':y_train,
            'y_test':y_test
        }
        
    def __create_envs(self):
        k=10**4
        self.e1 = self.__create_env(self.original_data['x_train'][:k], 
                                    self.original_data['y_train'][:k], .1)
        self.e2 = self.__create_env(self.original_data['x_train'][k:2*k], 
                                    self.original_data['y_train'][k:2*k], .2)
        self.e3 = self.__create_env(self.original_data['x_train'][2*k:3*k], 
                                    self.original_data['y_train'][2*k:3*k], .9)
        
    def __create_validation_envs(self):
        k=10**4
        i=3*k
        self.e11 = self.__create_env(self.original_data['x_train'][i:i+k], 
                                     self.original_data['y_train'][i:i+k], .1)
        self.e22 = self.__create_env(self.original_data['x_train'][i+k:i+2*k], 
                                     self.original_data['y_train'][i+k:i+2*k], .2)
        self.e33 = self.__create_env(self.original_data['x_train'][i+2*k:i+3*k], 
                                     self.original_data['y_train'][i+2*k:i+3*k], .9)
        
    def __create_env(self, x, y, e, labelflip_proba=.25):
        x = x.copy()
        y = y.copy()

        y = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < labelflip_proba).astype(int)
        ).astype(int)

        color = np.logical_xor(
            y,
            (np.random.random(size=len(y)) < e).astype(int)
        )

        x[color, :, :, 2] = 0
        x[color, :, :, 1] = 0
        return tf.data.Dataset.from_tensor_slices((x, y))
        

In [15]:
def get_model(compile=False):
    
    input_images = Input(shape=(28, 28, 3))
    
    cnn = Conv2D(32, kernel_size=(3, 3),
                 activation='relu')(input_images)
    cnn = Conv2D(64, (3, 3), activation='relu')(cnn)
    cnn = MaxPooling2D(pool_size=(2, 2))(cnn)
    cnn = Dropout(0.25)(cnn)
    cnn = Flatten()(cnn)
    
    env1 = Dense(32, activation='relu')(cnn)
    env1 = Dropout(0.5)(env1)
    env1 = Dense(1, name='env1')(env1)
        
    model = Model(
        inputs=[input_images],
        outputs=[env1]
    )
    
    if compile:
        model.compile(
            loss=[
                tf.keras.losses.binary_crossentropy,
            ],
            optimizer=tf.keras.optimizers.Adadelta(),
            metrics=['accuracy']
        )
    return model

In [13]:
mnist_envs = ColoredMNISTEnvironments()

In [38]:
def train(model, dataset, valid_dataset, epochs, lambdas, 
          dummy=tf.convert_to_tensor([1.]),
          loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True),
          accuracy_object = tf.keras.metrics.Accuracy(),
          optimizer = tf.keras.optimizers.Adam()):
  for epoch in range(epochs):
    lambda_ = lambdas[epoch]
    for (batch, (images, labels)) in enumerate(dataset):
    
      # compute penalty
      with tf.GradientTape() as tape:
        tape.watch(dummy)
        logits = model(images, training=False)
        loss_value = loss_object(labels, logits * dummy)
      accuracy_object.update_state(labels, 
                                   tf.math.greater(
                                       tf.keras.activations.sigmoid(logits),
                                       .5)
                                   )
      grads = tape.gradient(loss_value, dummy)
      penalty = tf.math.reduce_mean(loss_value * tf.math.square(grads)).numpy()
    
      # train
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        loss_value = loss_object(labels, logits)
      grads = tape.gradient(loss_value, model.trainable_variables)
      grads += penalty * lambda_
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
      if not batch % 30:
        tr_acc = accuracy_object.result().numpy()
        accuracy_object.reset_states()
        # validation
        for (v_batch, (v_images, v_labels)) in enumerate(valid_dataset):
          logits = model(v_images, training=False)
          accuracy_object.update_state(v_labels, 
                                       tf.math.greater(
                                         tf.keras.activations.sigmoid(logits),
                                         .5)
                                       )
        v_acc = accuracy_object.result().numpy()
        accuracy_object.reset_states()
        print ('Epoch %3d TrainLoss %.5f Penalty %.5f TrainAcc %.3f TestAcc %.3f' % (
            epoch, loss_value.numpy().mean(), penalty, tr_acc, v_acc 
        ))
        

In [122]:
from collections import defaultdict

class IRMModel(object):
    
    def __init__(self, optimizer = tf.keras.optimizers.Adam()):
        self.model = get_model()
        self.envs = ColoredMNISTEnvironments()
        self.dummy = tf.convert_to_tensor([1.])
        self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.optimizer = optimizer
        self.logs = defaultdict(list)
        
    def evaluate(self, env):
        accuracy = tf.keras.metrics.Accuracy()
        loss = tf.keras.metrics.BinaryCrossentropy()
        per_batch_penalties = []
        for (batch, (x, y)) in enumerate(env):
            with tf.GradientTape() as tape:
                tape.watch(self.dummy)
                logits = self.model(x, training=False)
                dummy_loss = self.loss(y, 
                                       tf.keras.activations.sigmoid(logits) * self.dummy)
            batch_grads = tape.gradient(dummy_loss, [self.dummy])
            per_batch_penalties += [
                tf.math.square(
                    dummy_loss * batch_grads
                )
            ]
            loss.update_state(y, tf.keras.activations.sigmoid(logits))
            accuracy.update_state(y, tf.math.greater(tf.keras.activations.sigmoid(logits), .5))
        return loss.result().numpy(), accuracy.result().numpy(), tf.reduce_mean(per_batch_penalties)
    
    def train(self, epochs, lambda_, batch_size=128):
        for epoch in range(epochs):
            # penalty & train metrics
            penalties = []
            for env_name, env in (('e1',self.envs.e1.batch(batch_size)), 
                                  ('e2',self.envs.e2.batch(batch_size)),):
                env_train_loss, env_train_acc, env_penalty = self.evaluate(env)
                self.logs[env_name+'-train-loss'] += [env_train_loss]
                self.logs[env_name+'-train-acc'] += [env_train_acc]
                self.logs[env_name+'-penalty'] += [env_penalty.numpy()]
                penalties += [env_penalty]
            penalty = tf.reduce_mean(penalties).numpy()
            # training
            for env_name, env in (('e1',self.envs.e1.batch(batch_size)), 
                                  ('e2',self.envs.e2.batch(batch_size)),):
                for (batch, (x, y)) in enumerate(env):
                    
                    with tf.GradientTape() as tape:
                        tape.watch(self.dummy)
                        dummy_logits = self.model(x, training=True)
                        dummy_loss = self.loss(y, dummy_logits * self.dummy)
                    dummy_grads = tape.gradient(dummy_loss, self.dummy)
                    dummy_penalty = lambda_(epoch) * dummy_grads ** 2
                    #print(dummy_penalty)
                    dummy_penalty.set_shape(self.dummy.get_shape())
                    #self.optimizer.apply_gradients(zip(dummy_penalty, self.model.trainable_variables))
                    
                    with tf.GradientTape() as tape:
                        logits = self.model(x, training=True)
                        loss_value = self.loss(y, tf.keras.activations.sigmoid(logits)) + dummy_penalty
                    grads = tape.gradient(loss_value, 
                                          self.model.trainable_variables)
                    self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            # evaluation
            ood_loss, ood_acc, ood_penalty = self.evaluate(self.envs.e3.batch(batch_size))
            self.logs['ood-loss'] += [ood_loss]
            self.logs['ood-acc'] += [ood_acc]
            self.logs['ood-penalty'] += [ood_penalty.numpy()]
            for env_name, env in (('e11',self.envs.e11.batch(batch_size)), 
                                  ('e22',self.envs.e22.batch(batch_size)),):
                env_train_loss, env_train_acc, env_penalty = self.evaluate(env)
                self.logs[env_name+'-test-loss'] += [env_train_loss]
                self.logs[env_name+'-test-acc'] += [env_train_acc]
            # monitoring
            print ('Epoch %3d'%epoch, end=' ')
            prefix=''
            for k, v in sorted(self.logs.items()):
                if k.split('-')[0] != prefix:
                    prefix = k.split('-')[0]
                    print(end='| ')
                print('%s:%.5f' % (k, v[-1]), end=' ')
            print()

In [None]:
irm = IRMModel()
def lambda_scheduler(epoch):
    if epoch < 2:
        return 0
    elif epoch < 10:
        return 100
    else:
        return 10000
irm.train(30, lambda_scheduler)

Epoch   0 | e1-penalty:0.00139 | e1-train-acc:0.51760 e1-train-loss:0.67736 e11-test-acc:0.89810 | e11-test-loss:1.56031 e2-penalty:0.00142 | e2-train-acc:0.51130 e2-train-loss:0.68162 e22-test-acc:0.80210 | e22-test-loss:3.04178 ood-acc:0.10520 | ood-loss:13.74406 ood-penalty:0.09088 
Epoch   1 | e1-penalty:0.00245 | e1-train-acc:0.90570 e1-train-loss:1.44479 e11-test-acc:0.89810 | e11-test-loss:1.56024 e2-penalty:0.00064 | e2-train-acc:0.79900 e2-train-loss:3.06701 e22-test-acc:0.80210 | e22-test-loss:3.04194 ood-acc:0.10520 | ood-loss:13.74440 ood-penalty:0.09088 
Epoch   2 | e1-penalty:0.00245 | e1-train-acc:0.90570 e1-train-loss:1.44468 e11-test-acc:0.89810 | e11-test-loss:1.56038 e2-penalty:0.00064 | e2-train-acc:0.79900 e2-train-loss:3.06696 e22-test-acc:0.80210 | e22-test-loss:3.04194 ood-acc:0.10520 | ood-loss:13.74579 ood-penalty:0.09088 
Epoch   3 | e1-penalty:0.00245 | e1-train-acc:0.90570 e1-train-loss:1.44483 e11-test-acc:0.89810 | e11-test-loss:1.56049 e2-penalty:0.00064

In [None]:
for k, v in irm.logs.items():
    if 'loss' not in k:
        continue
    plt.plot(v, label=k)
plt.legend()

In [None]:
for k, v in irm.logs.items():
    if 'acc' not in k:
        continue
    plt.plot(v, label=k)
plt.legend()

In [None]:
for k, v in irm.logs.items():
    if 'pen' not in k:
        continue
    plt.plot(v, label=k)
plt.legend()

In [39]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e11.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69673 Penalty 0.00000 TrainAcc 0.539 TestAcc 0.795
Epoch   0 TrainLoss 0.39744 Penalty 0.00158 TrainAcc 0.900 TestAcc 0.903
Epoch   0 TrainLoss 0.34157 Penalty 0.00084 TrainAcc 0.898 TestAcc 0.903


In [28]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e2.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.70284 Penalty 0.00004 TrainAcc 0.881 TestAcc 0.793
Epoch   0 TrainLoss 0.37105 Penalty 0.00120 TrainAcc 0.901 TestAcc 0.796
Epoch   0 TrainLoss 0.52029 Penalty 0.03031 TrainAcc 0.897 TestAcc 0.796


In [29]:
train(
    get_model(), 
    mnist_envs.e11.shuffle(256).batch(128), 
    mnist_envs.e22.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69566 Penalty 0.00001 TrainAcc 0.880 TestAcc 0.493
Epoch   0 TrainLoss 0.33700 Penalty 0.00000 TrainAcc 0.890 TestAcc 0.810
Epoch   0 TrainLoss 0.27881 Penalty 0.00393 TrainAcc 0.902 TestAcc 0.810


In [30]:
train(
    get_model(), 
    mnist_envs.e2.shuffle(256).batch(128), 
    mnist_envs.e1.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.68854 Penalty 0.00000 TrainAcc 0.878 TestAcc 0.491
Epoch   0 TrainLoss 0.55199 Penalty 0.00109 TrainAcc 0.782 TestAcc 0.899
Epoch   0 TrainLoss 0.53440 Penalty 0.00001 TrainAcc 0.795 TestAcc 0.899


In [18]:
train(
    get_model(), 
    mnist_envs.e1.shuffle(256).batch(128), 
    mnist_envs.e3.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.69673 Penalty 0.00001 TrainAcc 0.874 TestAcc 0.329
Epoch   0 TrainLoss 0.32922 Penalty 0.00074 TrainAcc 0.886 TestAcc 0.099
Epoch   0 TrainLoss 0.33516 Penalty 0.00146 TrainAcc 0.899 TestAcc 0.099


In [26]:
train(
    get_model(), 
    mnist_envs.e1.concatenate(mnist_envs.e2).shuffle(256).batch(128),
    mnist_envs.e3.shuffle(256).batch(128), 
    epochs = 1, 
    lambdas = [0]
)

Epoch   0 TrainLoss 0.70098 Penalty 0.00001 TrainAcc 0.878 TestAcc 0.379
Epoch   0 TrainLoss 0.38374 Penalty 0.00027 TrainAcc 0.887 TestAcc 0.099
Epoch   0 TrainLoss 0.43384 Penalty 0.00162 TrainAcc 0.898 TestAcc 0.099
Epoch   0 TrainLoss 0.55759 Penalty 0.00041 TrainAcc 0.858 TestAcc 0.099
Epoch   0 TrainLoss 0.51440 Penalty 0.00195 TrainAcc 0.801 TestAcc 0.099
Epoch   0 TrainLoss 0.51244 Penalty 0.00595 TrainAcc 0.792 TestAcc 0.099


In [42]:
train(
    get_model(), 
    mnist_envs.e1.concatenate(mnist_envs.e2).shuffle(256).batch(128),
    mnist_envs.e3.batch(128), 
    optimizer = tf.keras.optimizers.Adam(lr=1e-3),
    epochs = 100, 
    lambdas = [0,0,1,1,5,5,10,10,100,100]+[100 for _ in range(30)]+[1000 for _ in range(100)]
)

Epoch   0 TrainLoss 0.69530 Penalty 0.00001 TrainAcc 0.699 TestAcc 0.180
Epoch   0 TrainLoss 0.34775 Penalty 0.00008 TrainAcc 0.897 TestAcc 0.099
Epoch   0 TrainLoss 0.29930 Penalty 0.00041 TrainAcc 0.899 TestAcc 0.099
Epoch   0 TrainLoss 0.51504 Penalty 0.00152 TrainAcc 0.857 TestAcc 0.099
Epoch   0 TrainLoss 0.49004 Penalty 0.00412 TrainAcc 0.800 TestAcc 0.103
Epoch   0 TrainLoss 0.49998 Penalty 0.00018 TrainAcc 0.795 TestAcc 0.134
Epoch   1 TrainLoss 0.38836 Penalty 0.00542 TrainAcc 0.788 TestAcc 0.207
Epoch   1 TrainLoss 0.35345 Penalty 0.00056 TrainAcc 0.895 TestAcc 0.099
Epoch   1 TrainLoss 0.25682 Penalty 0.00021 TrainAcc 0.897 TestAcc 0.100
Epoch   1 TrainLoss 0.45060 Penalty 0.00483 TrainAcc 0.858 TestAcc 0.120
Epoch   1 TrainLoss 0.43763 Penalty 0.00013 TrainAcc 0.795 TestAcc 0.107
Epoch   1 TrainLoss 0.40321 Penalty 0.00057 TrainAcc 0.795 TestAcc 0.112
Epoch   2 TrainLoss 0.38274 Penalty 0.00265 TrainAcc 0.783 TestAcc 0.136
Epoch   2 TrainLoss 0.66056 Penalty 0.00058 TrainAc

Epoch  18 TrainLoss 0.65585 Penalty 0.00064 TrainAcc 0.485 TestAcc 0.491
Epoch  19 TrainLoss 0.68069 Penalty 0.00059 TrainAcc 0.501 TestAcc 0.491
Epoch  19 TrainLoss 0.60833 Penalty 0.00210 TrainAcc 0.494 TestAcc 0.491
Epoch  19 TrainLoss 0.65107 Penalty 0.00095 TrainAcc 0.493 TestAcc 0.491
Epoch  19 TrainLoss 0.64511 Penalty 0.00107 TrainAcc 0.493 TestAcc 0.491
Epoch  19 TrainLoss 0.67929 Penalty 0.00004 TrainAcc 0.495 TestAcc 0.491
Epoch  19 TrainLoss 0.65077 Penalty 0.00050 TrainAcc 0.490 TestAcc 0.491
Epoch  20 TrainLoss 0.62838 Penalty 0.00156 TrainAcc 0.511 TestAcc 0.491
Epoch  20 TrainLoss 0.63637 Penalty 0.00136 TrainAcc 0.490 TestAcc 0.491
Epoch  20 TrainLoss 0.65323 Penalty 0.00051 TrainAcc 0.488 TestAcc 0.491
Epoch  20 TrainLoss 0.66017 Penalty 0.00061 TrainAcc 0.497 TestAcc 0.491
Epoch  20 TrainLoss 0.65978 Penalty 0.00041 TrainAcc 0.499 TestAcc 0.491
Epoch  20 TrainLoss 0.67520 Penalty 0.00022 TrainAcc 0.488 TestAcc 0.491
Epoch  21 TrainLoss 0.64334 Penalty 0.00072 TrainAc

Epoch  37 TrainLoss 0.68373 Penalty 0.00009 TrainAcc 0.495 TestAcc 0.491
Epoch  37 TrainLoss 0.61106 Penalty 0.00145 TrainAcc 0.489 TestAcc 0.491
Epoch  38 TrainLoss 0.65559 Penalty 0.00225 TrainAcc 0.514 TestAcc 0.491
Epoch  38 TrainLoss 0.69601 Penalty 0.00054 TrainAcc 0.488 TestAcc 0.491
Epoch  38 TrainLoss 0.67328 Penalty 0.00017 TrainAcc 0.492 TestAcc 0.491
Epoch  38 TrainLoss 0.67506 Penalty 0.00029 TrainAcc 0.499 TestAcc 0.491
Epoch  38 TrainLoss 0.64334 Penalty 0.00054 TrainAcc 0.495 TestAcc 0.491
Epoch  38 TrainLoss 0.66000 Penalty 0.00032 TrainAcc 0.487 TestAcc 0.491
Epoch  39 TrainLoss 0.63655 Penalty 0.00176 TrainAcc 0.507 TestAcc 0.491
Epoch  39 TrainLoss 0.64213 Penalty 0.00142 TrainAcc 0.490 TestAcc 0.491
Epoch  39 TrainLoss 0.63971 Penalty 0.00077 TrainAcc 0.495 TestAcc 0.491
Epoch  39 TrainLoss 0.64742 Penalty 0.00019 TrainAcc 0.492 TestAcc 0.491
Epoch  39 TrainLoss 0.66113 Penalty 0.00094 TrainAcc 0.498 TestAcc 0.491
Epoch  39 TrainLoss 0.64639 Penalty 0.00135 TrainAc

Epoch  56 TrainLoss 3.51669 Penalty 43.28022 TrainAcc 0.490 TestAcc 0.491
Epoch  56 TrainLoss 3.27270 Penalty 34.87691 TrainAcc 0.497 TestAcc 0.491
Epoch  56 TrainLoss 3.88842 Penalty 58.55404 TrainAcc 0.496 TestAcc 0.491
Epoch  56 TrainLoss 3.15368 Penalty 31.21534 TrainAcc 0.489 TestAcc 0.491
Epoch  57 TrainLoss 3.04960 Penalty 28.22221 TrainAcc 0.515 TestAcc 0.491
Epoch  57 TrainLoss 3.78259 Penalty 53.91565 TrainAcc 0.488 TestAcc 0.491
Epoch  57 TrainLoss 3.58733 Penalty 45.98753 TrainAcc 0.493 TestAcc 0.491
Epoch  57 TrainLoss 3.83266 Penalty 56.10445 TrainAcc 0.496 TestAcc 0.491
Epoch  57 TrainLoss 3.35366 Penalty 37.57531 TrainAcc 0.496 TestAcc 0.491
Epoch  57 TrainLoss 3.26261 Penalty 34.59906 TrainAcc 0.488 TestAcc 0.491
Epoch  58 TrainLoss 3.60527 Penalty 46.70341 TrainAcc 0.502 TestAcc 0.491
Epoch  58 TrainLoss 3.85468 Penalty 57.10163 TrainAcc 0.491 TestAcc 0.491
Epoch  58 TrainLoss 3.53703 Penalty 44.11044 TrainAcc 0.490 TestAcc 0.491
Epoch  58 TrainLoss 3.32925 Penalty 36

Epoch  74 TrainLoss 5.10643 Penalty 133.14168 TrainAcc 0.489 TestAcc 0.491
Epoch  75 TrainLoss 5.53648 Penalty 169.69397 TrainAcc 0.499 TestAcc 0.491
Epoch  75 TrainLoss 5.47085 Penalty 163.73083 TrainAcc 0.492 TestAcc 0.491
Epoch  75 TrainLoss 6.17694 Penalty 235.66266 TrainAcc 0.489 TestAcc 0.491
Epoch  75 TrainLoss 5.59600 Penalty 175.22749 TrainAcc 0.497 TestAcc 0.491
Epoch  75 TrainLoss 5.18309 Penalty 139.23059 TrainAcc 0.499 TestAcc 0.491
Epoch  75 TrainLoss 5.63491 Penalty 178.90918 TrainAcc 0.485 TestAcc 0.491
Epoch  76 TrainLoss 5.89967 Penalty 205.33162 TrainAcc 0.506 TestAcc 0.491
Epoch  76 TrainLoss 5.13645 Penalty 135.50642 TrainAcc 0.493 TestAcc 0.491
Epoch  76 TrainLoss 5.59128 Penalty 174.78654 TrainAcc 0.491 TestAcc 0.491
Epoch  76 TrainLoss 5.61031 Penalty 176.57764 TrainAcc 0.496 TestAcc 0.491
Epoch  76 TrainLoss 5.71710 Penalty 186.85449 TrainAcc 0.496 TestAcc 0.491
Epoch  76 TrainLoss 5.73645 Penalty 188.75815 TrainAcc 0.488 TestAcc 0.491
Epoch  77 TrainLoss 6.094

Epoch  93 TrainLoss 7.44500 Penalty 412.66125 TrainAcc 0.494 TestAcc 0.491
Epoch  93 TrainLoss 7.12464 Penalty 361.64871 TrainAcc 0.489 TestAcc 0.491
Epoch  93 TrainLoss 6.57541 Penalty 284.29327 TrainAcc 0.498 TestAcc 0.491
Epoch  93 TrainLoss 7.15959 Penalty 366.99728 TrainAcc 0.497 TestAcc 0.491
Epoch  93 TrainLoss 7.17724 Penalty 369.71875 TrainAcc 0.486 TestAcc 0.491
Epoch  94 TrainLoss 8.20724 Penalty 552.82996 TrainAcc 0.501 TestAcc 0.491
Epoch  94 TrainLoss 7.31336 Penalty 391.15677 TrainAcc 0.495 TestAcc 0.491
Epoch  94 TrainLoss 8.01876 Penalty 515.60962 TrainAcc 0.492 TestAcc 0.491
Epoch  94 TrainLoss 7.69401 Penalty 455.46729 TrainAcc 0.495 TestAcc 0.491
Epoch  94 TrainLoss 8.51811 Penalty 618.05701 TrainAcc 0.496 TestAcc 0.491
Epoch  94 TrainLoss 7.96199 Penalty 504.73544 TrainAcc 0.485 TestAcc 0.491
Epoch  95 TrainLoss 7.85102 Penalty 483.92383 TrainAcc 0.512 TestAcc 0.491
Epoch  95 TrainLoss 6.59682 Penalty 287.07990 TrainAcc 0.493 TestAcc 0.491
Epoch  95 TrainLoss 8.005