In [1]:
## Create a Siamese Network with Triplet Loss in Keras
import socket
print(socket.gethostname())

psanagpu108


In [2]:
%matplotlib notebook

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import random
import sys
import h5py
from sklearn.decomposition import PCA

print('TensorFlow version:', tf.__version__)

TensorFlow version: 2.2.0


In [3]:
particle2idx = {
    '1fpv': 0,
    '1ss8': 1,
    '3j03': 2,
    '1ijg': 3,
    '3iyf': 4,
    '6ody': 5,
    '6sp2': 6,
    '6xs6': 7,
    '7dwz': 8,
    '7dx8': 9,
    '7dx9': 10
}

count2idx = {
    'single': 0,
    'double': 1,
    'triple': 2,
    'quadruple': 3
}

In [4]:
idx2particle = {
    0: '1fpv',
    1: '1ss8',
    2: '3j03',
    3: '1ijg',
    4: '3iyf',
    5: '6ody',
    6: '6sp2',
    7: '6xs6',
    8: '7dwz',
    9: '7dx8',
    10: '7dx9'
}

idx2count = {
    0: 'single',
    1: 'double',
    2: 'triple',
    3: 'quadruple'
}

In [5]:
# Import diffraction image set

def load_data(num_train_samples=200, num_test_samples=50, normalize=None, seed=None):
    """
    num_train_samples: number of training samples
    num_test_samples: number of test samples
    normalize: type of intensity normalization {'variance'}
    seed: random seed
    """
    # num_train_sample must be divisible by 2
    path = '/reg/data/ana03/scratch/xericfl/DeepProjection/DeepProjection/resnet/eric_data/'
    
    ### Change these names
    fnames = ['1fpv_5k_single_pps_1e14_thumbnail.h5',
            '6sp2_5k_single_pps_1e14_thumbnail.h5',
            '1ijg_5k_single_pps_1e14_thumbnail.h5',
            '6xs6_5k_single_pps_1e14_thumbnail.h5',
            '1ss8_5k_single_pps_1e14_thumbnail.h5',
            '7dwz_5k_single_pps_1e14_thumbnail.h5',
            '3iyf_5k_single_pps_1e14_thumbnail.h5',
            '7dx8_5k_single_pps_1e14_thumbnail.h5',
            '3j03_5k_single_pps_1e14_thumbnail.h5',
            '7dx9_5k_single_pps_1e14_thumbnail.h5',
            '6ody_5k_single_pps_1e14_thumbnail.h5',
            '1fpv_5k_double_pps_1e14_thumbnail.h5',
            '6sp2_5k_double_pps_1e14_thumbnail.h5',
            '1ijg_5k_double_pps_1e14_thumbnail.h5',
            '6xs6_5k_double_pps_1e14_thumbnail.h5',
            '1ss8_5k_double_pps_1e14_thumbnail.h5',
            '7dwz_5k_double_pps_1e14_thumbnail.h5',
            '3iyf_5k_double_pps_1e14_thumbnail.h5',
            '7dx8_5k_double_pps_1e14_thumbnail.h5',
            '3j03_5k_double_pps_1e14_thumbnail.h5',
            '7dx9_5k_double_pps_1e14_thumbnail.h5',
            '6ody_5k_double_pps_1e14_thumbnail.h5',
            '1fpv_5k_triple_pps_1e14_thumbnail.h5',
            '6sp2_5k_triple_pps_1e14_thumbnail.h5',
            '1ijg_5k_triple_pps_1e14_thumbnail.h5',
            '6xs6_5k_triple_pps_1e14_thumbnail.h5',
            '1ss8_5k_triple_pps_1e14_thumbnail.h5',
            '7dwz_5k_triple_pps_1e14_thumbnail.h5',
            '3iyf_5k_triple_pps_1e14_thumbnail.h5',
            '7dx8_5k_triple_pps_1e14_thumbnail.h5',
            '3j03_5k_triple_pps_1e14_thumbnail.h5',
            '7dx9_5k_triple_pps_1e14_thumbnail.h5',
            '6ody_5k_triple_pps_1e14_thumbnail.h5',
            '1fpv_5k_quadruple_pps_1e14_thumbnail.h5',
            '6sp2_5k_quadruple_pps_1e14_thumbnail.h5',
            '1ijg_5k_quadruple_pps_1e14_thumbnail.h5',
            '6xs6_5k_quadruple_pps_1e14_thumbnail.h5',
            '1ss8_5k_quadruple_pps_1e14_thumbnail.h5',
            '7dwz_5k_quadruple_pps_1e14_thumbnail.h5',
            '3iyf_5k_quadruple_pps_1e14_thumbnail.h5',
            '7dx8_5k_quadruple_pps_1e14_thumbnail.h5',
            '3j03_5k_quadruple_pps_1e14_thumbnail.h5',
            '7dx9_5k_quadruple_pps_1e14_thumbnail.h5',
            '6ody_5k_quadruple_pps_1e14_thumbnail.h5']
    
    numFiles = len(fnames)
    
    if seed is not None: 
        np.random.seed(seed)
    # Get image dimensions
    with h5py.File(path+fnames[0],'r') as f:
        img = f['photons'][0,:,:]
        nR,nC = img.shape
            
    # TRAIN
    x_train = np.empty((num_train_samples, nR, nC), dtype='float32')
    y_train = np.empty((num_train_samples,), dtype='uint8')
    index = np.random.randint(0,numFiles,num_train_samples)
    for i, fname in enumerate(fnames):
        ind = np.where(index == i)[0]
        if len(ind) > 0:
            with h5py.File(path+fname,'r') as f:
                x_train[ind] = f['photons'][0:len(ind),:,:]
                print(fname, len(ind))
                
                for idx in range(len(idx2particle)):
                    if idx2particle[idx] in fname:
                        y_train[ind] = idx
                

    # TEST
    x_test = np.empty((num_test_samples, nR, nC), dtype='float32')
    y_test = np.empty((num_test_samples,), dtype='uint8')
    index = np.random.randint(0,numFiles,num_test_samples)
    for i, fname in enumerate(fnames):
        ind = np.where(index == i)[0] 
        if len(ind) > 0:
            with h5py.File(path+fname,'r') as f:
                # apply offset since first num_train_samples have been used
                offset = round(num_train_samples/numFiles)
                x_test[ind] = f['photons'][offset:len(ind)+offset,:,:]
                
                for idx in range(len(idx2particle)):
                    if idx2particle[idx] in fname:
                        y_test[ind] = idx
        
    #x_train = tf.expand_dims(x_train, axis=1)
    #x_train = x_train.numpy()

    #x_test = tf.expand_dims(x_test, axis=1)
    #x_test = x_test.numpy()
    
    return (x_train, y_train), (x_test, y_test)

