In [None]:
"""
original implementation of separation loss in EPI paper is follow:
"""
def separation_loss(y_true, y_pred):

    y_true = tf.squeeze(y_true)
    env_id, _ = tf.unique(y_true)

    mu = []
    sigma = []
    for i in range(EPI.NUM_OF_ENVS):
        idx = tf.where(tf.equal(y_true, env_id[i]))  # indices of y_true equivalent to unique env_id[i]
        traj = tf.gather(y_pred, idx)   # corresponding list of trajectories in y_pred
        mu.append(tf.squeeze(K.mean(traj, axis=0)))  # trajectories mean as item of mu
        this_sigma = tf.maximum(K.mean(K.std(traj, axis=0))-0.1, 0)  # std of trajectories minus 0.1 as item of sigma
        sigma.append(this_sigma)

    mu = tf.stack(mu)
    r = tf.reduce_sum(mu * mu, 1)
    r = tf.reshape(r, [-1, 1])
    D = (r - 2 * tf.matmul(mu, tf.transpose(mu)) + tf.transpose(r))/tf.constant(EPI.EMBEDDING_DIMENSION, dtype=tf.float32)
    D = tf.sqrt(D + tf.eye(EPI.NUM_OF_ENVS, dtype=tf.float32))
    distance = K.mean(tf.reduce_sum(0.1 - tf.minimum(D, 0.1)))

    sigma = tf.stack(sigma)

    return (distance + K.mean(sigma))*0.01

'''
understandings of above:
we have mean embeddings of trajectories over same idex: t1=[a1, a2], t2=[b1, b2], as item of mu
we want to measure the distance of the two mean embeddings, which is:
dis = sqrt((a1-b1)^2+(a2-b2)^2), which is a L2 norm.
so we calculate the distance in matrix representation:
we have T=[t1, t2]
r = reduce_sum(T*T, 1), which is L2 norm of T, giving:
r = [a1^2+a2^2, b1^2+b2^2]=[A, B]
so, r-2r*r^T+r^T=[[A, A],[B,B]]-2*[[A, C],[C, B]]+[[A,B],[A,B]]
where C = a1*b1+a2*b2
r-2r*r^T+r^T= [[0, A+B-2C],[B-2C+A, 0]]
and A+B-2C = (a1-b1)^2 + (a2-b2)^2
EPI.EMBEDDING_DIMENSION is for normalization, and tf.eye is for reasonably sqrt.
sigma is the std term, which needs to be minimised.
'''

In [11]:
import numpy as np
import tensorflow as tf
a=[0.2, 1.]
a=np.array(a).astype('float64').reshape(-1)
print(type(np.array(a).astype('float64').reshape(-1)[0]))
np.unique(a)


<type 'numpy.float64'>


array([ 0.2,  1. ])

In [25]:
a = np.random.randint(5,size=(2,2,2))
print(a)

[[[2 0]
  [4 3]]

 [[4 2]
  [2 3]]]


In [26]:
mu=[]
for env in a:
    print(env)
    print(np.squeeze(np.mean(env, axis=0)))
    mu.append(np.squeeze(np.mean(env, axis=0)))

[[2 0]
 [4 3]]
[ 3.   1.5]
[[4 2]
 [2 3]]
[ 3.   2.5]


In [27]:
sigma=[]
for env in a:
    print(env)
    print(np.std(env, axis=0)-0.1)
    sigma.append(np.std(env, axis=0)-0.1)

[[2 0]
 [4 3]]
[ 0.9  1.4]
[[4 2]
 [2 3]]
[ 0.9  0.4]


In [28]:
print(mu)
mu = np.stack(mu)
print(mu)

[array([ 3. ,  1.5]), array([ 3. ,  2.5])]
[[ 3.   1.5]
 [ 3.   2.5]]


In [32]:
r=np.sum(mu*mu, axis=1)  # element-wise product
print(r)
r=r.reshape(-1,1)
r

[ 11.25  15.25]


array([[ 11.25],
       [ 15.25]])

In [37]:
D=(r-2*np.matmul(mu, mu.T)+r.T)
print(r)
print(-2*np.matmul(mu, mu.T))
print(r.T)
print(D)
D=D/len(D)
print(D)

[[ 11.25]
 [ 15.25]]
[[-22.5 -25.5]
 [-25.5 -30.5]]
[[ 11.25  15.25]]
[[ 0.  1.]
 [ 1.  0.]]
[[ 0.   0.5]
 [ 0.5  0. ]]


In [38]:
np.sqrt(D+np.eye(len(D)))

array([[ 1.        ,  0.70710678],
       [ 0.70710678,  1.        ]])