In [1]:
import itertools
import matplotlib

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K

from tqdm import tqdm
from scipy.spatial import distance

In [2]:
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 22}

matplotlib.rc('font', **font)

In [3]:
np.random.seed(0)

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train / 255.
x_test = x_test / 255.

y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

y = np.argmax(y_train, axis=1)
x_train_1s = x_train[np.where(y==1)]
x_train_8s = x_train[np.where(y==8)]
y_train_1s = y[np.where(y==1)]
y_train_8s = y[np.where(y==8)]

x_train_1s_8s = np.concatenate((x_train_1s, x_train_8s))
y_train_1s_8s = np.concatenate((y_train_1s, y_train_8s))

y_train_1s_8s[np.where(y_train_1s_8s==8)] = 0

order = np.random.permutation(len(x_train_1s_8s))

x_train_1s_8s = x_train_1s_8s[order]
y_train_1s_8s = y_train_1s_8s[order]

print(y_train_1s_8s)

y_train_1s_8s = tf.keras.utils.to_categorical(y_train_1s_8s, 2)

x_train_1s_8s_reshape = [x_train_1s_8s[i].reshape((784)) for i in range(len(x_train_1s_8s))]
print(np.array(x_train_1s_8s_reshape).shape)

[0 1 1 ... 0 0 1]
(12593, 784)


In [5]:
n_hidden_neurons = 4
lr = 0.0001
epochs = 200
batch_size = 8
limit = 256
alpha = 1.0

In [6]:
i = tf.keras.layers.Input(shape=(28*28,))
i2 = tf.keras.layers.Input(shape=(2,))
e = tf.keras.layers.Dense(n_hidden_neurons, activation='relu', name='dense_1')(i)
o = tf.keras.layers.Dense(2, activation='softmax', name='softmax')(e)
model = tf.keras.Model(inputs=[i, i2], outputs=o)

2022-05-25 19:43:03.347465: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2022-05-25 19:43:03.366410: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fedc4119310 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-05-25 19:43:03.366428: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version


In [27]:
def calculate_term(e, c0, c1):
    a = tf.gather(indices=c0, params=e) / (K.max(K.abs(tf.gather(indices=c0, params=e))) + 1e-9)
    b = tf.gather(indices=c1, params=e) / (K.max(K.abs(tf.gather(indices=c1, params=e))) + 1e-9)
    #print(a.shape)
    #print(K.dot(a, K.transpose(b)))
    return K.sum(K.dot(a, K.transpose(b)))

In [28]:
#https://stackoverflow.com/questions/62454500/how-to-use-tensorflow-custom-loss-for-a-keras-model
#@tf.function
def my_loss(y_true, y_pred, e, c0, c1, alpha):
#     p_loss = 0.
#     for i in range(len(c0)):
#         p_loss += (1. - tf.cast(tf.equal(y_true[c0[i]], y_true[c1[i]]), dtype='float32')) \
#                * calculate_term(e, c0[i], c1[i])
    p_loss = (1. - (tf.cast(tf.equal(K.argmax(tf.gather(indices=c0, params=y_true)), 
                                     K.argmax(tf.gather(indices=c1, params=y_true))), dtype='float32'))) \
           * calculate_term(e, c0, c1)
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=y_true, 
                                                                      y_pred=y_pred)

    loss = loss + alpha * (p_loss / len(c0))
    return loss

In [29]:
c0, c1 = [], []
for c in list(itertools.combinations(range(batch_size), 2)):
    c0.append(c[0])
    c1.append(c[1])

In [30]:
print(len(c0))

28


In [31]:
model.add_loss(my_loss(i2, o, e, np.array(c0, dtype='int32'), np.array(c1, dtype='int32'), alpha))

In [32]:
def my_acc(y_true, y_pred):
    y_pred = tf.cast(tf.math.greater(y_pred, tf.constant([0.5])), dtype='float32')
    return tf.cast(tf.math.equal(y_true, y_pred), dtype='float32')