In [6]:
# Load data
(x_train, y_train), (x_test, y_test) = load_data(num_train_samples=200, num_test_samples=50, normalize=None, seed=None)

1fpv_5k_single_pps_1e14_thumbnail.h5 9
6sp2_5k_single_pps_1e14_thumbnail.h5 6
1ijg_5k_single_pps_1e14_thumbnail.h5 5
6xs6_5k_single_pps_1e14_thumbnail.h5 5
1ss8_5k_single_pps_1e14_thumbnail.h5 3
7dwz_5k_single_pps_1e14_thumbnail.h5 5
3iyf_5k_single_pps_1e14_thumbnail.h5 5
7dx8_5k_single_pps_1e14_thumbnail.h5 2
3j03_5k_single_pps_1e14_thumbnail.h5 3
7dx9_5k_single_pps_1e14_thumbnail.h5 3
6ody_5k_single_pps_1e14_thumbnail.h5 2
1fpv_5k_double_pps_1e14_thumbnail.h5 6
6sp2_5k_double_pps_1e14_thumbnail.h5 4
1ijg_5k_double_pps_1e14_thumbnail.h5 3
6xs6_5k_double_pps_1e14_thumbnail.h5 4
1ss8_5k_double_pps_1e14_thumbnail.h5 7
7dwz_5k_double_pps_1e14_thumbnail.h5 8
3iyf_5k_double_pps_1e14_thumbnail.h5 6
7dx8_5k_double_pps_1e14_thumbnail.h5 3
3j03_5k_double_pps_1e14_thumbnail.h5 3
7dx9_5k_double_pps_1e14_thumbnail.h5 4
6ody_5k_double_pps_1e14_thumbnail.h5 2
1fpv_5k_triple_pps_1e14_thumbnail.h5 3
6sp2_5k_triple_pps_1e14_thumbnail.h5 2
1ijg_5k_triple_pps_1e14_thumbnail.h5 2
6xs6_5k_triple_pps_1e14_t

