Skip to content

Commit

Permalink
Merge pull request #24 from inikdom/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
inikdom committed Aug 16, 2016
2 parents 6739129 + 1a3e2d2 commit d16b7b2
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 19 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,11 @@ To run with preset hyper-parameters just run:

`$ python train.py`

To try other hyper-parameters:
To try other hyper-parameters, you can change them in the config file. If you want to provide your own:

`$ python train.py --dropout=0.8 --hidden_size=120` etc...

It will automatically check if parameters in checkpoint_dir match, if not it will
make a new directory inside checkpoint_dir to train the new network

The modifiable hyper-parameters are:
`$ python train.py --config_file="path_to_config"`

Descripton of hyper parameters:

| Name | Type | Description |
| :-------------------:|:-------------:|:-------------------------------------------|
Expand Down
60 changes: 50 additions & 10 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Written by: Dominik Kaukinen
'''
import tensorflow as tf
from tensorflow.models.rnn import rnn, rnn_cell, seq2seq
from tensorflow.python.platform import gfile
import numpy as np
import sys
Expand All @@ -18,16 +17,18 @@
import util.dataprocessor
import models.sentiment
import util.vocabmapping
import ConfigParser

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('checkpoint_dir', 'data/checkpoints/', 'Directory to store/restore checkpoints')
flags.DEFINE_string('text', 'Hello World!', 'Text to sample with.')
flags.DEFINE_string('config_file', 'config.ini', 'Path to configuration file.')

def main():
vocab_mapping = util.vocabmapping.VocabMapping()
with tf.Session() as sess:
model = loadModel(sess, vocab_mapping.getSize())
model = load_model(sess, vocab_mapping.getSize())
if model == None:
return
max_seq_length = model.max_seq_length
Expand Down Expand Up @@ -90,14 +91,53 @@ def loadModel(session, vocab_size):
model = None
return model

'''
Restore training hyper parameters.
This is a hack mostly, but I couldn't find another way to do this.
Ultimately, I don't think is that bad.
'''
def restoreHyperParameters():
path = os.path.join(FLAGS.checkpoint_dir, "hyperparams.npy")
return np.load(path)
def load_model(session, vocab_size):
hyper_params = read_config_file()
model = models.sentiment.SentimentModel(vocab_size,
hyper_params["hidden_size"],
1.0,
hyper_params["num_layers"],
hyper_params["grad_clip"],
hyper_params["max_seq_length"],
hyper_params["learning_rate"],
hyper_params["lr_decay_factor"],
1)
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
print "Reading model parameters from {0}".format(ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print "Double check you got the checkpoint_dir right..."
print "Model not found..."
model = None
return model

def read_config_file():
'''
Reads in config file, returns dictionary of network params
'''
config = ConfigParser.ConfigParser()
config.read(FLAGS.config_file)
dic = {}
sentiment_section = "sentiment_network_params"
general_section = "general"
dic["num_layers"] = config.getint(sentiment_section, "num_layers")
dic["hidden_size"] = config.getint(sentiment_section, "hidden_size")
dic["dropout"] = config.getfloat(sentiment_section, "dropout")
dic["batch_size"] = config.getint(sentiment_section, "batch_size")
dic["train_frac"] = config.getfloat(sentiment_section, "train_frac")
dic["learning_rate"] = config.getfloat(sentiment_section, "learning_rate")
dic["lr_decay_factor"] = config.getfloat(sentiment_section, "lr_decay_factor")
dic["grad_clip"] = config.getint(sentiment_section, "grad_clip")
dic["use_config_file_if_checkpoint_exists"] = config.getboolean(general_section,
"use_config_file_if_checkpoint_exists")
dic["max_epoch"] = config.getint(sentiment_section, "max_epoch")
dic ["max_vocab_size"] = config.getint(sentiment_section, "max_vocab_size")
dic["max_seq_length"] = config.getint(general_section,
"max_seq_length")
dic["steps_per_checkpoint"] = config.getint(general_section,
"steps_per_checkpoint")
return dic

def tokenize(text):
text = text.decode('utf-8')
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def main():
vocabmapping = util.vocabmapping.VocabMapping()
vocab_size = vocabmapping.getSize()
print "Vocab size is: {0}".format(vocab_size)
path = "data/processed/"
path = os.path.join(FLAGS.data_dir, "processed/")
infile = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
#randomize data order
print infile
data = np.load(os.path.join(path, infile[0]))
for i in range(1, len(infile)):
data = np.vstack((data, np.load(os.path.join(path, infile[i]))))
Expand Down Expand Up @@ -103,7 +104,7 @@ def main():
loss += test_loss
test_accuracy += accuracy
normalized_test_loss, normalized_test_accuracy = loss / len(model.test_data), test_accuracy / len(model.test_data)
checkpoint_path = os.path.join(path, "sentiment{0}.ckpt".format(normalized_test_accuracy))
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, "sentiment{0}.ckpt".format(normalized_test_accuracy))
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
writer.add_summary(str_summary, step)
print "Avg Test Loss: {0}, Avg Test Accuracy: {1}".format(normalized_test_loss, normalized_test_accuracy)
Expand Down
34 changes: 34 additions & 0 deletions util/hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
'''
This is the main logic for serializing and deserializing dictionaries
of hyperparameters (for use in checkpoint restoration and sampling)
'''
import os
import pickle

class HyperParameterHandler(object):
def __init__(self, path):
self.file_path = os.path.join(path, "hyperparams.p")

def saveParams(self, dic):
with open(self.file_path, 'wb') as handle:
pickle.dump(dic, handle)

def getParams(self):
with open(self.file_path, 'rb') as handle:
return pickle.load(handle)

def checkExists(self):
'''
Checks if hyper parameter file exists
'''
return os.path.exists(self.file_path)

def checkChanged(self, new_params):
if self.checkExists():
old_params = self.getParams()
return old_params["num_layers"] != new_params["num_layers"] or\
old_params["hidden_size"] != new_params["hidden_size"] or\
old_params["max_seq_length"] != new_params["max_seq_length"] or\
old_params["max_vocab_size"] != new_params["max_vocab_size"]
else:
return False

0 comments on commit d16b7b2

Please sign in to comment.