def cce(y_true, y_pred):
    return tf.keras.losses.CategoricalCrossentropy(
        from_logits=False,
        reduction=tf.keras.losses.Reduction.NONE
    )(y_true=y_true, y_pred=y_pred)

def ps_term(y_true, e, c0, c1):
    p_loss = (1. - (tf.cast(tf.equal(K.argmax(tf.gather(indices=c0, params=y_true)), 
                                     K.argmax(tf.gather(indices=c1, params=y_true))), dtype='float32'))) \
           * calculate_term(e, c0, c1)
    return tf.convert_to_tensor(p_loss)
    
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), 
              loss=None)

model.add_metric(my_acc(i2, o), name='acc', aggregation='mean')
model.add_metric(cce(i2, o), name='cce', aggregation='mean')
model.add_metric(ps_term(i2, e, c0, c1), name='ps', aggregation='mean')



In [None]:
model.fit([np.array(x_train_1s_8s_reshape)[:limit], y_train_1s_8s[:limit]],
          y=None,
          batch_size=batch_size,
          epochs=epochs)

Train on 256 samples
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200


Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 69/200
Epoch 70/200
Epoch 71/200
Epoch 72/200
Epoch 73/200
Epoch 74/200
Epoch 75/200
Epoch 76/200
Epoch 77/200
Epoch 78/200
Epoch 79/200
Epoch 80/200
Epoch 81/200
Epoch 82/200
Epoch 83/200
Epoch 84/200
Epoch 85/200
Epoch 86/200
Epoch 87/200
Epoch 88/200
Epoch 89/200
Epoch 90/200
Epoch 91/200
Epoch 92/200
Epoch 93/200
Epoch 94/200
Epoch 95/200
Epoch 96/200
Epoch 97/200
Epoch 98/200
Epoch 99/200
Epoch 100/200
Epoch 101/200
Epoch 102/200
Epoch 103/200
Epoch 104/200
Epoch 105/200
Epoch 106/200
Epoch 107/200
Epoch 108/200
Epoch 109/200
Epoch 110/200
Epoch 111/200
Epoch 112/200
Epoch 113/200
Epoch 114/200
Epoch 115/200
Epoch 116/200
Epoch 117/200
Epoch 118/200
  8/256 [..............................] - ETA: 0s - loss: 2.6725 - acc: 0.7500 - cce: 0.6681 - ps: 0.0000e+00

In [None]:
embed = tf.keras.Model(inputs=i, outputs=e)

In [None]:
embed.layers[1].set_weights(model.layers[1].get_weights())
activations_by_class = {}
a = embed.predict(np.array(x_train_1s_8s_reshape))
for i in tqdm(range(len(x_train_1s_8s_reshape[:limit]))):
    try:
        activations_by_class[np.argmax(y_train_1s_8s, axis=1)[i]] += np.squeeze(a[i])
    except KeyError:
        activations_by_class[np.argmax(y_train_1s_8s, axis=1)[i]] = np.squeeze(a[i])

In [None]:
fig = plt.figure(figsize=(20,20))
for i,k in enumerate(activations_by_class.keys()):
    y_pos = np.arange(len(activations_by_class[k]))
    ax = fig.add_subplot(1, 2, i+1)
    ax.set_title(k)
    ax.barh(y_pos, np.squeeze(activations_by_class[k]), align='center')
    ax.plot([0,0],[-1,len(np.squeeze(activations_by_class[k]))], 'k-')
    ax.set_ylim(-1,len(np.squeeze(activations_by_class[k])))

In [None]:
W = embed.layers[1].get_weights()
fig = plt.figure(figsize=(20,20))
dim = int(np.ceil(np.sqrt(n_hidden_neurons)))
for j in range(n_hidden_neurons):
  x_j = W[0][:,j] / np.sqrt(np.sum(np.dot(W[0][:,j], W[0][:,j].T)))
  ax = fig.add_subplot(dim,dim,j+1)
  ax.imshow(x_j.reshape((28,28)))
  plt.axis('off')
plt.show()