In [7]:
# vectorize dataset
print(x_train.shape)
(num_train, dim_r, dim_c) = x_train.shape
num_test = x_test.shape[0]
img_dim = dim_r * dim_c

x_train = np.reshape(x_train, (num_train, img_dim)) # reshape to vectors & normalize intensity
x_test = np.reshape(x_test, (num_test, img_dim))

(200, 128, 128)


In [8]:
# plotting triplet (anchor, positive, negative)
def plot_triplet(triplet, labels):
    plt.figure(figsize=(6,2))
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.imshow(np.reshape(triplet[i], (dim_r, dim_c)), cmap='binary')
        plt.xticks([])
        plt.yticks([])
        
        if i == 0: plt.title('anchor: {}'.format(idx2particle[labels[i].item()]))
        if i == 1: plt.title('positive: {}'.format(idx2particle[labels[i].item()]))
        if i == 2: plt.title('negative: {}'.format(idx2particle[labels[i].item()]))
    plt.show()

In [9]:
# create a batch of triplets
def create_batch(batch_size, anchor_label=None):
    anchors = np.zeros((batch_size, img_dim))
    positives = np.zeros((batch_size, img_dim))
    negatives = np.zeros((batch_size, img_dim))
    
    #### TODO: convert to numpy arrays for efficiency.
    anchor_labels = np.zeros((batch_size,), dtype=int)
    positive_labels = np.zeros((batch_size,), dtype=int)
    negative_labels = np.zeros((batch_size,), dtype=int)
    
    for i in range(batch_size):
        # Pick an anchor that matches anchor_label if given
        if anchor_label:
            indices_for_anc = np.squeeze(np.where(y_train == anchor_label))
            index = indices_for_anc[np.random.randint(0, len(indices_for_anc)-1)]
        else:
            index = np.random.randint(0, num_train-1)
        anc = x_train[index]
        y = y_train[index]
        
        indices_for_pos = np.squeeze(np.where(y_train == y))
        indices_for_neg = np.squeeze(np.where(y_train != y))
        
        pos_idx = indices_for_pos[np.random.randint(0, len(indices_for_pos)-1)]
        neg_idx = indices_for_neg[np.random.randint(0, len(indices_for_neg)-1)]
        
        pos = x_train[pos_idx]
        neg = x_train[neg_idx]
        
        anchors[i] = anc
        positives[i] = pos
        negatives[i] = neg
        
        #### TODO: conver to numpy arrays
        anchor_labels[i] = y
        positive_labels[i] = y_train[pos_idx]
        negative_labels[i] = y_train[neg_idx]
        
        
    # Returns image batches and (anc, pos, neg) labels
    return [anchors, positives, negatives], [anchor_labels, positive_labels, negative_labels]

In [10]:
# visualize a batch of triplet
triplet, labels = create_batch(batch_size=1, anchor_label=1)
plot_triplet(triplet, labels)

<IPython.core.display.Javascript object>

In [11]:
# Embedding model (2 dense layers w/ relu and sigmoid activation)
#### Play with this value during testing.
emb_dim = 64

# Dense implements the operation: output = activation(dot(input, kernel) + bias)

#### TODO: Implement CNN or Resnet model you want for embeddings here
embedding_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(emb_dim, activation='relu', input_shape=(img_dim,)),
    tf.keras.layers.Dense(emb_dim, activation='sigmoid')
])

embedding_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 64)                1048640   
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
Total params: 1,052,800
Trainable params: 1,052,800
Non-trainable params: 0
_________________________________________________________________


In [12]:
example = x_train[0]
example_emb = embedding_model.predict(np.expand_dims(example, axis=0))
print("example embedding: ", example_emb, example_emb.shape)

example embedding:  [[9.4853413e-01 9.6557206e-01 2.8427631e-02 8.2111853e-04 2.5722035e-03
  8.5019842e-03 9.6162361e-01 9.7033453e-01 1.2850632e-04 1.4093502e-03
  1.4436816e-01 1.2295955e-07 9.9813277e-01 4.4753689e-02 8.7166530e-01
  9.5472413e-01 5.2598127e-05 1.2837446e-02 9.5952612e-01 3.4524432e-01
  8.1070352e-01 2.7635890e-01 3.8083854e-05 1.6171287e-03 1.7525910e-05
  8.6938787e-01 9.8784757e-01 1.0000000e+00 1.0643871e-02 3.3560891e-02
  9.9592751e-01 8.8809931e-01 9.9893278e-01 9.7113198e-01 1.0000000e+00
  9.9799430e-01 3.1509513e-05 1.5605147e-02 8.9048856e-01 9.9962270e-01
  9.9706405e-01 3.2561343e-02 6.0468596e-01 8.4481835e-03 9.9999857e-01
  6.3189269e-08 1.8165328e-02 1.6858324e-01 9.9984467e-01 1.0724999e-02
  6.6092191e-04 2.5776723e-01 8.9098209e-01 6.6776864e-02 4.8348814e-01
  9.9992275e-01 9.8882550e-01 9.9995995e-01 1.5124576e-02 6.1243668e-06
  9.9858308e-01 6.4468267e-04 8.9017719e-01 3.8862386e-01]] (1, 64)


