diff --git a/README.md b/README.md index cbd3bfaaf..21cd91b79 100644 --- a/README.md +++ b/README.md @@ -23,15 +23,11 @@ To run with preset hyper-parameters just run: `$ python train.py` -To try other hyper-parameters: +To try other hyper-parameters, you can change them in the config file. If you want to provide your own: -`$ python train.py --dropout=0.8 --hidden_size=120` etc... - -It will automatically check if parameters in checkpoint_dir match, if not it will -make a new directory inside checkpoint_dir to train the new network - -The modifiable hyper-parameters are: +`$ python train.py --config_file="path_to_config"` +Descripton of hyper parameters: | Name | Type | Description | | :-------------------:|:-------------:|:-------------------------------------------| diff --git a/sample.py b/sample.py index 728702488..6290ea96a 100644 --- a/sample.py +++ b/sample.py @@ -8,7 +8,6 @@ Written by: Dominik Kaukinen ''' import tensorflow as tf -from tensorflow.models.rnn import rnn, rnn_cell, seq2seq from tensorflow.python.platform import gfile import numpy as np import sys @@ -18,16 +17,18 @@ import util.dataprocessor import models.sentiment import util.vocabmapping +import ConfigParser flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('checkpoint_dir', 'data/checkpoints/', 'Directory to store/restore checkpoints') flags.DEFINE_string('text', 'Hello World!', 'Text to sample with.') +flags.DEFINE_string('config_file', 'config.ini', 'Path to configuration file.') def main(): vocab_mapping = util.vocabmapping.VocabMapping() with tf.Session() as sess: - model = loadModel(sess, vocab_mapping.getSize()) + model = load_model(sess, vocab_mapping.getSize()) if model == None: return max_seq_length = model.max_seq_length @@ -90,14 +91,53 @@ def loadModel(session, vocab_size): model = None return model -''' -Restore training hyper parameters. -This is a hack mostly, but I couldn't find another way to do this. -Ultimately, I don't think is that bad. -''' -def restoreHyperParameters(): - path = os.path.join(FLAGS.checkpoint_dir, "hyperparams.npy") - return np.load(path) +def load_model(session, vocab_size): + hyper_params = read_config_file() + model = models.sentiment.SentimentModel(vocab_size, + hyper_params["hidden_size"], + 1.0, + hyper_params["num_layers"], + hyper_params["grad_clip"], + hyper_params["max_seq_length"], + hyper_params["learning_rate"], + hyper_params["lr_decay_factor"], + 1) + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if ckpt and gfile.Exists(ckpt.model_checkpoint_path): + print "Reading model parameters from {0}".format(ckpt.model_checkpoint_path) + model.saver.restore(session, ckpt.model_checkpoint_path) + else: + print "Double check you got the checkpoint_dir right..." + print "Model not found..." + model = None + return model + +def read_config_file(): + ''' + Reads in config file, returns dictionary of network params + ''' + config = ConfigParser.ConfigParser() + config.read(FLAGS.config_file) + dic = {} + sentiment_section = "sentiment_network_params" + general_section = "general" + dic["num_layers"] = config.getint(sentiment_section, "num_layers") + dic["hidden_size"] = config.getint(sentiment_section, "hidden_size") + dic["dropout"] = config.getfloat(sentiment_section, "dropout") + dic["batch_size"] = config.getint(sentiment_section, "batch_size") + dic["train_frac"] = config.getfloat(sentiment_section, "train_frac") + dic["learning_rate"] = config.getfloat(sentiment_section, "learning_rate") + dic["lr_decay_factor"] = config.getfloat(sentiment_section, "lr_decay_factor") + dic["grad_clip"] = config.getint(sentiment_section, "grad_clip") + dic["use_config_file_if_checkpoint_exists"] = config.getboolean(general_section, + "use_config_file_if_checkpoint_exists") + dic["max_epoch"] = config.getint(sentiment_section, "max_epoch") + dic ["max_vocab_size"] = config.getint(sentiment_section, "max_vocab_size") + dic["max_seq_length"] = config.getint(general_section, + "max_seq_length") + dic["steps_per_checkpoint"] = config.getint(general_section, + "steps_per_checkpoint") + return dic def tokenize(text): text = text.decode('utf-8') diff --git a/train.py b/train.py index 4141d8f29..82550b3cf 100644 --- a/train.py +++ b/train.py @@ -43,9 +43,10 @@ def main(): vocabmapping = util.vocabmapping.VocabMapping() vocab_size = vocabmapping.getSize() print "Vocab size is: {0}".format(vocab_size) - path = "data/processed/" + path = os.path.join(FLAGS.data_dir, "processed/") infile = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] #randomize data order + print infile data = np.load(os.path.join(path, infile[0])) for i in range(1, len(infile)): data = np.vstack((data, np.load(os.path.join(path, infile[i])))) @@ -103,7 +104,7 @@ def main(): loss += test_loss test_accuracy += accuracy normalized_test_loss, normalized_test_accuracy = loss / len(model.test_data), test_accuracy / len(model.test_data) - checkpoint_path = os.path.join(path, "sentiment{0}.ckpt".format(normalized_test_accuracy)) + checkpoint_path = os.path.join(FLAGS.checkpoint_dir, "sentiment{0}.ckpt".format(normalized_test_accuracy)) model.saver.save(sess, checkpoint_path, global_step=model.global_step) writer.add_summary(str_summary, step) print "Avg Test Loss: {0}, Avg Test Accuracy: {1}".format(normalized_test_loss, normalized_test_accuracy) diff --git a/util/hyperparams.py b/util/hyperparams.py new file mode 100644 index 000000000..fa3061a00 --- /dev/null +++ b/util/hyperparams.py @@ -0,0 +1,34 @@ +''' +This is the main logic for serializing and deserializing dictionaries +of hyperparameters (for use in checkpoint restoration and sampling) +''' +import os +import pickle + +class HyperParameterHandler(object): + def __init__(self, path): + self.file_path = os.path.join(path, "hyperparams.p") + + def saveParams(self, dic): + with open(self.file_path, 'wb') as handle: + pickle.dump(dic, handle) + + def getParams(self): + with open(self.file_path, 'rb') as handle: + return pickle.load(handle) + + def checkExists(self): + ''' + Checks if hyper parameter file exists + ''' + return os.path.exists(self.file_path) + + def checkChanged(self, new_params): + if self.checkExists(): + old_params = self.getParams() + return old_params["num_layers"] != new_params["num_layers"] or\ + old_params["hidden_size"] != new_params["hidden_size"] or\ + old_params["max_seq_length"] != new_params["max_seq_length"] or\ + old_params["max_vocab_size"] != new_params["max_vocab_size"] + else: + return False