Skip to content

Commit

Permalink
Merge pull request #112 from neulab/philip
Browse files Browse the repository at this point in the history
Ported python2 codes to python3
  • Loading branch information
neubig committed Jul 3, 2017
2 parents 97bcdf6 + 050c2a2 commit ffb77ec
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 13 deletions.
5 changes: 4 additions & 1 deletion xnmt/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division, generators
import numpy as np
from collections import defaultdict, Counter
from six.moves.builtins import range, map
import math

class Evaluator(object):
Expand Down Expand Up @@ -216,8 +217,10 @@ def sim(self, word1, word2):
return -1

def seq_sim(self, l1, l2):
l1 = list(l1)
l2 = list(l2)
# compute matrix
F = [[0] * (len(l2) + 1) for i in xrange((len(l1) + 1))]
F = [[0] * (len(l2) + 1) for i in range((len(l1) + 1))]
for i in range(len(l1) + 1):
F[i][0] = i * self.gapPenalty
for j in range(len(l2) + 1):
Expand Down
5 changes: 3 additions & 2 deletions xnmt/input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import itertools
import os
import codecs
from collections import defaultdict
from six.moves import zip
from serializer import Serializable
Expand Down Expand Up @@ -75,9 +76,9 @@ def read_file(self, filename, max_num=None):
if self.vocab is None:
self.vocab = Vocab()
sents = []
with open(filename) as f:
with codecs.open(filename, encoding='utf-8') as f:
for line in f:
words = line.decode('utf-8').strip().split()
words = line.strip().split()
sent = [self.vocab.convert(word) for word in words]
sent.append(self.vocab.convert(Vocab.ES_STR))
sents.append(SimpleSentenceInput(sent))
Expand Down
5 changes: 3 additions & 2 deletions xnmt/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def set_serialize_params_recursive(self, obj):
if isinstance(val, Serializable):
obj.serialize_params[name] = val
self.set_serialize_params_recursive(val)
elif type(val) in [type(None), bool, int, float, str, unicode, datetime.datetime, list, dict, set]:
elif type(val) in [type(None), bool, int, float, str, datetime.datetime, list, dict, set] or \
sys.version_info[0] == 2 and type(val) == unicode:
obj.serialize_params[name] = val
else:
continue
Expand Down Expand Up @@ -82,7 +83,7 @@ def get_val_to_share_or_none(self, obj, shared_params):
if cur_val:
if val is None: val = cur_val
elif cur_val != val:
print "WARNING: inconsistent shared params %s" % str(shared_params)
print("WARNING: inconsistent shared params %s" % str(shared_params))
return None
return val
def resolve_param_name(self, obj, param_descr):
Expand Down
1 change: 0 additions & 1 deletion xnmt/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def i2w_from_vocab_file(vocab_file):
return vocab

def convert(self, w):
assert isinstance(w, unicode)
if w not in self.w2i:
if self.frozen:
assert self.unk_token != None, 'Attempt to convert an OOV in a frozen vocabulary with no UNK token set'
Expand Down
7 changes: 3 additions & 4 deletions xnmt/xnmt_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8

import codecs
from output import *
from serializer import *
import codecs
Expand Down Expand Up @@ -69,7 +70,7 @@ def xnmt_decode(args, model_elements=None):
# Perform decoding

translator.set_train(False)
with open(args.trg_file, 'wb') as fp: # Saving the translated output to a trg file
with codecs.open(args.trg_file, 'wb', encoding='utf-8') as fp: # Saving the translated output to a trg file
for src in src_corpus:
if args.max_src_len is not None and len(src) > args.max_src_len:
trg_sent = NO_DECODING_ATTEMPTED
Expand All @@ -78,9 +79,7 @@ def xnmt_decode(args, model_elements=None):
token_string = translator.translate(src, search_strategy)
trg_sent = output_generator.process(token_string)[0]

assert isinstance(trg_sent, unicode), "Expected unicode as translator output, got %s" % type(trg_sent)
trg_sent = trg_sent.encode('utf-8', errors='ignore')

#assert isinstance(trg_sent, unicode), "Expected unicode as translator output, got %s" % type(trg_sent)
fp.write(trg_sent + '\n')


Expand Down
5 changes: 3 additions & 2 deletions xnmt/xnmt_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import sys
import codecs
from evaluator import BLEUEvaluator, WEREvaluator, CEREvaluator
from options import Option, OptionParser
from xnmt_decode import NO_DECODING_ATTEMPTED
Expand All @@ -16,9 +17,9 @@ def read_data(loc_):
"""Reads the lines in the file specified in loc_ and return the list after inserting the tokens
"""
data = list()
with open(loc_) as fp:
with codecs.open(loc_, encoding="utf-8") as fp:
for line in fp:
t = line.decode('utf-8').split()
t = line.split()
data.append(t)
return data

Expand Down
1 change: 1 addition & 0 deletions xnmt/xnmt_run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,5 @@

for line in results:
experiment_name, eval_scores = line
eval_scores = " ".join(map(str, eval_scores))
print("{:<20}|{:<40}".format(experiment_name, eval_scores))
2 changes: 1 addition & 1 deletion xnmt/xnmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def create_model(self):
model_globals.default_layer_dim = self.args.default_layer_dim
model_globals.dropout = self.args.dropout
self.model = self.model_serializer.initialize_object(self.args.model, context)
print self.model_serializer.dump(self.model)
print(self.model_serializer.dump(self.model))

# Read in training and dev corpora
# src_vocab, trg_vocab = None, None
Expand Down

0 comments on commit ffb77ec

Please sign in to comment.