In [13]:
plt.figure(figsize=(5,5))
plt.plot(example_emb[0],'x-'); plt.xlabel("features"); plt.ylabel("weight"); plt.show()

<IPython.core.display.Javascript object>

In [14]:
# Siamese network
in_anc = tf.keras.layers.Input(shape=(img_dim,))
in_pos = tf.keras.layers.Input(shape=(img_dim,))
in_neg = tf.keras.layers.Input(shape=(img_dim,))

em_anc = embedding_model(in_anc)
em_pos = embedding_model(in_pos)
em_neg = embedding_model(in_neg)

out = tf.keras.layers.concatenate([em_anc, em_pos, em_neg], axis=1)

net = tf.keras.models.Model(
    [in_anc, in_pos, in_neg],
    out
)

net.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 16384)]      0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 16384)]      0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 16384)]      0                                            
__________________________________________________________________________________________________
sequential (Sequential)         (None, 64)           1052800     input_1[0][0]                    
                                                                 input_2[0][0]                

# Triplet Loss
A loss function that tries to pull the Embeddings of Anchor and Positive Examples closer, and tries to push the Embeddings of Anchor and Negative Examples away from each other.

Root mean square difference between Anchor and Positive examples in a batch of N images is:
$
\begin{equation}
d_p = \sqrt{\frac{\sum_{i=0}^{N-1}(f(a_i) - f(p_i))^2}{N}}
\end{equation}
$

Root mean square difference between Anchor and Negative examples in a batch of N images is:
$
\begin{equation}
d_n = \sqrt{\frac{\sum_{i=0}^{N-1}(f(a_i) - f(n_i))^2}{N}}
\end{equation}
$

For each example, we want:
$
\begin{equation}
d_p \leq d_n
\end{equation}
$

Therefore,
$
\begin{equation}
d_p - d_n \leq 0
\end{equation}
$

This condition is quite easily satisfied during the training.

We will make it non-trivial by adding a margin (alpha):
$
\begin{equation}
d_p - d_n + \alpha \leq 0
\end{equation}
$

Given the condition above, the Triplet Loss L is defined as:
$
\begin{equation}
L = max(d_p - d_n + \alpha, 0)
\end{equation}
$

In [15]:
def triplet_loss(alpha, emb_dim):
    def loss(y_true, y_pred):
        anc, pos, neg = y_pred[:, :emb_dim], y_pred[:, emb_dim:2*emb_dim], y_pred[:, 2*emb_dim:]
        dp = tf.reduce_mean(tf.square(anc - pos), axis=1)
        dn = tf.reduce_mean(tf.square(anc - neg), axis=1)
        return tf.maximum(dp - dn + alpha, 0.)
    return loss

In [16]:
class PCAPlotter(tf.keras.callbacks.Callback):
    
    def __init__(self, plt, embedding_model, x_test, y_test):
        super(PCAPlotter, self).__init__()
        self.embedding_model = embedding_model
        self.x_test = x_test
        self.y_test = y_test
        self.fig = plt.figure(figsize=(10, 10))
        self.ax1 = plt.subplot(1, 2, 1)
        self.ax2 = plt.subplot(1, 2, 2)
        plt.ion()
        
        self.losses = []
    
    def plot(self, epoch=None, plot_loss=False):
        x_test_embeddings = self.embedding_model.predict(self.x_test)
        pca_out = PCA(n_components=2).fit_transform(x_test_embeddings)
        self.ax1.clear()
        self.ax1.scatter(pca_out[:, 0], pca_out[:, 1], c=self.y_test, cmap='seismic')
        if plot_loss:
            self.ax2.clear()
            self.ax2.plot(range(epoch), self.losses)
            self.ax2.set_xlabel('Epochs')
            self.ax2.set_ylabel('Loss')
            self.ax2.set_title('Loss vs Epochs')
        self.ax1.set_title('PCA embedding of Siamese embedding')
        self.ax1.set_xlabel('principal component 1')
        self.ax1.set_ylabel('principal component 2')
        self.fig.canvas.draw()
    
    def on_train_begin(self, logs=None):
        self.losses = []
        self.fig.show()
        self.fig.canvas.draw()
        self.plot()
        
    def on_epoch_end(self, epoch, logs=None):
        self.losses.append(logs.get('loss'))
        self.plot(epoch+1, plot_loss=True)

