In [10]:
import tensorflow as tf 
import matplotlib.pyplot as plt 
import numpy as np 
import os 
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [11]:
A = np.random.normal(size=[100,2])
B = np.random.normal(size=[200,2], loc=1)
data = np.concatenate([A,B])

yA1 = np.array([1]*100 + [0]*200)
yA2 = np.array([0]*100 + [1]*200)
yB = np.ones(300)
y = np.concatenate([yA1[:,np.newaxis], yA2[:,np.newaxis], yB[:,np.newaxis]], axis=1)
y = tf.convert_to_tensor(y, tf.float32)

In [12]:
class HMCModel(tf.keras.models.Model):
    def __init__(self):
        super(HMCModel, self).__init__()
        self.W1 = tf.keras.layers.Dense(10, activation='relu')
        self.W2 = tf.keras.layers.Dense(3, activation='sigmoid')
    def call(self, inputs):
        x = self.W1(inputs)
        outputs = self.W2(x)
        return outputs
    def postprocess(self, inputs):
        A = inputs[:,:-1]
        B = tf.reduce_max(inputs, axis=1)
        return tf.concat([A, B[:,tf.newaxis]], axis=1)

In [13]:
model = HMCModel()
pred = model(data)
print(pred[:5])

tf.Tensor(
[[0.45011762 0.5104567  0.4410804 ]
 [0.5028045  0.494806   0.49938732]
 [0.21106966 0.46272463 0.27008232]
 [0.65983933 0.70771116 0.67803466]
 [0.2853197  0.4600654  0.31711978]], shape=(5, 3), dtype=float32)


In [14]:
def loss_fn(y_true, y_pred):
    '''
    y_true    3 cols
    y_pred    3 cols, before postprocessing
    ''' 
    y_pred = tf.where(y_pred < 1e-20, 1e-20, y_pred)
    y_pred = tf.where(y_pred > 1-1e-20, 1-1e-20, y_pred)
    
    lossA = y_true[:,:-1] * tf.math.log(y_pred[:,:-1]) + (1-y_true[:,:-1]) * tf.math.log(1-y_pred[:,:-1])
    
    lossB_1 = y_true[:,2] * tf.math.log(tf.reduce_max(y_true * y_pred, axis=1))
    lossB_2 = (1-y_true[:,2]) * tf.math.log(1 - tf.reduce_max(y_pred, axis=1))
    lossB = (lossB_1 + lossB_2)[...,tf.newaxis]
    loss = tf.concat([lossA, lossB], axis=1)
    
    return -tf.reduce_mean(loss)

In [15]:
class HMCLogitsModel(tf.keras.models.Model):
    def __init__(self):
        super(HMCLogitsModel, self).__init__()
        self.W1 = tf.keras.layers.Dense(10, activation='relu')
        self.W2 = tf.keras.layers.Dense(3)
    def call(self, inputs):
        x = self.W1(inputs)
        outputs = self.W2(x)
        return outputs
    def get_prob(self, logits):
        return tf.math.sigmoid(logits)
    def postprocess(self, inputs):
        A = inputs[:,:-1]
        B = tf.reduce_max(inputs, axis=1)
        return tf.concat([A, B[:,tf.newaxis]], axis=1)

The gradient of $y=\max \{a, b\}$. The explicit form of the function is 
$$
y=
\begin{cases}
a, & a \geq b \\
b, & a<b
\end{cases}
$$
Hence, the gradient should be 
$$
\frac{dy}{da} = \begin{cases}
1, & a \geq b \\
0, & a < b
\end{cases}
$$
and 
$$
\frac{dy}{db} = \begin{cases}
0, & a \geq b \\
1, & a < b
\end{cases}
$$

In our setting, the gradient of $L =\max_B \{f(x_B, y_B)\}$ with respect to $X_B$ should be the following, where $y_B$ is binary, and 
$$
f(x_B, y_B) = \begin{cases} 
x_B, &y_B = 1 \\
-\infty, &y_B = 0
\end{cases}
$$
The gradient $df/dx_B$ is 
$$
\frac{df}{dx_B}=
\begin{cases}
1, &y_B = 1 \\
0, &y_B = 0
\end{cases}
$$

In [16]:
def softplus(x):
    ''' 
    return log(1+exp(-x))
    '''
    r = tf.maximum(-x, 0) + tf.math.log(1 + tf.exp(-tf.abs(x)))
    return r 

def get_cross_logits_y(y):
    @tf.custom_gradient
    def cross_logits_y(logits):
        outputs = tf.where(y == 0, -np.inf, logits)
        def grad(upstream):
            return upstream * tf.where(y == 1, 1.0, 0.0)
        return outputs, grad
    return cross_logits_y


def max_with_structure(x):
    outputs1 = x[:,:-1]
    outputs2 = tf.reduce_max(x, axis=1, keepdims=True)
    outputs = tf.concat([outputs1, outputs2], axis=1)
    outputs = tf.where(tf.math.is_inf(outputs), np.inf, outputs)
    return outputs

