In [None]:
import sys
import os
import numpy as np
import pickle
from numpy import array
import tensorflow as tf
import platform
import time
import rdflib
import random
from statistics import mean
from scipy.spatial import distance
import scipy
import networkx

print('Python version: %s' % platform.python_version())
print('Tensorflow version: %s' % tf.__version__)
tf.get_logger().setLevel('ERROR')

DEVICE = "0"
os.environ["CUDA_VISIBLE_DEVICES"]=DEVICE

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

### Load Data

In [None]:
# KG data
KG1_ENTITIES = pickle.load(open('data/KG1_ENTITIES', "rb"))
KG2_ENTITIES = pickle.load(open('data/KG2_ENTITIES', "rb"))
KG1_ATTRIBUTES = pickle.load(open('data/KG1_ATTRIBUTES', "rb"))
KG2_ATTRIBUTES = pickle.load(open('data/KG2_ATTRIBUTES', "rb"))

# Vocab
KG1_ent_vocab = pickle.load(open('data/KG1_ent_vocab', "rb"))
KG2_ent_vocab = pickle.load(open('data/KG2_ent_vocab', "rb"))
pred_vocab = pickle.load(open('data/pred_vocab', "rb"))
char_vocab = pickle.load(open('data/char_vocab', "rb"))

seed_data = pickle.load(open('data/seed_data', "rb"))
test_data = pickle.load(open('data/test_data', "rb"))

In [None]:
TOTAL_EPOCHS = 400
BUFFER_SIZE = seed_data.shape[0]
BATCH_SIZE = 64
N_BATCH = BUFFER_SIZE//BATCH_SIZE
CHR_EMBED_DIM = 64
HIDDEN_UNITS = 256
NUM_LAYERS_ATT = 4
NUM_LAYERS_REL = 4
NUM_HEAD = 8

MAX_CHARS = 10
MARGIN_LOSS = 1

VOCAB_ENT_SIZE = len(KG1_ent_vocab) + len(KG2_ent_vocab) + 1
VOCAB_PRE_SIZE = len(pred_vocab) + 1
VOCAB_CHR_SIZE = len(char_vocab) + 1

In [None]:
def get_neg_sample(kg2_ents):
    return np.array(list(set(random.choices(population=list(KG2_ENTITIES.keys()), k=len(kg2_ents)*2)))[:len(kg2_ents)])

def get_attribute_data(entities, KG_attribute):
    pred_data = list()
    attr_data = list()
    
    for ent in entities:
        tmp_pred=list()
        tmp_attr=list()
        for pred, attr in KG_attribute[ent]:
            tmp_pred.append(pred)
            tmp_attr.append(attr)
        pred_data.append(tmp_pred)
        attr_data.append(tmp_attr)
    return np.array(pred_data), np.array(attr_data)

def get_subgraph(entities, KG_entities):
    assert len(entities) == len(set(entities))
        
    # create graph
    G = networkx.Graph()
    
    #add the batch nodes (needs to be separated, so that we make sure that the first N (N=batch_size) are the mini-batch nodes)
    for e in entities:
        G.add_node(e)
        G.add_weighted_edges_from([(e,e,pred_vocab["SELF-REL"])])
    
    #add the neighbors
    for e in entities:
        neighbours = KG_entities[e]
        for p, n in neighbours:
            G.add_weighted_edges_from([(e,n,p)])   
    adj_matrix = networkx.adjacency_matrix(G)
    
    return np.array(G.nodes), np.array(adj_matrix.todense())
            

def get_batch_data(ent_pair):
    kg1_ents = ent_pair[:,0]
    kg2_ents = ent_pair[:,1]
    neg_ents = get_neg_sample(kg2_ents)
    assert len(kg1_ents) == len(set(kg1_ents))
    assert len(kg2_ents) == len(set(kg2_ents))
    assert len(neg_ents) == len(set(neg_ents))
    
    # get attribute data
    kg1_attr_keys, kg1_attr_vals = get_attribute_data(kg1_ents, KG1_ATTRIBUTES)
    kg2_attr_keys, kg2_attr_vals = get_attribute_data(kg2_ents, KG2_ATTRIBUTES)
    neg_attr_keys, neg_attr_vals = get_attribute_data(neg_ents, KG2_ATTRIBUTES)
    
    # get graph data
    kg1_node_list, kg1_adj_matrix = get_subgraph(kg1_ents, KG1_ENTITIES)
    kg2_node_list, kg2_adj_matrix = get_subgraph(kg2_ents, KG2_ENTITIES)
    neg_node_list, neg_adj_matrix = get_subgraph(neg_ents, KG2_ENTITIES)
    
    # group data
    kg1_data = (kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix)
    kg2_data = (kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix)
    neg_data = (neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix)
    
    return kg1_data, kg2_data, neg_data

