Skip to content

Commit

Permalink
Added more options to encdec
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiencro committed Apr 20, 2017
1 parent 50b23c5 commit 6c09206
Showing 1 changed file with 71 additions and 14 deletions.
85 changes: 71 additions & 14 deletions temp_scripts/char_encdec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import numpy as np
import json

import codecs
import itertools
import collections

class SentDec(Chain):
def __init__(self, V, Hw, Hs):
Expand Down Expand Up @@ -208,19 +211,32 @@ def report_nan(self):
if param.grad is not None:
if self.xp.any(self.xp.isnan(param.grad)):
print "nan in grad of", name

def encode_voc_list(voc_list, charlist):
chardict = {}
for num, c in enumerate(charlist):
chardict[c] = num

def make_data(filename, max_nb_ex = None, frequency_threshold = 2):
import codecs
import itertools
import collections
dataset = []
for w in voc_list:
encoded = [chardict[c] for c in w]
dataset.append(np.array(encoded, dtype = np.int32))

return dataset


def make_voc_list_from_text(filename, max_nb_ex = None, frequency_threshold = 2):
words = collections.defaultdict(int)
f = codecs.open(filename, encoding = "utf8")
for line in itertools.islice(f, max_nb_ex):
for w in line.strip().split(" "):
words[w] += 1

words = [w for w,cnt in words.iteritems() if cnt >= frequency_threshold]
return words

def make_data(filename, max_nb_ex = None, frequency_threshold = 2):
words = make_voc_list_from_text(filename, max_nb_ex = max_nb_ex, frequency_threshold = frequency_threshold)

print "collected", len(words), "words"

Expand All @@ -237,10 +253,7 @@ def make_data(filename, max_nb_ex = None, frequency_threshold = 2):
for num, c in enumerate(charlist):
chardict[c] = num

dataset = []
for w in sorted(words):
encoded = [chardict[c] for c in w]
dataset.append(np.array(encoded, dtype = np.int32))
dataset = encode_voc_list(sorted(words), charlist)

return dataset, charlist, chardict

Expand All @@ -258,7 +271,7 @@ def create_model(config):
ced = CharEncDec(V, Ec, H, nlayers_src=config["src_layers"])
return ced

def do_train_sentences(args):
# def do_train_sentences(args):


def do_train(args):
Expand Down Expand Up @@ -339,13 +352,16 @@ def train_once(print_loss = False, use_gumbel = False, temperature = 1, sample =
serializers.save_npz(args.dest + "char_encdec.model", ced)
num_iter += 1

def do_eval(args):
import IPython
config=json.load(open(args.config))
def load_encdec_from_config(config_fn, model_fn):
config=json.load(open(config_fn))
ced = create_model(config)
charlist = json.load(open(config["indexer"], "r"))
chardict = dict((c,i) for i,c in enumerate(charlist))
serializers.load_npz(args.model, ced)
serializers.load_npz(model_fn, ced)
return ced, charlist, chardict

def do_eval(args):
ced, charlist, chardict = load_encdec_from_config(args.config, args.model)

if args.gpu is not None:
chainer.cuda.Device(args.gpu).use()
Expand All @@ -366,6 +382,35 @@ def dec(hx):

IPython.embed()

def generate_voc_encodings(encoder, charlist, voc_list, mb_size=1024):
dataset = encode_voc_list(voc_list, charlist)
print "voc_size:", len(dataset)
xp = encoder.xp

encodings_list = []
cursor = 0
while cursor < len(dataset):
batch = dataset[cursor:cursor+mb_size]
if xp != np:
batch = [xp.array(bb) for bb in batch]
cursor += mb_size
encodings = encoder.compute_h(batch, train = False, use_workaround = True)
encodings_list.append(encodings.data.reshape(encodings.data.shape[1:]))
print "processed", cursor

result = np.vstack(encodings_list)
return result

def do_generate_voc_table(args):
voc_list = make_voc_list_from_text(args.filename, max_nb_ex = args.max_nb_ex, frequency_threshold = args.frequency_threshold)
ced, charlist, chardict = load_encdec_from_config(args.config, args.model)
if args.gpu is not None:
chainer.cuda.Device(args.gpu).use()
ced = ced.to_gpu(args.gpu)
encodings = generate_voc_encodings(ced.enc, charlist, voc_list, mb_size=args.mb_size)
json.dump(voc_list, open(args.dest + ".gen.voc_list", "w"))
np.savez(args.dest + ".gen.encodings", enc=encodings)

def main():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -395,11 +440,23 @@ def main():
parser_eval.add_argument("config")
parser_eval.add_argument("model")
parser_eval.add_argument("--gpu", type = int)

parser_generate_encodings = subparsers.add_parser('encode', description="Generate encodings", help="Encode", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser_generate_encodings.add_argument("config")
parser_generate_encodings.add_argument("model")
parser_generate_encodings.add_argument("filename")
parser_generate_encodings.add_argument("dest")
parser_generate_encodings.add_argument("--frequency_threshold", type = int, default = 2)
parser_generate_encodings.add_argument("--max_nb_ex", type = int)
parser_generate_encodings.add_argument("--mb_size", type = int, default = 1024)
parser_generate_encodings.add_argument("--gpu", type = int)

args = parser.parse_args()

func = {"make_data": do_make_data,
"train": do_train,
"eval": do_eval}[args.__subcommand_name]
"eval": do_eval,
"encode": do_generate_voc_table}[args.__subcommand_name]

func(args)

Expand Down

0 comments on commit 6c09206

Please sign in to comment.