In [17]:
# data generation
def data_generator(batch_size, emb_dim):
    while True:
        x, _ = create_batch(batch_size=batch_size)
        y = np.zeros((batch_size, 3*emb_dim))
        yield x, y

In [42]:
# model training

#### TODO: Run bigger batch size on SDF or another psana node.
batch_size = 200   #### BEWARE: Setting batch_size too high will result in a UnboundLocalError involving a log variable.
epochs = 20
steps_per_epoch = int(num_train/batch_size)

net.compile(loss=triplet_loss(alpha=0.2, emb_dim=emb_dim), optimizer='adam')

X, Y = x_test[:1000], y_test[:1000]

In [43]:
_ = net.fit(
    data_generator(batch_size, emb_dim),
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    verbose=False,
    callbacks=[
        PCAPlotter(plt, embedding_model, X, Y)
    ]
)

<IPython.core.display.Javascript object>

In [20]:
# https://medium.com/@crimy/one-shot-learning-siamese-networks-and-triplet-loss-with-keras-2885ed022352

def compute_dist(a,b):
    return np.sum(np.square(a-b))

def compute_probs(network,X,Y):
    '''
    Input
        network : current NN to compute embeddings
        X : tensor of shape (m,w,h,1) containing pics to evaluate
        Y : tensor of shape (m,) containing true class
        
    Returns
        probs : array of shape (m,m) containing distances
    
    '''
    m = X.shape[0]
    nbevaluation = int(m*(m-1)/2)
    probs = np.zeros((nbevaluation))
    y = np.zeros((nbevaluation))
    
    #Compute all embeddings for all pics with current network
    embeddings = network.predict(X)
    
    size_embedding = embeddings.shape[1]
    
    #For each pics of our dataset
    k = 0
    for i in range(m):
            #Against all other images
            for j in range(i+1,m):
                #compute the probability of being the right decision : it should be 1 for right class, 0 for all other classes
                probs[k] = -compute_dist(embeddings[i,:],embeddings[j,:])
                if (Y[i]==Y[j]):
                    y[k] = 1
                else:
                    y[k] = 0
                k += 1
    return probs,y

from sklearn.metrics import roc_auc_score, roc_curve
def compute_metrics(probs,yprobs):
    '''
    Returns
        fpr : Increasing false positive rates such that element i is the false positive rate of predictions with score >= thresholds[i]
        tpr : Increasing true positive rates such that element i is the true positive rate of predictions with score >= thresholds[i].
        thresholds : Decreasing thresholds on the decision function used to compute fpr and tpr. thresholds[0] represents no instances being predicted and is arbitrarily set to max(y_score) + 1
        auc : Area Under the ROC Curve metric
    '''
    # calculate AUC
    auc = roc_auc_score(yprobs, probs)
    # calculate roc curve
    fpr, tpr, thresholds = roc_curve(yprobs, probs)
    
    return fpr, tpr, thresholds, auc

In [21]:
n_val = 100
probs,yprob = compute_probs(embedding_model,x_test[:n_val],y_test[:n_val])

In [36]:
fpr, tpr, thresholds, auc = compute_metrics(probs,yprob)
print(fpr)
print(tpr)

