diff --git a/main.py b/main.py index eaf337f..f571ecf 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import os import pprint import tensorflow as tf @@ -30,6 +31,9 @@ def main(_): count = [] word2idx = {} + if not os.path.exists(FLAGS.checkpoint_dir): + os.makedirs(FLAGS.checkpoint_dir) + train_data = read_data('%s/%s.train.txt' % (FLAGS.data_dir, FLAGS.data_name), count, word2idx) valid_data = read_data('%s/%s.valid.txt' % (FLAGS.data_dir, FLAGS.data_name), count, word2idx) test_data = read_data('%s/%s.test.txt' % (FLAGS.data_dir, FLAGS.data_name), count, word2idx)