def loss_fn_logits(y_true, y_pred):
    cross_logits_y = get_cross_logits_y(y_true)
    loss1 = y_true * softplus(max_with_structure(cross_logits_y(y_pred)))
    loss2 = (1-y_true) * softplus(max_with_structure(y_pred)) 
    loss3 = (1-y_true) * max_with_structure(y_pred)
    return tf.reduce_mean(loss1 + loss2 + loss3)


In [17]:
model = HMCLogitsModel()
logits = model(data)
prob = model.get_prob(logits)
loss_fn(y, prob)

<tf.Tensor: shape=(), dtype=float32, numpy=0.6859324>

In [26]:
optimizer = tf.keras.optimizers.Adam(1e-3)
for epoch in range(1000):
    with tf.GradientTape() as tape:
        loss = loss_fn_logits(y, model(data))
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))
        if (epoch+1)%100 ==0:
            print("epoch {}/1000, loss {}".format(epoch+1,loss))

epoch 100/1000, loss 0.2636295258998871
epoch 200/1000, loss 0.26360827684402466
epoch 300/1000, loss 0.26359012722969055
epoch 400/1000, loss 0.26357540488243103
epoch 500/1000, loss 0.263563871383667
epoch 600/1000, loss 0.26355740427970886
epoch 700/1000, loss 0.26355189085006714
epoch 800/1000, loss 0.26355046033859253
epoch 900/1000, loss 0.2635428011417389
epoch 1000/1000, loss 0.26354295015335083


In [25]:
logits = model(data)
prob = model.get_prob(logits)
prob = model.postprocess(prob)
prob

<tf.Tensor: shape=(300, 3), dtype=float32, numpy=
array([[8.09924364e-01, 1.90719336e-01, 9.99999821e-01],
       [4.07104373e-01, 5.93481362e-01, 9.99999523e-01],
       [9.87500012e-01, 1.26434490e-02, 1.00000000e+00],
       [1.78263575e-01, 8.21826220e-01, 1.00000000e+00],
       [9.97776151e-01, 2.25230376e-03, 1.00000000e+00],
       [2.61303395e-01, 7.38956928e-01, 1.00000000e+00],
       [3.35952520e-01, 6.63524389e-01, 9.99999404e-01],
       [5.09302914e-01, 4.90294874e-01, 1.00000000e+00],
       [2.60752082e-01, 7.39071190e-01, 1.00000000e+00],
       [9.35136557e-01, 6.48711026e-02, 1.00000000e+00],
       [3.30381066e-01, 6.69816911e-01, 9.99999762e-01],
       [3.22713435e-01, 6.77185774e-01, 9.99999404e-01],
       [8.89859080e-01, 1.11379363e-01, 1.00000000e+00],
       [8.70580137e-01, 1.29354805e-01, 9.99999762e-01],
       [2.11296171e-01, 7.88347065e-01, 9.99999523e-01],
       [4.48018789e-01, 5.51148891e-01, 9.99999404e-01],
       [9.66049612e-01, 3.38035412e-02

In [27]:
model = HMCModel()
optimizer = tf.keras.optimizers.Adam(1e-3)
for epoch in range(1000):
    with tf.GradientTape() as tape:
        loss = loss_fn(y, model(data))
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))
        if (epoch+1)%100 ==0:
            print("epoch {}/1000, loss {}".format(epoch+1,loss))


epoch 100/1000, loss 0.44226911664009094
epoch 200/1000, loss 0.36681216955184937
epoch 300/1000, loss 0.3310305178165436
epoch 400/1000, loss 0.31354355812072754
epoch 500/1000, loss 0.3051971197128296
epoch 600/1000, loss 0.3006165027618408
epoch 700/1000, loss 0.29762351512908936
epoch 800/1000, loss 0.2954781651496887
epoch 900/1000, loss 0.29385390877723694
epoch 1000/1000, loss 0.29250532388687134


In [28]:
model(data)

<tf.Tensor: shape=(300, 3), dtype=float32, numpy=
array([[0.81365883, 0.17204055, 0.9773428 ],
       [0.37136358, 0.66622984, 0.9893941 ],
       [0.9869997 , 0.00662948, 0.99985313],
       [0.11335477, 0.88553214, 0.99903435],
       [0.96491855, 0.02271631, 0.99829215],
       [0.18643028, 0.8235879 , 0.9951367 ],
       [0.5042321 , 0.49325   , 0.9765352 ],
       [0.7212457 , 0.22672553, 0.9791762 ],
       [0.24829093, 0.78402835, 0.99649924],
       [0.83555734, 0.15757784, 0.9987128 ],
       [0.26606518, 0.7567198 , 0.9867082 ],
       [0.4299449 , 0.570367  , 0.9827341 ],
       [0.7562503 , 0.22980572, 0.9613381 ],
       [0.8425975 , 0.13863695, 0.98520476],
       [0.30761474, 0.71602845, 0.9804299 ],
       [0.6539668 , 0.32716814, 0.97086287],
       [0.8400146 , 0.1429371 , 0.99299204],
       [0.22517161, 0.7835462 , 0.9899696 ],
       [0.8953694 , 0.08573895, 0.99061376],
       [0.7293102 , 0.2567695 , 0.9957604 ],
       [0.79329795, 0.20159875, 0.9966206 ],
     

In [3]:
for i in range(1,1):
    print(i)