[0.00000000e+00 0.00000000e+00 8.96057348e-04 8.96057348e-04
 1.79211470e-03 1.79211470e-03 2.68817204e-03 2.68817204e-03
 5.37634409e-03 5.37634409e-03 6.27240143e-03 6.27240143e-03
 8.96057348e-03 8.96057348e-03 1.07526882e-02 1.07526882e-02
 1.61290323e-02 1.61290323e-02 1.88172043e-02 1.88172043e-02
 2.41935484e-02 2.41935484e-02 2.77777778e-02 2.77777778e-02
 3.67383513e-02 3.67383513e-02 4.12186380e-02 4.12186380e-02
 4.92831541e-02 4.92831541e-02 5.10752688e-02 5.10752688e-02
 5.19713262e-02 5.19713262e-02 5.64516129e-02 5.64516129e-02
 6.09318996e-02 6.09318996e-02 6.36200717e-02 6.36200717e-02
 6.63082437e-02 6.63082437e-02 8.15412186e-02 8.15412186e-02
 8.51254480e-02 8.51254480e-02 8.96057348e-02 8.96057348e-02
 9.31899642e-02 9.31899642e-02 9.85663082e-02 9.85663082e-02
 1.02150538e-01 1.02150538e-01 1.03942652e-01 1.03942652e-01
 1.04838710e-01 1.04838710e-01 1.08422939e-01 1.08422939e-01
 1.17383513e-01 1.17383513e-01 1.22759857e-01 1.22759857e-01
 1.33512545e-01 1.335125

In [37]:
# Plot sensitivity as function of false positive rate (fpr)
sensitivity = tpr/(fpr+tpr)
plt.figure()
plt.plot(fpr,sensitivity); 
plt.xlabel('False positive rate')
plt.ylabel('Sensitivity')
plt.show()

  


<IPython.core.display.Javascript object>

In [38]:
# ROC curve and Area under the curve (AUC)
plt.figure()
plt.plot(fpr,tpr, color='darkorange', lw=2, linestyle='-');
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--');
plt.xlabel('False positive rate'); 
plt.ylabel('True positive rate');
plt.title('AUC: {:.3f}'.format(auc))
plt.show()

<IPython.core.display.Javascript object>

In [39]:
# Print Dp and Dn
triplet, labels = create_batch(batch_size=1, anchor_label=9)
myBatch = np.zeros((3, emb_dim))
myBatch[0,:] = embedding_model.predict(triplet[0])
myBatch[1,:] = embedding_model.predict(triplet[1])
myBatch[2,:] = embedding_model.predict(triplet[2])
plot_triplet(triplet, labels)

from scipy.spatial.distance import cdist
Y = cdist(myBatch, myBatch, 'euclidean')
print("Dp: {}, Dn: {}".format(Y[0,1],Y[0,2]))

print(len(myBatch[0]))

<IPython.core.display.Javascript object>

Dp: 0.0014571444916656718, Dn: 0.004181834262707305
64


In [40]:
# Look at class embedding distance

# Embed first 1000 test images
num_sample = 50
num_classes = 11
junk = x_test[:num_sample]
junky = y_test[:num_sample]
emb_junk = np.zeros((num_sample,emb_dim))
for i in range(num_sample):
    emb_junk[i,:] = embedding_model.predict(np.expand_dims(junk[i,:], axis=0)) 

# Calculate euclidean distance between embeddings
Y = cdist(emb_junk,emb_junk,'euclidean')

# Populate class embedding distance matrix
distance_matrix = np.zeros((num_classes,num_classes))
counter = np.zeros((num_classes,num_classes))
for i in range(num_sample):
    for j in range(i+1,num_sample):
        class1 = junky[i]
        class2 = junky[j]
        distance_matrix[class1,class2] += Y[i,j]
        distance_matrix[class2,class1] += Y[i,j]
        counter[class1,class2] += 1
        counter[class2,class1] += 1
distance_matrix = distance_matrix/counter

# Visualize
plt.figure(); plt.imshow(distance_matrix,interpolation='None'); 
plt.title("Class embedding distance"); plt.colorbar(); plt.show()

<IPython.core.display.Javascript object>

In [41]:
# See which classes are hard and easy to distinguish from each other

myClass = 1 # Choose a digit here

# select hardest and easiest classes based on mean distance
dist = distance_matrix[myClass,:] 
challenging = np.argsort(dist)
hardest = None
hscore = 0
easiest = None
escore = 0
cscore = dist[myClass]
for i in challenging:
    if i != myClass:
        hardest = i
        hscore=dist[i]
        break
for i in reversed(challenging):
    if i != myClass:
        easiest = i
        escore=dist[i]
        break        
print("For PDB {} ({:.3f}), hardest class to distinguish is digit {} ({:.3f}) and easiest digit is {} ({:.3f})"
      .format(idx2particle[myClass],cscore,idx2particle[hardest],hscore,idx2particle[easiest],escore))

For PDB 1ss8 (1.203), hardest class to distinguish is digit 7dx9 (1.180) and easiest digit is 6sp2 (6.166)
