Skip to content

Commit

Permalink
#8: added save/restore features
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Ruthmann committed Jul 11, 2018
1 parent 5db76de commit 2c38f5e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 35 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,11 @@ venv.bak/
benchmarks/
result/
.ifiske/
tests/tmp
tests/tmp

# Tensorflow models
checkpoint
*.data-*
*.index
/*.json
*.meta
39 changes: 28 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,34 @@ To compute a knowledge graph embedding, first import datasets and set configure
from openke.models import TransE

# Read the dataset
base = Dataset("./benchmarks/fb15k.nt")

# Set the knowledge embedding model class.
model = TransE(50, 1.0, base.shape)

# Train the model.
base.train(500, model, count=100, negatives=(1,0), bern=False, workers=4)

# Save the result.
model.save("./result")

ds = Dataset("./benchmarks/fb15k.nt")

# Configure parameters
folds = 20
neg_ent = 2
neg_rel = 0


# Set the knowledge embedding model class.
def model():
return TransE(50, 1.0, ds.ent_count, ds.rel_count, batch_size=ds.size // folds, variants=1 + neg_rel + neg_ent)


# Train the model. It is saved in the process.
model = ds.train(
model,
folds=folds,
epochs=20,
post_epoch=print,
prefix="./TransE",
neg_ent=neg_ent,
neg_rel=neg_rel,
bern=False,
workers=4,
)

# Save the embedding to a JSON file
model.save_to_json("TransE.json")

## Interfaces

Expand Down
5 changes: 4 additions & 1 deletion examples/train_transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ def model():
folds=folds,
epochs=20,
post_epoch=print,
prefix="./result",
prefix="./TransE",
neg_ent=neg_ent,
neg_rel=neg_rel,
bern=False,
workers=4,
)

# Save the embedding to a JSON file
model.save_to_json("TransE.json")
44 changes: 30 additions & 14 deletions openke/Config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
import ctypes
import datetime as dt
import os
from ctypes import cdll, c_void_p, c_int64

import numpy as np
Expand Down Expand Up @@ -86,7 +88,7 @@ def __len__(self):
# types = (np.int64, np.int64, np.int64, np.float32)
# batches = [np.zeros(batch_size_neg, dtype=t) for t in types]
# h_addr, t_addr, l_addr, y_addr = [get_array_pointer(x) for x in batches]
#
# TODO: Move randReset
# self.__library.randReset(workers, seed)
#
# sampling = self.__library.bernSampling if bern else self.__library.sampling
Expand All @@ -96,9 +98,10 @@ def __len__(self):

# TODO: Add selection for best-performing model
def train(self, model_constructor, folds=1, epochs=1, model_count=1, prefix='best', post_epoch=None,
post_batch=None, **kwargs):
post_batch=None, autosave: float = 30, continue_training=True, **kwargs):
"""
Training algorithm over the whole dataset.
Training algorithm over the whole dataset. The model is saved as a TensorFlow binary at the end of the training
and every `autosave` minutes.
:param model_constructor: Parameterless constructor of to-be-trained embedding models
:param folds: Number of batches
Expand All @@ -107,11 +110,14 @@ def train(self, model_constructor, folds=1, epochs=1, model_count=1, prefix='bes
:param prefix: Prefix to save the model
:param post_epoch: Callback at the end of each epoch (receiving epoch number and loss)
:param post_batch: Callback at the end of each batch (receiving batches and loss)
:param autosave: Time in minute after which the model is saved.
:param continue_training: Flag, which states whether the training process should continue with the existing
model. If False, the model is newly trained.
:param kwargs: Optional kwargs for the batch creation. Possible values are: neg_ent, neg_rel, bern, workers,
seed
"""
# Prepare batches
neg_ent = kwargs.get("neg_ent", 0)
neg_ent = kwargs.get("neg_ent", 1)
neg_rel = kwargs.get("neg_rel", 0)
bern = kwargs.get("bern", True)
workers = kwargs.get("workers", 1)
Expand All @@ -126,6 +132,10 @@ def train(self, model_constructor, folds=1, epochs=1, model_count=1, prefix='bes

# create model
m = model_constructor()
if os.path.exists(prefix + ".index") and continue_training:
print(f"Found model with prefix {prefix}. Continuing training ...")
m.restore(prefix)
datetime_next_save = dt.datetime.now() + dt.timedelta(minutes=autosave)

for epoch in range(epochs):
loss = 0.0
Expand All @@ -144,6 +154,12 @@ def train(self, model_constructor, folds=1, epochs=1, model_count=1, prefix='bes
if post_epoch:
post_epoch(epoch, loss)

# Save
if dt.datetime.now() > datetime_next_save:
print(f"Autosave in epoch {epoch} ...")
m.save(prefix)
datetime_next_save = dt.datetime.now() + dt.timedelta(minutes=autosave)

m.save(prefix)
return m

Expand Down Expand Up @@ -244,16 +260,16 @@ def __init__(self):
def __getitem__(self, key):
if key in self.__dict:
return self.__dict[key]
l = cdll.LoadLibrary(key)
l.sampling.argtypes = [c_void_p, c_void_p, c_void_p, c_void_p, c_int64, c_int64, c_int64, c_int64]
l.bernSampling.argtypes = l.sampling.argtypes
l.query_head.argtypes = [c_void_p, c_int64, c_int64]
l.query_tail.argtypes = [c_int64, c_void_p, c_int64]
l.query_rel.argtypes = [c_int64, c_int64, c_void_p]
l.importTrainFiles.argtypes = [c_void_p, c_int64, c_int64]
l.randReset.argtypes = [c_int64, c_int64]
self.__dict[key] = l
return l
lib = cdll.LoadLibrary(key)
lib.sampling.argtypes = [c_void_p, c_void_p, c_void_p, c_void_p, c_int64, c_int64, c_int64, c_int64]
lib.bernSampling.argtypes = lib.sampling.argtypes
lib.query_head.argtypes = [c_void_p, c_int64, c_int64]
lib.query_tail.argtypes = [c_int64, c_void_p, c_int64]
lib.query_rel.argtypes = [c_int64, c_int64, c_void_p]
lib.importTrainFiles.argtypes = [c_void_p, c_int64, c_int64]
lib.randReset.argtypes = [c_int64, c_int64]
self.__dict[key] = lib
return lib


_l = _Library()
Expand Down
44 changes: 36 additions & 8 deletions openke/models/Base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
import json

import numpy as np
import tensorflow as tf
from tensorflow.python.training.saver import Saver

from openke import norm

Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(self, ent_count, rel_count, batch_size=0, variants=0, optimizer=Non
self.__prediction = self._predict_def()
grads_and_vars = optimizer.compute_gradients(self.__loss)
self.__training = optimizer.apply_gradients(grads_and_vars)
self.__saver = Saver()
self.__saver = tf.train.Saver()
self.__session.run(tf.global_variables_initializer())

def __iter__(self):
Expand Down Expand Up @@ -172,13 +173,40 @@ def entity(self, head=None):
}
return self.__session.run(self._entity, feed)

def save(self, fileprefix):
"""Writes the model's state into persistent memory."""
self.__saver.save(self.__session, fileprefix)
def save(self, prefix: str, step: int = None):
"""
Save the model to filesystem.
:param prefix: File prefix for the model
:param step: Step of the model (appended to prefix)
"""
if step:
self.__saver.save(self.__session, prefix, global_step=step)
else:
self.__saver.save(self.__session, prefix)

def save_to_json(self, filename: str):
"""
Save the embedding as JSON file. The JSON file contains the embedding parameters (e.g. entity and relation
matrices). These parameters depend on the model.
:param filename: Filename for the output JSON file
"""
content = {}
for var_name in self.__parameters:
with self.__graph.as_default():
with self.__session.as_default():
content[var_name] = self.__session.run(self.__parameters[var_name]).tolist()
with open(filename, "w") as f:
f.write(json.dumps(content))

def restore(self, fileprefix):
"""Reads a model from persistent memory."""
self.__saver.restore(self.__session, fileprefix)
def restore(self, prefix: str):
"""
Reads a model from filesystem.
:param prefix: Model prefix of the model to laod
"""
self.__saver.restore(self.__session, prefix)

def get_positive_instance(self, in_batch=True):
if in_batch:
Expand Down

0 comments on commit 2c38f5e

Please sign in to comment.