Skip to content

Commit

Permalink
Simplify vocab (#475)
Browse files Browse the repository at this point in the history
* simplify vocab (some tests broken)

* fix unit tests
  • Loading branch information
msperber authored and neubig committed Jul 30, 2018
1 parent b297bba commit 77e2baf
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 113 deletions.
25 changes: 17 additions & 8 deletions test/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from xnmt.scorer import Softmax
from xnmt.translator import DefaultTranslator
from xnmt.search_strategy import BeamSearch, GreedySearch
from xnmt.vocab import Vocab

class TestForcedDecodingOutputs(unittest.TestCase):

Expand All @@ -29,9 +30,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down Expand Up @@ -64,9 +67,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down Expand Up @@ -99,9 +104,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down Expand Up @@ -137,9 +144,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down
19 changes: 13 additions & 6 deletions test/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from xnmt.translator import DefaultTranslator
from xnmt.scorer import Softmax
from xnmt.search_strategy import GreedySearch
from xnmt.vocab import Vocab

class TestForcedDecodingOutputs(unittest.TestCase):

Expand All @@ -29,9 +30,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down Expand Up @@ -66,9 +69,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down Expand Up @@ -102,9 +107,11 @@ def setUp(self):
layer_dim = 512
xnmt.events.clear()
ParamManager.init_param_col()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.model = DefaultTranslator(
src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
src_reader=PlainTextReader(vocab=src_vocab),
trg_reader=PlainTextReader(vocab=trg_vocab),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
Expand Down
4 changes: 2 additions & 2 deletions test/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from xnmt.embedder import PretrainedSimpleWordEmbedder
from xnmt.param_collection import ParamManager
import xnmt.events

from xnmt.vocab import Vocab

class PretrainedSimpleWordEmbedderSanityTest(unittest.TestCase):
def setUp(self):
xnmt.events.clear()
self.input_reader = PlainTextReader()
self.input_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab"))
list(self.input_reader.read_sents('examples/data/head.ja'))
ParamManager.init_param_col()

Expand Down
6 changes: 4 additions & 2 deletions test/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def setUp(self):
xnmt.events.clear()
ParamManager.init_param_col()

self.src_reader = PlainTextReader()
self.trg_reader = PlainTextReader()
src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
self.src_reader = PlainTextReader(vocab=src_vocab)
self.trg_reader = PlainTextReader(vocab=trg_vocab)
self.src_data = list(self.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.trg_reader.read_sents("examples/data/head.en"))

Expand Down
2 changes: 1 addition & 1 deletion test/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def setUp(self):
self.hyp = ["the taro met the hanako".split()]
self.ref = ["taro met hanako".split()]

vocab = Vocab()
vocab = Vocab(i2w=["the","taro","met","hanako"])
self.hyp_id = list(map(vocab.convert, self.hyp[0]))
self.ref_id = list(map(vocab.convert, self.ref[0]))

Expand Down
11 changes: 4 additions & 7 deletions test/test_segmenting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def setUp(self):
ParamManager.init_param_col()
self.segment_encoder_bilstm = BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim)
self.segment_composer = SumComposer()
self.src_reader = CharFromWordTextReader()
self.trg_reader = PlainTextReader()
self.src_reader = CharFromWordTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.charvocab"))
self.trg_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab"))
self.loss_calculator = AutoRegressiveMLELoss()


Expand Down Expand Up @@ -184,8 +184,8 @@ def setUp(self):
xnmt.events.clear()
ParamManager.init_param_col()
self.segment_composer = SumComposer()
self.src_reader = CharFromWordTextReader()
self.trg_reader = PlainTextReader()
self.src_reader = CharFromWordTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.charvocab"))
self.trg_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab"))
self.loss_calculator = AutoRegressiveMLELoss()
self.segmenting_encoder = SegmentingSeqTransducer(
segment_composer = self.segment_composer,
Expand Down Expand Up @@ -225,7 +225,6 @@ def inp_emb(self, idx=0):
def test_lookup_composer(self):
enc = self.segmenting_encoder
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
word_vocab.freeze()
enc.segment_composer = LookupComposer(
word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
Expand All @@ -236,7 +235,6 @@ def test_lookup_composer(self):
def test_charngram_composer(self):
enc = self.segmenting_encoder
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
word_vocab.freeze()
enc.segment_composer = CharNGramComposer(
word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
Expand All @@ -247,7 +245,6 @@ def test_charngram_composer(self):
def test_add_multiple_segment_composer(self):
enc = self.segmenting_encoder
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
word_vocab.freeze()
enc.segment_composer = SumMultipleComposer(
composers = [
LookupComposer(word_vocab = word_vocab,
Expand Down
16 changes: 8 additions & 8 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def setUp(self):
xnmt.events.clear()
ParamManager.init_param_col()

self.src_reader = PlainTextReader()
self.trg_reader = PlainTextReader()
self.src_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab"))
self.trg_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab"))
self.src_data = list(self.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.trg_reader.read_sents("examples/data/head.en"))

Expand Down Expand Up @@ -163,8 +163,8 @@ def setUp(self):
xnmt.events.clear()
ParamManager.init_param_col()

self.src_reader = PlainTextReader()
self.trg_reader = PlainTextReader()
self.src_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab"))
self.trg_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab"))
self.src_data = list(self.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.trg_reader.read_sents("examples/data/head.en"))

Expand Down Expand Up @@ -288,8 +288,8 @@ def test_train_dev_loss_equal(self):
train_args['src_file'] = "examples/data/head.ja"
train_args['trg_file'] = "examples/data/head.en"
train_args['loss_calculator'] = AutoRegressiveMLELoss()
train_args['model'] = DefaultTranslator(src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
train_args['model'] = DefaultTranslator(src_reader=PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab")),
trg_reader=PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab")),
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim,
Expand Down Expand Up @@ -330,8 +330,8 @@ def test_overfitting(self):
train_args['src_file'] = "examples/data/head.ja"
train_args['trg_file'] = "examples/data/head.en"
train_args['loss_calculator'] = AutoRegressiveMLELoss()
train_args['model'] = DefaultTranslator(src_reader=PlainTextReader(),
trg_reader=PlainTextReader(),
train_args['model'] = DefaultTranslator(src_reader=PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab")),
trg_reader=PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab")),
src_embedder=SimpleWordEmbedder(vocab_size=100, emb_dim=layer_dim),
encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
hidden_dim=layer_dim),
Expand Down
36 changes: 8 additions & 28 deletions xnmt/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def read_sents(self, filename: str, filter_ids: Sequence[int] = None) -> Iterato
filter_ids: only read sentences with these ids (0-indexed)
Returns: iterator over sentences from filename
"""
if self.vocab is None:
self.vocab = Vocab()
return self.iterate_filtered(filename, filter_ids)

def count_sents(self, filename: str) -> int:
Expand All @@ -46,12 +44,6 @@ def count_sents(self, filename: str) -> int:
"""
raise RuntimeError("Input readers must implement the count_sents function")

def freeze(self) -> None:
"""
Freeze the data representation, e.g. by freezing the vocab.
"""
pass

def needs_reload(self) -> bool:
"""
Overwrite this method if data needs to be reload for each epoch
Expand Down Expand Up @@ -98,6 +90,12 @@ def iterate_filtered(self, filename, filter_ids=None):
if max_id is not None and sent_count > max_id:
break

def convert_int(x):
try:
return int(x)
except ValueError:
raise ValueError(f"Expecting integer tokens because no vocab was set. Got: '{x}'")

class PlainTextReader(BaseTextReader, Serializable):
"""
Handles the typical case of reading plain text files, with one sent per line.
Expand All @@ -113,25 +111,17 @@ class PlainTextReader(BaseTextReader, Serializable):
def __init__(self, vocab: Optional[Vocab] = None, read_sent_len: bool = False):
self.vocab = vocab
self.read_sent_len = read_sent_len
if vocab is not None:
self.vocab.freeze()
self.vocab.set_unk(Vocab.UNK_STR)

def read_sent(self, line):
if self.vocab:
self.convert_fct = self.vocab.convert
else:
self.convert_fct = int
self.convert_fct = convert_int
if self.read_sent_len:
return IntInput(len(line.strip().split()))
else:
return SimpleSentenceInput([self.convert_fct(word) for word in line.strip().split()] + [Vocab.ES])

def freeze(self):
self.vocab.freeze()
self.vocab.set_unk(Vocab.UNK_STR)
self.save_processed_arg("vocab", self.vocab)

def vocab_size(self):
return len(self.vocab)

Expand Down Expand Up @@ -166,8 +156,6 @@ def read_sents(self, filename: Union[str,Sequence[str]], filter_ids: Sequence[in
return
def count_sents(self, filename: str) -> int:
return self.readers[0].count_sents(filename if isinstance(filename,str) else filename[0])
def freeze(self) -> None:
for reader in self.readers: reader.freeze()
def needs_reload(self) -> bool:
return any(reader.needs_reload() for reader in self.readers)

Expand Down Expand Up @@ -199,9 +187,6 @@ def __init__(self, model_file, sample_train=False, l=-1, alpha=0.1, vocab=None):
self.alpha = alpha
self.vocab = vocab
self.train = False
if vocab is not None:
self.vocab.freeze()
self.vocab.set_unk(Vocab.UNK_STR)

@handle_xnmt_event
def on_set_train(self, val):
Expand All @@ -216,11 +201,6 @@ def read_sent(self, sentence):
return SimpleSentenceInput([self.vocab.convert(word) for word in words] + \
[self.vocab.convert(Vocab.ES_STR)])

def freeze(self):
self.vocab.freeze()
self.vocab.set_unk(Vocab.UNK_STR)
self.save_processed_arg("vocab", self.vocab)

def count_words(self, trg_words):
trg_cnt = 0
for x in trg_words:
Expand All @@ -236,7 +216,7 @@ def vocab_size(self):
class CharFromWordTextReader(PlainTextReader, Serializable):
yaml_tag = "!CharFromWordTextReader"
@serializable_init
def __init__(self, vocab=None, read_sent_len=False):
def __init__(self, vocab:Vocab=None, read_sent_len:bool=False):
super().__init__(vocab, read_sent_len)
def read_sent(self, sentence, filter_ids=None):
chars = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def __init__(self,
word_vocab = Vocab()
dict_entry = vocab_size
else:
word_vocab.freeze()
word_vocab.set_unk(word_vocab.UNK_STR)
dict_entry = len(word_vocab)
self.src_vocab = src_vocab
self.word_vocab = word_vocab
Expand Down Expand Up @@ -154,8 +152,6 @@ def __init__(self,
word_vocab = Vocab()
dict_entry = vocab_size
else:
word_vocab.freeze()
word_vocab.set_unk(word_vocab.UNK_STR)
dict_entry = len(word_vocab)

self.dict_entry = dict_entry
Expand Down

0 comments on commit 77e2baf

Please sign in to comment.