In [None]:
train_batch = tf.data.Dataset.from_tensor_slices((seed_data))
train_batch = tf.data.Dataset.batch(train_batch, batch_size = BATCH_SIZE, drop_remainder=True)
for (batch, ent_pair) in enumerate(train_batch):
    kg1_data, kg2_data, neg_data = get_batch_data(ent_pair.numpy())

### Classes

In [None]:
def gru(units):
    return tf.keras.layers.GRU(units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')

In [None]:
class MultiHeadAttention(tf.keras.Model):

    def __init__(self, num_head):
        super(MultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.key_size = HIDDEN_UNITS // self.num_head
        
        self.wq = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.wk = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.wv = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.wn = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.wp = tf.keras.layers.Dense(HIDDEN_UNITS)

    def call(self, query, value, mask=None, bias=None):
        query = self.wq(query)
        key = self.wk(value)
        value = self.wv(value)
        
        batch_size = query.shape[0]
        
        # for parallel multihead computation
        query = tf.reshape(query, [batch_size, -1, self.num_head, self.key_size])
        query = tf.transpose(query, [0, 2, 1, 3])
        key = tf.reshape(key, [batch_size, -1, self.num_head, self.key_size])
        key = tf.transpose(key, [0, 2, 1, 3])
        value = tf.reshape(value, [batch_size, -1, self.num_head, self.key_size])
        value = tf.transpose(value, [0, 2, 1, 3])
        
        #(batch, h, query_len, value_len)
        score = tf.matmul(query, key, transpose_b=True) / tf.math.sqrt(tf.dtypes.cast(self.key_size, dtype=tf.float32))
        if bias is not None:
            score = score + bias
        if mask is not None:
            if len(mask.shape) < 4:
                mask = tf.expand_dims(mask, axis=3)
                mask = tf.transpose(mask, [0, 2, 1, 3])
            score *= mask
            score = tf.where(tf.equal(score, 0), tf.ones_like(score) * -1e9, score)
        
        attention = tf.nn.softmax(score, axis=-1)
        context = tf.matmul(attention, value)
        context = tf.transpose(context, [0, 2, 1, 3])
        context = tf.reshape(context, [batch_size, -1, self.key_size * self.num_head])
        
        
        pred_output = score
        pred_output = tf.reduce_sum(tf.transpose(pred_output, [2,3,1,0]), axis=-1)
        
        pred_output = self.wp(pred_output)
        node_output = self.wn(context)
        
        return node_output, pred_output, attention

In [None]:
#### LOOKER
class HistoricalEmbeddings(tf.keras.Model):
    def __init__(self):
        super(HistoricalEmbeddings, self).__init__()
        
        self.batch_embed = tf.keras.layers.Embedding(VOCAB_ENT_SIZE, 
                                                   HIDDEN_UNITS, 
                                                   mask_zero=False, 
                                                   trainable=True,
                                                   name="node_batch_embedding")
        
        self.hist_embed = tf.keras.layers.Embedding(VOCAB_ENT_SIZE, 
                                                   HIDDEN_UNITS, 
                                                   mask_zero=False, 
                                                   trainable=True,
                                                   name="node_hist_embedding")
        
        self.lin_transform_batch_1 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_batch_2 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_batch_3 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_batch_4 = tf.keras.layers.Dense(HIDDEN_UNITS)
        
        self.lin_transform_hist_1 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_hist_2 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_hist_3 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.lin_transform_hist_4 = tf.keras.layers.Dense(HIDDEN_UNITS)
    
    def call(self, ent_list):
        
        batch_node = ent_list[:BATCH_SIZE]
        batch_node_emb = self.batch_embed(batch_node)
        batch_node_emb = self.lin_transform_batch_1(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_2(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_3(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_4(batch_node_emb)
        
        hist_node = ent_list[BATCH_SIZE:]
        hist_node_emb = self.hist_embed(hist_node)
        hist_node_emb = self.lin_transform_hist_1(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_2(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_3(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_4(hist_node_emb)
        
        
        entities = tf.concat([batch_node_emb, hist_node_emb], axis=0) #(NUM_NODES, 128)
        return entities
    
    def get_batch_embed(self, ent_list):
        batch_node = ent_list[:BATCH_SIZE]
        batch_node_emb = self.batch_embed(batch_node)
        batch_node_emb = self.lin_transform_batch_1(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_2(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_3(batch_node_emb)
        batch_node_emb = self.lin_transform_batch_4(batch_node_emb)
        return batch_node_emb
    
    def get_hist_embed(self, ent_list):
        hist_node = ent_list[:BATCH_SIZE]
        hist_node_emb = self.hist_embed(hist_node)
        hist_node_emb = self.lin_transform_hist_1(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_2(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_3(hist_node_emb)
        hist_node_emb = self.lin_transform_hist_4(hist_node_emb)
        return hist_node_emb
    
    
hist_embed = HistoricalEmbeddings()

for (batch, ent_pair) in enumerate(train_batch):
    kg1_data, kg2_data, neg_data = get_batch_data(ent_pair.numpy())
    
    kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix = kg1_data
    kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix = kg2_data
    neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix = neg_data
    
    kg1_ents = hist_embed(kg1_node_list)
    kg2_ents = hist_embed(kg2_node_list)
    neg_ents = hist_embed(neg_node_list)
        
    break

print('Shape of kg1_ents: %s' % kg1_ents.shape)
print('Shape of kg2_ents: %s' % kg2_ents.shape)
print('Shape of neg_ents: %s' % neg_ents.shape)

In [None]:
class AttAggregator(tf.keras.Model):

    def __init__(self):
        super(AttAggregator, self).__init__()
        
        self.W1 = tf.keras.layers.Dense(HIDDEN_UNITS)
        self.W2 = tf.keras.layers.Dense(HIDDEN_UNITS)
        
        self.self_att_obj = [MultiHeadAttention(NUM_HEAD) for _ in range(NUM_LAYERS_REL)]
        self.self_att_obj_dropout = [tf.keras.layers.Dropout(0.1) for _ in range(NUM_LAYERS_REL)]
        self.self_att_obj_norm = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(NUM_LAYERS_REL)]
        
        self.lin_transform_1_o = [tf.keras.layers.Dense(HIDDEN_UNITS * NUM_LAYERS_ATT, activation='relu') for _ in range(NUM_LAYERS_ATT)]
        self.lin_transform_2_o = [tf.keras.layers.Dense(HIDDEN_UNITS) for _ in range(NUM_LAYERS_ATT)]
        self.dropout_o = [tf.keras.layers.Dropout(0.1) for _ in range(NUM_LAYERS_ATT)]
        self.batch_norm_o = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(NUM_LAYERS_ATT)]
        
        self.Wc = tf.keras.layers.Dense(HIDDEN_UNITS, activation='relu')
        self.final_transform = MultiHeadAttention(NUM_HEAD)
        
    def call(self, predicates, objects, mask=None):

        residual = tf.nn.tanh(self.W1(objects) + self.W2(predicates))
        for i in range(NUM_LAYERS_REL):
            
            self_att_output, _, _ = self.self_att_obj[i](residual, residual, mask)
            self_att_output = self.self_att_obj_dropout[i](self_att_output)
            self_att_output = residual + self_att_output
            self_att_output = self.self_att_obj_norm[i](self_att_output)
            
            output = self.lin_transform_2_o[i](self.lin_transform_1_o[i](self_att_output))
            output = self.dropout_o[i](output)
            output = output + residual
            output = self.batch_norm_o[i](output)
            
            residual = output
        
        ctx = tf.reduce_sum(output, axis=1, keepdims=True)
        ctx = self.Wc(ctx)
        final_output, _, attention = self.final_transform(ctx, output)
        attention = tf.reduce_mean(attention, axis=1)
        
        return tf.reshape(final_output, [final_output.shape[0], final_output.shape[-1]]), tf.reshape(attention, [attention.shape[0], attention.shape[-1]])

In [None]:
class TransGNN(tf.keras.Model):

    def __init__(self):
        super(TransGNN, self).__init__()
        self.pred_lin_trans_sigmoid = [tf.keras.layers.Dense(1, activation="sigmoid") for _ in range(NUM_LAYERS_REL)]
        self.pred_lin_trans_bias = [tf.keras.layers.Dense(1) for _ in range(NUM_LAYERS_REL)]
        
        self.self_att_nodes = [MultiHeadAttention(NUM_HEAD) for _ in range(NUM_LAYERS_REL)]
        self.self_att_nodes_dropout = [tf.keras.layers.Dropout(0.1) for _ in range(NUM_LAYERS_REL)]
        self.self_att_nodes_norm = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(NUM_LAYERS_REL)]
        self.self_att_preds_dropout = [tf.keras.layers.Dropout(0.1) for _ in range(NUM_LAYERS_REL)]
        self.self_att_preds_norm = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(NUM_LAYERS_REL)]
        
        self.lin_transform_1_o = [tf.keras.layers.Dense(HIDDEN_UNITS * NUM_LAYERS_ATT, activation='relu') for _ in range(NUM_LAYERS_ATT)]
        self.lin_transform_2_o = [tf.keras.layers.Dense(HIDDEN_UNITS) for _ in range(NUM_LAYERS_ATT)]
        self.dropout_o = [tf.keras.layers.Dropout(0.1) for _ in range(NUM_LAYERS_ATT)]
        self.batch_norm_o = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(NUM_LAYERS_ATT)]
        
    #node_list #(NUM_NODES) --> the first N (N=batch_size) are the mini-batch nodes, the others are the one-hop nodes
    #pred_embed #(NUM_NODES, NUM_NODES, 128)
    #adj_mask #(NUM_NODES, NUM_NODES, 1)
    def call(self, node_list, pred_embed, adj_mask, mask=None):
        
        hist_embeddings = tf.expand_dims(hist_embed(node_list), axis=0) # (1, NUM_NODES, 128)
        residual_node = hist_embeddings # first layer, all embeddings are come from the historical embeddings
        residual_pred = pred_embed
        
        hidden_states = list()
        
        for i in range(NUM_LAYERS_REL):
            
            pred_gate = self.pred_lin_trans_sigmoid[i](residual_pred) * adj_mask
            pred_gate = tf.expand_dims(tf.transpose(pred_gate, [2, 0, 1]), axis=0)
            pred_bias = self.pred_lin_trans_bias[i](residual_pred) * adj_mask
            pred_bias = tf.expand_dims(tf.transpose(pred_bias, [2, 0, 1]), axis=0)
            
            node_output, pred_output, attention = self.self_att_nodes[i](residual_node, residual_node, mask=pred_gate, bias=None)
            
            node_output = node_output
            node_output = self.self_att_nodes_dropout[i](node_output)
            node_output = residual_node + node_output
            node_output = self.self_att_nodes_norm[i](node_output)
            
            pred_output = self.self_att_preds_dropout[i](pred_output)
            pred_output = residual_pred + pred_output
            pred_output = self.self_att_preds_norm[i](pred_output)
            
            output = self.lin_transform_2_o[i](self.lin_transform_1_o[i](node_output))
            output = self.dropout_o[i](output)
            output = output + residual_node
            output = self.batch_norm_o[i](output)
            
            residual_node = output
            residual_pred = pred_output
            
            hidden_states.append(residual_node)
        
        final_output = residual_node
        attention = tf.reduce_mean(attention, axis=1)
        return tf.reshape(final_output[:,:BATCH_SIZE,:], [BATCH_SIZE, final_output.shape[-1]]), hidden_states, tf.reshape(attention[:,:BATCH_SIZE,:], [BATCH_SIZE, attention.shape[-1]]) # return only the batch node

In [None]:
#### ENCODER
class Encoder(tf.keras.Model):
    def __init__(self):
        super(Encoder, self).__init__()
        self.chr_embed = tf.keras.layers.Embedding(VOCAB_CHR_SIZE, 
                                                   CHR_EMBED_DIM, 
                                                   mask_zero=False, 
                                                   trainable=True,
                                                   name="source_char_embedding")
        self.pre_embed = tf.keras.layers.Embedding(VOCAB_PRE_SIZE, 
                                                   HIDDEN_UNITS, 
                                                   mask_zero=False, 
                                                   trainable=True,
                                                   name="source_predicate_embedding")
        self.gru_char = gru(HIDDEN_UNITS)
        self.trans_att = AttAggregator()
        self.trans_rel = TransGNN()
    
    def call(self, att_keys, att_values, node_list, adj_matrix):

        # process attribute data
        char_x = tf.unstack(att_values, axis=1)
        char_x_embed = list()
        for t in range(len(char_x)):
            mask_char = tf.expand_dims(1 - tf.cast(tf.equal(char_x[t], 0), dtype=tf.float32), axis=2)
            char_xt = self.chr_embed(char_x[t])
            char_xt, __ = self.gru_char(char_xt)
            char_xt = char_xt * mask_char
            char_xt = tf.reduce_sum(char_xt, axis=1)
            char_x_embed.append(char_xt)
        
        mask_att = tf.expand_dims(1 - tf.cast(tf.equal(att_keys, 0), dtype=tf.float32), axis=2) #(BATCH_SIZE, 20, 1)
        att_values = tf.stack(char_x_embed, axis = 1) #(BATCH_SIZE, 20, 128)
        att_keys = self.pre_embed(att_keys) #(BATCH_SIZE, 20, 128)
        att_rep, att_attention = self.trans_att(att_keys, att_values, mask_att) # entity representation based on attributes #(BATCH_SIZE, 128)
        
        
        # process graph data
        pred_embed = self.pre_embed(adj_matrix) #(NUM_NODES, NUM_NODES, 128)
        adj_mask = tf.expand_dims(1 - tf.cast(tf.equal(adj_matrix, 0), dtype=tf.float32), axis = 2) #(NUM_NODES, NUM_NODES, 1)
        node_rep, hidden_states, node_attention = self.trans_rel(node_list, pred_embed, adj_mask) # entity representation based on structures
        
        return att_rep, node_rep, hidden_states, att_attention, node_attention
    
    
encoder = Encoder()

for (batch, ent_pair) in enumerate(train_batch):
    kg1_data, kg2_data, neg_data = get_batch_data(ent_pair.numpy())
    
    kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix = kg1_data
    kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix = kg2_data
    neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix = neg_data
    
    kg1_att_embed, kg1_node_embed, kg1_states, kg1_att_attention, kg1_node_attention = encoder(kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix)
    kg2_att_embed, kg2_node_embed, kg2_states, kg2_att_attention, kg2_node_attention = encoder(kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix)
    neg_att_embed, neg_node_embed, neg_states, neg_att_attention, neg_node_attention = encoder(neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix)
    break

print('Shape of kg1_att_embed: %s' % kg1_att_embed.shape)
print('Shape of kg2_att_embed: %s' % kg2_att_embed.shape)
print('Shape of neg_att_embed: %s' % neg_att_embed.shape)

print('Shape of kg1_node_embed: %s' % kg1_node_embed.shape)
print('Shape of kg2_node_embed: %s' % kg2_node_embed.shape)
print('Shape of neg_node_embed: %s' % neg_node_embed.shape)

print('Shape of kg1_att_attention: %s' % kg1_att_attention.shape)
print('Shape of kg2_att_attention: %s' % kg2_att_attention.shape)
print('Shape of neg_att_attention: %s' % neg_att_attention.shape)

print('Shape of kg1_node_attention: %s' % kg1_node_attention.shape)
print('Shape of kg2_node_attention: %s' % kg2_node_attention.shape)
print('Shape of neg_node_attention: %s' % neg_node_attention.shape)

### Training

In [None]:
def distill_loss(real, pred):
    real = tf.math.l2_normalize(real,1)        
    pred = tf.math.l2_normalize(pred,1)
    cos_sim = tf.reduce_sum(tf.multiply(real, pred), 1, keepdims=True)
    sim_loss = tf.reduce_sum(1-cos_sim)
    return sim_loss

def margin_based_ranking_loss(x, pos_sample, neg_sample):
    pos = tf.reduce_sum(abs(x - pos_sample), 1, keepdims = True)
    neg = tf.reduce_sum(abs(x - neg_sample), 1, keepdims = True)
    loss = tf.reduce_sum(tf.maximum(pos - neg + MARGIN_LOSS, 0))
    return loss

In [None]:
### TEST ###

from tqdm.notebook import tqdm

def get_test_embeddings():
    ### GET TEST DATA ###
    test_batch = tf.data.Dataset.from_tensor_slices((test_data))
    test_batch = tf.data.Dataset.batch(test_batch, batch_size = BATCH_SIZE, drop_remainder=True)
    ### ============= ###

    ### GET ALL KG1 and KG2 EMBEDDING ###
    KG1_embed = None
    KG2_embed = None
    KG1_att_atts = None
    KG2_att_atts = None
    KG1_node_atts = None
    KG2_node_atts = None
    included_kg2_ent = set()
    for (batch, ent_pair) in enumerate(test_batch):
        kg1_data, kg2_data, _ = get_batch_data(ent_pair.numpy())
        
        kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix = kg1_data
        kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix = kg2_data
        
        kg1_att_embed, kg1_node_embed, _, kg1_att_attention, kg1_node_attention = encoder(kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix)
        kg2_att_embed, kg2_node_embed, _, kg2_att_attention, kg2_node_attention = encoder(kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix)
    
        if batch == 0:
            KG1_embed = tf.concat([kg1_att_embed], axis=-1)
            KG2_embed = tf.concat([kg2_att_embed], axis=-1)
            KG1_att_atts = tf.unstack(kg1_att_attention, axis = 0)
            KG2_att_atts = tf.unstack(kg2_att_attention, axis = 0)
            KG1_node_atts = tf.unstack(kg1_node_attention, axis = 0)
            KG2_node_atts = tf.unstack(kg2_node_attention, axis = 0)
        else:
            KG1_embed = np.concatenate([KG1_embed, tf.concat([kg1_att_embed], axis=-1)], axis=0)
            KG2_embed = np.concatenate([KG2_embed, tf.concat([kg2_att_embed], axis=-1)], axis=0)
            KG1_att_atts = KG1_att_atts + tf.unstack(kg1_att_attention, axis = 0)
            KG2_att_atts = KG2_att_atts + tf.unstack(kg2_att_attention, axis = 0)
            KG1_node_atts = KG1_node_atts + tf.unstack(kg1_node_attention, axis = 0)
            KG2_node_atts = KG2_node_atts + tf.unstack(kg2_node_attention, axis = 0)
    ### ===================== ###

    ### GET REMAINING KG2 EMBEDDING (which is from the seed_data) ###
    for (batch, ent_pair) in enumerate(train_batch):
        _, kg2_data, _ = get_batch_data(ent_pair.numpy())
        kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix = kg2_data
        kg2_att_embed, kg2_node_embed, _, kg2_att_attention, kg2_node_attention = encoder(kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix)
        KG2_embed = np.concatenate([KG2_embed, tf.concat([kg2_att_embed], axis=-1)], axis=0)
        KG2_att_atts = KG2_att_atts + tf.unstack(kg2_att_attention, axis = 0)
        KG2_node_atts = KG2_node_atts + tf.unstack(kg2_node_attention, axis = 0)
    ### =========================== ###
    attentions = (KG1_att_atts, KG2_att_atts, KG1_node_atts, KG2_node_atts)
    return KG1_embed, KG2_embed, attentions
    

def test():
    print ("compute embeddings... ")
    KG1_embed, KG2_embed, _ = get_test_embeddings()
    
    print ("compute similarity... ")
    #sim = scipy.spatial.distance.cdist(KG1_embed, KG2_embed, metric='cosine')
    sim = scipy.spatial.distance.cdist(KG1_embed, KG2_embed, metric='cityblock')
    print ("sorting results... ")
    sim = sim.argsort()
    
    hits_1 = list()
    hits_10 = list()

    for (idx, similarity) in enumerate(sim):
        top_res = similarity

        if idx in top_res[:1]:
            hits_1.append(1)
        else:
            hits_1.append(0)
        if idx in top_res[:10]:
            hits_10.append(1)
        else:
            hits_10.append(0)
    return mean(hits_1), mean(hits_10)

test()

In [None]:
LR = 0.0001
BETA = 0.01
GAMMA = 1e-4
#lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LR,decay_steps=10000,decay_rate=0.9)
#optimizer = tf.keras.optimizers.Adam(lr_schedule, clipnorm=1.)
optimizer = tf.keras.optimizers.Adam(LR)

checkpoint_dir = 'output/model/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, hist_embed=hist_embed)

In [None]:
EPOCHS = TOTAL_EPOCHS

log_file = open("log/log", "w", buffering=1)

best_hit = 0

for epoch in range(EPOCHS):
    
    start = time.time()
    all_loss_list = list()
    db_loss_list = list()
    dh_loss_list = list()
    mbr_loss_list = list()
    
    for (batch, ent_pair) in enumerate(train_batch):
        kg1_data, kg2_data, neg_data = get_batch_data(ent_pair.numpy())

        kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix = kg1_data
        kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix = kg2_data
        neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix = neg_data
    

        with tf.GradientTape() as tape:
            kg1_att_embed, kg1_node_embed, kg1_states, _, _ = encoder(kg1_attr_keys, kg1_attr_vals, kg1_node_list, kg1_adj_matrix)
            kg2_att_embed, kg2_node_embed, kg2_states, _, _ = encoder(kg2_attr_keys, kg2_attr_vals, kg2_node_list, kg2_adj_matrix)
            neg_att_embed, neg_node_embed, neg_states, _, _ = encoder(neg_attr_keys, neg_attr_vals, neg_node_list, neg_adj_matrix)

            # contrastive loss
            kg1_final_embed = tf.concat([kg1_att_embed], axis=-1)
            kg2_final_embed = tf.concat([kg2_att_embed], axis=-1)
            neg_final_embed = tf.concat([neg_att_embed], axis=-1)
            m_loss = margin_based_ranking_loss(kg1_final_embed, kg2_final_embed, neg_final_embed)

            #distillation batch embeddings
            kg1_batch_embed = hist_embed.get_batch_embed(kg1_node_list)
            kg2_batch_embed = hist_embed.get_batch_embed(kg2_node_list)
            neg_batch_embed = hist_embed.get_batch_embed(neg_node_list)
            db_loss1 = distill_loss(kg1_att_embed, kg1_batch_embed[:BATCH_SIZE,:])
            db_loss2 = distill_loss(kg2_att_embed, kg2_batch_embed[:BATCH_SIZE,:])
            db_loss3 = distill_loss(neg_att_embed, neg_batch_embed[:BATCH_SIZE,:])
            db_loss = (db_loss1 + db_loss2 + db_loss3)/3
            
            #distillation hist embeddings
            kg1_hist_embed = hist_embed.get_hist_embed(kg1_node_list)
            kg2_hist_embed = hist_embed.get_hist_embed(kg2_node_list)
            neg_hist_embed = hist_embed.get_hist_embed(neg_node_list)
            dh_loss1 = distill_loss(kg1_node_embed, kg1_hist_embed[:BATCH_SIZE,:])
            dh_loss2 = distill_loss(kg2_node_embed, kg2_hist_embed[:BATCH_SIZE,:])
            dh_loss3 = distill_loss(neg_node_embed, neg_hist_embed[:BATCH_SIZE,:])
            dh_loss = (dh_loss1 + dh_loss2 + dh_loss3)/3

            the_loss = m_loss + db_loss + dh_loss
            
            # to ensure closeness of historical embeddings
            aux_loss = 0
            for i in range(1, len(kg1_states)):
                aux_loss += tf.nn.l2_loss((kg1_states[i] + GAMMA) - kg1_states[i-1]) \
                            + tf.nn.l2_loss((kg2_states[i] + GAMMA) - kg2_states[i-1]) \
                            + tf.nn.l2_loss((neg_states[i] + GAMMA) - neg_states[i-1])
            the_loss = the_loss + (BETA * aux_loss)
            
            all_loss_list.append(the_loss.numpy())
            db_loss_list.append(db_loss.numpy())
            dh_loss_list.append(dh_loss.numpy())
            mbr_loss_list.append(m_loss.numpy())

        variables = encoder.variables + hist_embed.variables
        gradients = tape.gradient(the_loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        if (batch+1) % 10 == 0:
            log_str = ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1, batch+1, mean(all_loss_list)))
            log_str += (' MBRL {:.4f} DB {:.4f} DH {:.4f}'.format(mean(mbr_loss_list), mean(db_loss_list), mean(dh_loss_list)))
            print (log_str)
            log_file.write(log_str+"\n")
    
    if (epoch+1)%10 == 0:
        hit1, hit10 = test()
        log_str = ('Epoch {} Hit-1 {:4f} Hit-10 {:.4f}'.format(epoch+1, hit1, hit10))
        print (log_str)
        log_file.write(log_str+"\n")
        
        if hit1 > best_hit:
            checkpoint.save(file_prefix = checkpoint_prefix)
            best_hit = hit1
    
    log_str = ('Epoch {} Loss {:.4f}'.format(epoch + 1, mean(all_loss_list)))
    print (log_str)
    log_file.write(log_str+"\n")
    log_str = ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
    print (log_str)
    log_file.write(log_str+"\n")
print ("DONE")
log_file.write("DONE\n")
log_file.close()

In [None]:
# LOAD EXISTING MODEL
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))