In [None]:
import itertools
import matplotlib

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K

from tqdm import tqdm
from scipy.spatial import distance

In [None]:
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 22}

matplotlib.rc('font', **font)

In [None]:
np.random.seed(0)

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train / 255.
x_test = x_test / 255.

y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

y = np.argmax(y_train, axis=1)
x_train_1s = x_train[np.where(y==1)]
x_train_8s = x_train[np.where(y==8)]
y_train_1s = y[np.where(y==1)]
y_train_8s = y[np.where(y==8)]

x_train_1s_8s = np.concatenate((x_train_1s, x_train_8s))
y_train_1s_8s = np.concatenate((y_train_1s, y_train_8s))

y_train_1s_8s[np.where(y_train_1s_8s==8)] = 0

order = np.random.permutation(len(x_train_1s_8s))

x_train_1s_8s = x_train_1s_8s[order]
y_train_1s_8s = y_train_1s_8s[order]

print(y_train_1s_8s)

y_train_1s_8s = tf.keras.utils.to_categorical(y_train_1s_8s, 2)

x_train_1s_8s_reshape = [x_train_1s_8s[i].reshape((784)) for i in range(len(x_train_1s_8s))]
print(np.array(x_train_1s_8s_reshape).shape)

In [None]:
n_hidden_neurons = 4
lr = 0.001
epochs = 200
batch_size = 8
limit = 256
alpha = 0.0 # ignore polysemantic term in loss function.

In [None]:
i = tf.keras.layers.Input(shape=(28*28,))
i2 = tf.keras.layers.Input(shape=(2,))
e = tf.keras.layers.Dense(n_hidden_neurons, activation='relu', name='dense_1')(i)
o = tf.keras.layers.Dense(2, activation='softmax', name='softmax')(e)
model = tf.keras.Model(inputs=[i, i2], outputs=o)

In [None]:
def calculate_term(e, c0, c1):
  a = K.expand_dims(e[c0] / (K.max(K.abs(e[c0])) + 1e-9))
  b = K.expand_dims(e[c1] / (K.max(K.abs(e[c1])) + 1e-9))
  if c0==c1:
    return 0
  else:
    return K.dot(K.transpose(a), b)

In [None]:
#https://stackoverflow.com/questions/62454500/how-to-use-tensorflow-custom-loss-for-a-keras-model
def my_loss(y_true, y_pred, e, c0, c1, alpha):
    p_loss = 0
    for i in range(len(c0)):
      p_loss += calculate_term(e, c0[i], c1[i])

    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=y_true, 
                                                                      y_pred=y_pred)
    loss += alpha * (p_loss / len(c0))
    return loss

In [None]:
c0, c1 = [], []
for c in list(itertools.combinations(range(batch_size), 2)):
    c0.append(c[0])
    c1.append(c[1])

In [None]:
print(len(c0))

In [None]:
model.add_loss(my_loss(i2, o, e, np.array(c0, dtype='int32'), np.array(c1, dtype='int32'), alpha))

In [None]:
def my_acc(y_true, y_pred):
    y_pred = tf.cast(tf.math.greater(y_pred, tf.constant([0.5])), dtype='float32')
    return tf.cast(tf.math.equal(y_true, y_pred), dtype='float32')
    
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), 
              loss=None)

model.add_metric(my_acc(i2, o), name='acc', aggregation='mean')

In [None]:
model.fit([np.array(x_train_1s_8s_reshape)[:limit], y_train_1s_8s[:limit]],
          y=None,
          batch_size=batch_size,
          epochs=epochs)

In [None]:
embed = tf.keras.Model(inputs=i, outputs=e)

In [None]:
embed.layers[1].set_weights(model.layers[1].get_weights())
activations_by_class = {}
a = embed.predict(np.array(x_train_1s_8s_reshape))
for i in tqdm(range(len(x_train_1s_8s_reshape[:limit]))):
    try:
        activations_by_class[np.argmax(y_train_1s_8s, axis=1)[i]] += np.squeeze(a[i])
    except KeyError:
        activations_by_class[np.argmax(y_train_1s_8s, axis=1)[i]] = np.squeeze(a[i])

In [None]:
fig = plt.figure(figsize=(20,20))
for i,k in enumerate(activations_by_class.keys()):
    y_pos = np.arange(len(activations_by_class[k]))
    ax = fig.add_subplot(1, 2, i+1)
    ax.set_title(k)
    ax.barh(y_pos, np.squeeze(activations_by_class[k]), align='center')
    ax.plot([0,0],[-1,len(np.squeeze(activations_by_class[k]))], 'k-')
    ax.set_ylim(-1,len(np.squeeze(activations_by_class[k])))

In [None]:
W = embed.layers[1].get_weights()
fig = plt.figure(figsize=(20,20))
dim = int(np.ceil(np.sqrt(n_hidden_neurons)))
for j in range(n_hidden_neurons):
  x_j = W[0][:,j] / np.sqrt(np.sum(np.dot(W[0][:,j], W[0][:,j].T)))
  ax = fig.add_subplot(dim,dim,j+1)
  ax.imshow(x_j.reshape((28,28)))
  plt.axis('off')
plt.show()