In [1]:
import config
import models
import tensorflow as tf
import numpy as np
import os


In [2]:
train_data_path = "./benchmarks/FB15K/"
test_data_path = "./benchmarks/FB15K_OOV/"

train_file_path = "./res/model.vec.tf"
train_embedding_path = "./res/embedding.vec.json"

test_file_path = "./res/model_new.vec.tf"
test_embedding_path = "./res/embedding_new.vec.json"

# Run TransE To Create initial embeddings

In [3]:
"""
Method:

Run the normal transe example (example_train_transe.py)
Write the embeddings as a file that can be read
use embeddings to initialize embedding layer in TransE_freeze.py

Append random embeddings for new entities and relations
set config freeze_train_embeddings = true
figure out how to update only the embeddings for a certain set of indices
figure out how to make sure that we only see examples using new items to speed convergence
compare new+old embeddings
"""
os.environ['CUDA_VISIBLE_DEVICES']='7'
con = config.Config()
#True: Input test files from the same folder.
con.set_in_path(train_data_path)
con.set_test_link_prediction(True)
con.set_test_triple_classification(True)
con.set_work_threads(8)
con.set_train_times(10)
con.set_nbatches(20)
con.set_alpha(0.001)
con.set_margin(1.0)
con.set_bern(0)
con.set_dimension(100)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")

In [4]:
#Models will be exported via tf.Saver() automatically.
con.set_export_files(train_file_path, 0)
#Model parameters will be exported to json files automatically.
con.set_out_files(train_embedding_path)
#Initialize experimental settings.
con.init()

In [5]:

#Set the knowledge embedding model
con.set_model(models.TransE)



For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use `tf.global_variables_initializer` instead.


In [6]:
#Train the model.
con.run()

Epoch: 0, loss: 301717.369140625, time: 9.5367431640625e-07
Epoch: 1, loss: 127094.7919921875, time: 1.1920928955078125e-06
Epoch: 2, loss: 87462.60961914062, time: 7.152557373046875e-07
Epoch: 3, loss: 70860.0908203125, time: 0.0
Epoch: 4, loss: 61540.795166015625, time: 9.5367431640625e-07
Epoch: 5, loss: 55482.537841796875, time: 9.5367431640625e-07
Epoch: 6, loss: 50941.429443359375, time: 9.5367431640625e-07
Epoch: 7, loss: 47697.72802734375, time: 9.5367431640625e-07
Epoch: 8, loss: 45372.07763671875, time: 9.5367431640625e-07
Epoch: 9, loss: 44106.403564453125, time: 1.1920928955078125e-06


# Run TransE Freeze using Embeddings produced in above step

In [7]:
os.environ['CUDA_VISIBLE_DEVICES']='7'
#Input training files from benchmarks/FB15K/ folder.
con = config.Config()
#True: Input test files from the same folder.
con.set_in_path(test_data_path)
con.set_test_link_prediction(True)
con.set_test_triple_classification(True)
con.set_work_threads(8)
con.set_train_times(10)
con.set_nbatches(20)
con.set_alpha(0.001)
con.set_margin(1.0)
con.set_bern(0)
con.set_dimension(100)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")
con.set_freeze_train_embeddings(True)
con.set_ent_embedding_initializer(train_embedding_path)
con.set_rel_embedding_initializer(train_embedding_path)

In [8]:
#Models will be exported via tf.Saver() automatically.
con.set_export_files(test_file_path, 0)
#Model parameters will be exported to json files automatically.
con.set_out_files(test_embedding_path)
#Initialize experimental settings.
con.init()
#Set the knowledge embedding model
con.set_model(models.TransE_freeze)

In [9]:
#Train the model.
con.run()

Epoch: 0, loss: 42157.26647949219, time: 0.0
Epoch: 1, loss: 40747.789794921875, time: 9.5367431640625e-07
Epoch: 2, loss: 38899.87634277344, time: 1.1920928955078125e-06
Epoch: 3, loss: 37705.41516113281, time: 1.1920928955078125e-06
Epoch: 4, loss: 36526.37060546875, time: 1.1920928955078125e-06
Epoch: 5, loss: 35860.90673828125, time: 9.5367431640625e-07
Epoch: 6, loss: 34948.32189941406, time: 0.0
Epoch: 7, loss: 34169.264404296875, time: 1.1920928955078125e-06
Epoch: 8, loss: 33043.1103515625, time: 9.5367431640625e-07
Epoch: 9, loss: 31239.524658203125, time: 9.5367431640625e-07


# Compare new and old embeddings

In [10]:
import json
with open("./res/embedding.vec.json", "r") as f: 
    old_embeddings = json.loads(f.read())
    old_ent_embeddings = old_embeddings["ent_embeddings"]
    old_rel_embeddings = old_embeddings["rel_embeddings"]


In [11]:
with open("./res/embedding_new.vec.json", "r") as f: 
    new_embeddings = json.loads(f.read())
    new_ent_embeddings = new_embeddings["ent_embeddings"]
    new_rel_embeddings = new_embeddings["rel_embeddings"]

In [12]:
old_ent_embeddings[0] == new_ent_embeddings[0]

False

In [13]:
old_rel_embeddings[0] == new_rel_embeddings[0]

False

In [14]:
len(new_rel_embeddings)

1346

In [15]:
len(old_rel_embeddings)

1345

In [16]:
new_rel_embeddings[1345]

[-0.018950240686535835,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.022950241342186928,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.024950241670012474,
 -0.018950240686535835,
 -0.022950241342186928,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.022950241342186928,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.01695024035871029,
 -0.022950241342186928,
 -0.02095024101436138,
 -0.02095024101436138,
 -0.02095024101436138,
 -0.02095024101436138,
 -0.022950241342186928,
 -0.018950240686535835,
 -0.02095024101436138,
 -0.022950241342186928,
 -0.02095024101436138,
 -0.018950240686535835,
 -0.018950240686535835,
 -0.022950241342186928,
 -0.020