In [11]:
import tensorflow as tf 
from tensorflow.keras.layers import Dense 
import numpy as np 
import matplotlib.pyplot as plt 
import os 
# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [12]:
A = np.random.normal(size=[100,2])
B = np.random.normal(size=[200,2], loc=1)
data = np.concatenate([A,B])

yA1 = np.array([1]*100 + [0]*200)
yA2 = np.array([0]*100 + [1]*200)
yB = np.ones(300)
y = np.concatenate([yA1[:,np.newaxis], yA2[:,np.newaxis], yB[:,np.newaxis]], axis=1)

In [13]:
# loss function
def get_loss_fn(structure, alpha, beta):
    ''' 
    structure     array of shape [M,M], where M is the number of all classes 
                  (a hierarchical level may contain multiple classes), 
                  structure[i,j] == 1 iff "class i" is a subclass of "class j"
    alpha         a hyper-parameter to balance hierarchical loss and prediction loss
    beta          a hyper-parameter to balance local prob and global prob
    '''
    def loss_fn(y_true, y_pred_logits):
        ''' 
        y_true     [N,K], K = K1 + K2 + ... + KM, where M is the number of layers
        y_pred     [total_prob, local_prob, global_prob], local_prob = [N,K], global_prob = [N,K]
        '''
        local_logits, global_logits = y_pred_logits 
        
        local_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(y_true, local_logits)
        )
        
        global_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(y_true, global_logits)
        )
        
        total_prob = beta * tf.math.sigmoid(local_logits) + (1-beta) * tf.math.sigmoid(global_logits)
        
        hierachical_loss = 0
        for i in range(structure.shape[0]):
            for j in range(structure.shape[1]):
                if structure[i,j] == 1:
                    hierachical_loss += tf.reduce_sum(
                        tf.where(
                            total_prob[:,i] < total_prob[:,j], 
                            0, 
                            tf.pow(total_prob[:,i] - total_prob[:,j],2)
                        )
                    )
        hierachical_loss /= y_true.shape[0]
        
        return local_loss + global_loss + alpha * hierachical_loss
        
    return loss_fn 

In [14]:
class HMC_LSTM(tf.keras.Model):
    def __init__(self, units, beta, num_classes_list, num_layers):
        '''
        units               number of nurons in LSTM
        beta                integer, balance global and local predictions
        num_classes_list    a list of number of classes in each layer
        num_layers          integer, len(num_classes_list) == num_layers
        '''
        super(HMC_LSTM, self).__init__()
        self.units = units 
        self.beta = beta 
        self.num_layers = num_layers
        
        self.input_gate = Dense(units,activation='sigmoid')
        self.output_gate = Dense(units, activation='sigmoid')
        self.forget_gate = Dense(units, activation='sigmoid')
        self.candidate_gate = Dense(units, activation='tanh')
        
        self.local_W = list()
        for i in range(self.num_layers):
            self.local_W.append(Dense(num_classes_list[i]))
        
        self.global_W = Dense(tf.reduce_sum(num_classes_list))
            
    def call(self, inputs):
        ''' 
        inputs     [N,k], where N is the number of examples
        '''
        
        # init sequence
        x = tf.concat(
            [inputs, tf.zeros([inputs.shape[0],self.units])], 
            axis=-1
        )
        candidate_last = tf.zeros(self.units)
        
        # output probs
        local_probs = list()
        
        # main loop
        for i in range(self.num_layers):
            # gates
            forget_prob = self.forget_gate(x)
            input_prob = self.input_gate(x)
            output_prob = self.output_gate(x)
            candidate_hat = self.candidate_gate(x)
            
            # new candidate
            candidate = candidate_hat * input_prob + forget_prob * candidate_last
            
            # outputs
            outputs = output_prob * tf.math.tanh(candidate)
            
            # update
            candidate_last = candidate 
            x = tf.concat([inputs, outputs], axis=-1)

            # local probs
            local_probs.append(self.local_W[i](outputs))
            
        global_logits = self.global_W(
            tf.concat([inputs, outputs], axis=-1)
        )
        local_logits = tf.concat(local_probs, axis=-1)
        
        return local_logits, global_logits
    
    def postprocess(self, local_logits, global_logits):
        local_prob = tf.nn.sigmoid(local_logits)
        global_prob = tf.nn.sigmoid(global_logits)
        total_prob = self.beta * local_prob + (1-self.beta) * global_prob
        return total_prob, local_prob, global_prob

In [15]:
model = HMC_LSTM(10, 0.5, [2,1], 2)
local_logits, global_logits = model(data)
structure = np.array([[0,0,1],[0,0,1],[0,0,0]])
alpha = 0.5
loss_fn = get_loss_fn(structure, alpha, 0.5)

def get_train_fn(model, loss_fn, optimizer):
    def train_fn(x,y):
        with tf.GradientTape() as tape:
            loss = loss_fn(y, model(x))
            variables = model.trainable_variables 
            gradients = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradients, variables))
        return loss 
    return train_fn 

train_fn = get_train_fn(model, loss_fn, tf.keras.optimizers.Adam(1e-2))

for epoch in range(5000):
    loss = train_fn(data, tf.convert_to_tensor(y,tf.float32))
    if (epoch+1)%500 == 0:
        print("epoch {}/5000, loss = {}".format(epoch+1, loss))


epoch 500/5000, loss = 0.5115510821342468
epoch 1000/5000, loss = 0.4327765107154846
epoch 1500/5000, loss = 0.3604480028152466
epoch 2000/5000, loss = 0.316402405500412
epoch 2500/5000, loss = 0.2916376292705536
epoch 3000/5000, loss = 0.2715415060520172
epoch 3500/5000, loss = 0.2570025622844696
epoch 4000/5000, loss = 0.23144850134849548
epoch 4500/5000, loss = 0.2094428986310959
epoch 5000/5000, loss = 0.18716368079185486


In [16]:
def get_area(x1, y1, x2, y2):
    '''
    x1 < x2, the inequality should be strict
    '''
    return (y1 + y2) * (x2 - x1) / 2

def get_score(model, x, y):
    local_logits, global_logits = model(x)
    total_prob, local_prob, global_prob = model.postprocess(local_logits, global_logits)
    
    precision_list = [0]
    recall_list = [1]
    for threshold in np.arange(0.01, 1, 0.01):
        pred_int = tf.cast(total_prob >= threshold, tf.int16)
        tp = pred_int * y 
        fp = pred_int * (1-y)
        fn = (1-pred_int) * y
        
        tp = tf.reduce_sum(tp)
        fp = tf.reduce_sum(fp)
        fn = tf.reduce_sum(fn)
        
        precision = tp / (tp+fp)
        recall = tp / (tp+fn)
        
        if not (np.isnan(precision) or np.isnan(recall)):
            precision_list.append(precision)
            recall_list.append(recall)
        
    idx = np.argsort(precision_list)
    precision_np = np.array(precision_list)
    recall_np = np.array(recall_list)
    
    precision_np = precision_np[idx]
    recall_np = recall_np[idx]
    
    area = 0 
    for i in range(len(precision_np)-1):
        area += get_area(precision_np[i], recall_np[i], precision_np[i+1], recall_np[i+1])
    
    return area

In [17]:
local_logits, global_logits = model(data)
total_prob, local_prob, global_prob = model.postprocess(local_logits, global_logits)
print("AU(PRC) = {}".format(get_score(model, data, y)))

AU(PRC) = 0.9990706177379405
