Skip to content

Commit

Permalink
Merge pull request #519 from neulab/lru_vocab
Browse files Browse the repository at this point in the history
LRU vocab for learning vocabularies on the flight using RL
  • Loading branch information
philip30 committed Sep 27, 2018
2 parents ed908ef + 2b2b47e commit b2db445
Show file tree
Hide file tree
Showing 14 changed files with 324 additions and 88 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -10,4 +10,5 @@ numpy
Unidecode>=1.0.22
beautifulsoup4
nltk
pylru
tensorboardX
8 changes: 6 additions & 2 deletions test/config/seg_debug.yaml
Expand Up @@ -13,12 +13,16 @@ report: !Experiment
policy_learning: !PolicyGradient
output_dim: 2
conf_penalty: !ConfidencePenalty {}
reporter: !SegmentingReporter {}
compute_report: True
reporter: !SegmentPLLogger
report_path: examples/output/report
final_transducer: !BiLSTMSeqTransducer {}
train: !SimpleTrainingRegimen
run_for_epochs: 1
src_file: examples/data/head.ja
trg_file: examples/data/head.en
trg_file: examples/data/head.en
loss_calculator: !CompositeLoss
losses: [!FeedbackLoss {child_loss: !MLELoss {}}]
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu,wer
Expand Down
48 changes: 48 additions & 0 deletions test/config/seg_learn_debug.yaml
@@ -0,0 +1,48 @@
seg_learn_debug: !Experiment
exp_global: !ExpGlobal
default_layer_dim: 64
compute_report: True
model: !DefaultTranslator
src_reader: !CharFromWordTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.charvocab}
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
encoder: !SegmentingSeqTransducer
segment_composer: !CharNGramComposer
vocab_size: 1000
ngram_size: 4
policy_learning: !PolicyGradient
weight: !DefinedSequence {sequence: [1.0]}
output_dim: 2
use_baseline: True
z_normalization: True
conf_penalty: !ConfidencePenalty {}
compute_report: True
reporter: !SegmentPLLogger
report_path: examples/output/report
final_transducer: !BiLSTMSeqTransducer {}
train: !SimpleTrainingRegimen
run_for_epochs: 3
src_file: examples/data/head.ja
trg_file: examples/data/head.en
loss_calculator: !CompositeLoss
losses: [!FeedbackLoss {child_loss: !MLELoss {}}]
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
- !AccuracyEvalTask
eval_metrics: bleu
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: examples/output/dev_hyp
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/debug.ja
ref_file: examples/data/debug.en
hyp_file: test/tmp/{EXP}.test_hyp
- !LossEvalTask
src_file: examples/data/debug.ja
ref_file: examples/data/debug.en

7 changes: 4 additions & 3 deletions test/config/seg_report.yaml
Expand Up @@ -24,11 +24,12 @@ report: !Experiment
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/head.ja
ref_file: examples/data/head.en
src_file: examples/data/dev.ja
ref_file: examples/data/dev.en
hyp_file: test/tmp/{EXP}.test_hyp
inference: !AutoRegressiveInference
reporter: !SegmentationReporter {}
reporter: !SegmentationReporter
report_path: examples/output/test.segment
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
Expand Down
86 changes: 76 additions & 10 deletions test/test_segmenting.py
Expand Up @@ -19,7 +19,7 @@
from xnmt.loss_calculators import MLELoss, FeedbackLoss, GlobalFertilityLoss, CompositeLoss
from xnmt.specialized_encoders.segmenting_encoder.segmenting_encoder import *
from xnmt.specialized_encoders.segmenting_encoder.segmenting_composer import *
from xnmt.specialized_encoders.segmenting_encoder.reporter import SegmentingReporter
from xnmt.specialized_encoders.segmenting_encoder.reporter import SegmentPLLogger
from xnmt.specialized_encoders.segmenting_encoder.length_prior import PoissonLengthPrior
from xnmt.specialized_encoders.segmenting_encoder.priors import PoissonPrior, GoldInputPrior
from xnmt.modelparts.transforms import AuxNonLinear, Linear
Expand All @@ -36,7 +36,7 @@ def setUp(self):
# Seeding
numpy.random.seed(2)
random.seed(2)
layer_dim = 64
layer_dim = 4
xnmt.events.clear()
ParamManager.init_param_col()
self.segment_encoder_bilstm = BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim)
Expand Down Expand Up @@ -163,15 +163,15 @@ def test_policy_gold(self):
self.calc_loss_single_batch()

def test_reporter(self):
self.model.encoder.reporter = SegmentingReporter("test/tmp/seg-report.log", self.model.src_reader.vocab)
self.model.encoder.reporter = SegmentPLLogger("test/tmp/seg-report.log", self.model.src_reader.vocab)
self.calc_loss_single_batch()

class TestComposing(unittest.TestCase):
def setUp(self):
# Seeding
numpy.random.seed(2)
random.seed(2)
layer_dim = 64
layer_dim = 4
xnmt.events.clear()
ParamManager.init_param_col()
self.segment_composer = SumComposer()
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_lookup_composer(self):
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
enc.segment_composer = LookupComposer(
word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
char_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim
)
enc.transduce(self.inp_emb(0))
Expand All @@ -228,22 +228,88 @@ def test_charngram_composer(self):
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
enc.segment_composer = CharNGramComposer(
word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
char_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim
)
enc.transduce(self.inp_emb(0))

def test_lookup_composer_learn(self):
enc = self.segmenting_encoder
char_vocab = Vocab(i2w=['a', 'b', 'c', 'd'])
enc.segment_composer = LookupComposer(
word_vocab = None,
char_vocab = char_vocab,
hidden_dim = self.layer_dim,
vocab_size = 4
)
event_trigger.set_train(True)
enc.segment_composer.set_word((0, 1, 2)) # abc 0
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 2, 1)) # acb 1
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 3, 2)) # adc 2
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 1, 2)) # abc 0
enc.segment_composer.transduce([])
enc.segment_composer.set_word((1, 3, 2)) # bdc 3
enc.segment_composer.transduce([])
enc.segment_composer.set_word((3, 3, 3)) # ddd 1 -> acb is the oldest
enc.segment_composer.transduce([])
act = dict(enc.segment_composer.lrucache.items())
exp = {'abc': 0, 'ddd': 1, 'adc': 2, 'bdc': 3}
self.assertDictEqual(act, exp)

enc.segment_composer.set_word((0, 2, 1))
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 3, 2))
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 1, 2)) # abc 0
enc.segment_composer.transduce([])
enc.segment_composer.set_word((1, 3, 2)) # bdc 3
enc.segment_composer.transduce([])
enc.segment_composer.set_word((3, 3, 3))
enc.segment_composer.transduce([])
enc.segment_composer.set_word((0, 3, 1))
enc.segment_composer.transduce([])

event_trigger.set_train(False)
enc.segment_composer.set_word((3, 3, 2))
enc.segment_composer.transduce([])

def test_chargram_composer_learn(self):
enc = self.segmenting_encoder
char_vocab = Vocab(i2w=['a', 'b', 'c', 'd'])
enc.segment_composer = CharNGramComposer(
word_vocab = None,
char_vocab = char_vocab,
hidden_dim = self.layer_dim,
ngram_size = 2,
vocab_size = 5,
)
event_trigger.set_train(True)
enc.segment_composer.set_word((0, 1, 2)) # a:0, ab:1, b: 2, bc: 3, c: 4
enc.segment_composer.transduce([])
act = dict(enc.segment_composer.lrucache.items())
exp = {'a': 0, 'ab': 1, 'b': 2, 'bc': 3, 'c': 4}
self.assertDictEqual(act, exp)

enc.segment_composer.set_word((2, 3)) # c, cd, d
enc.segment_composer.transduce([])
act = dict(enc.segment_composer.lrucache.items())
exp = {'cd': 0, 'd': 1, 'b': 2, 'bc': 3, 'c': 4}
self.assertDictEqual(act, exp)

def test_add_multiple_segment_composer(self):
enc = self.segmenting_encoder
word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
enc.segment_composer = SumMultipleComposer(
composers = [
LookupComposer(word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim),
char_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim),
CharNGramComposer(word_vocab = word_vocab,
src_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim)
char_vocab = self.src_reader.vocab,
hidden_dim = self.layer_dim)
]
)
enc.transduce(self.inp_emb(0))
Expand Down
11 changes: 8 additions & 3 deletions xnmt/expression_seqs.py
Expand Up @@ -43,9 +43,11 @@ def __init__(self, expr_list: Optional[Sequence[dy.Expression]] = None, expr_ten
if e.dim() != expr_list[0].dim():
raise AssertionError()
if expr_tensor:
if not isinstance(expr_tensor,dy.Expression): raise ValueError("expr_tensor must be dynet expression, was:", type(expr_tensor))
if not isinstance(expr_tensor,dy.Expression):
raise ValueError("expr_tensor must be dynet expression, was:", type(expr_tensor))
if expr_transposed_tensor:
if not isinstance(expr_transposed_tensor,dy.Expression): raise ValueError("expr_transposed_tensor must be dynet expression, was:", type(expr_transposed_tensor))
if not isinstance(expr_transposed_tensor,dy.Expression):
raise ValueError("expr_transposed_tensor must be dynet expression, was:", type(expr_transposed_tensor))

def __len__(self):
"""Return length.
Expand Down Expand Up @@ -104,7 +106,10 @@ def as_tensor(self):
the whole sequence as a tensor expression where each column is one of the embeddings.
"""
if self.expr_tensor is None:
self.expr_tensor = dy.concatenate_cols(self.expr_list) if self.expr_list else dy.transpose(self.expr_transposed_tensor)
if self.expr_list:
self.expr_tensor = dy.concatenate_cols(self.expr_list)
else:
self.expr_tensor = dy.transpose(self.expr_transposed_tensor)
return self.expr_tensor

def has_tensor(self):
Expand Down
4 changes: 2 additions & 2 deletions xnmt/hyper_params.py
Expand Up @@ -18,7 +18,7 @@ class Scalar(Serializable):

@serializable_init
@register_xnmt_handler
def __init__(self, initial=0.0, times_updated=0):
def __init__(self, initial:numbers.Integral = 0.0, times_updated:numbers.Integral = 0):
self.initial = initial
self.times_updated = times_updated
self.value = self.get_curr_value()
Expand Down Expand Up @@ -80,4 +80,4 @@ def __init__(self, sequence: typing.Sequence[numbers.Real], times_updated: numbe
def get_curr_value(self):
return self.sequence[min(len(self.sequence) - 1, self.times_updated)]

numbers.Real.register(Scalar)
numbers.Real.register(Scalar)
4 changes: 2 additions & 2 deletions xnmt/persistence.py
Expand Up @@ -1462,15 +1462,15 @@ def check_type(obj, desired_type):
elif desired_type.__class__.__name__ == "_Union":
return any(
subtype.__class__.__name__ == "_ForwardRef" or check_type(obj, subtype) for subtype in desired_type.__args__)
elif issubclass(desired_type, collections.abc.MutableMapping):
elif issubclass(desired_type.__class__, collections.abc.MutableMapping):
if not isinstance(obj, collections.abc.MutableMapping): return False
if desired_type.__args__:
return (desired_type.__args__[0].__class__.__name__ == "_ForwardRef" or all(
check_type(key, desired_type.__args__[0]) for key in obj.keys())) and (
desired_type.__args__[1].__class__.__name__ == "_ForwardRef" or all(
check_type(val, desired_type.__args__[1]) for val in obj.values()))
else: return True
elif issubclass(desired_type, collections.abc.Sequence):
elif issubclass(desired_type.__class__, collections.abc.Sequence):
if not isinstance(obj, collections.abc.Sequence): return False
if desired_type.__args__ and desired_type.__args__[0].__class__.__name__ != "_ForwardRef":
return all(check_type(item, desired_type.__args__[0]) for item in obj)
Expand Down
8 changes: 4 additions & 4 deletions xnmt/reports.py
Expand Up @@ -445,16 +445,16 @@ def __init__(self, report_path: str=settings.DEFAULT_REPORT_PATH):

def create_sent_report(self, segment_actions, src, **kwargs):
if self.report_fp is None:
report_path = os.path.join(self.report_path, "segment.txt")
utils.make_parent_dir(report_path)
self.report_fp = open(report_path, "w")
utils.make_parent_dir(self.report_path)
self.report_fp = open(self.report_path, "w")

actions = segment_actions[0]
src = src.str_tokens()
words = []
start = 0
for end in actions:
words.append("".join(str(src[start:end+1])))
if start < end+1:
words.append("".join(map(str, src[start:end+1])))
start = end+1
print(" ".join(words), file=self.report_fp)

Expand Down
2 changes: 1 addition & 1 deletion xnmt/rl/policy_gradient.py
Expand Up @@ -92,7 +92,7 @@ def calc_loss(self, reward):
if self.z_normalization:
mean_batches = dy.mean_batches(reward)
std_batches = dy.std_batches(reward)
reward = dy.cdiv(reward-mean_batches, std_batches)
reward = dy.cdiv(reward-mean_batches, std_batches+1e-6)
## Calculate baseline
if self.baseline is not None:
pred_reward, baseline_loss = self.calc_baseline_loss(reward)
Expand Down
2 changes: 1 addition & 1 deletion xnmt/specialized_encoders/segmenting_encoder/priors.py
Expand Up @@ -16,7 +16,7 @@

class Prior(object):
def log_ll(self, event): raise NotImplementedError()
def sample(self, size): raise NotImplementedError()
def sample(self, batch_size, size): raise NotImplementedError()

class PoissonPrior(Prior, Serializable):
""" The poisson prior """
Expand Down
14 changes: 10 additions & 4 deletions xnmt/specialized_encoders/segmenting_encoder/reporter.py
Expand Up @@ -8,8 +8,8 @@
from xnmt.persistence import serializable_init, Serializable, Ref, Path
from xnmt.specialized_encoders.segmenting_encoder.segmenting_encoder import SegmentingSeqTransducer

class SegmentingReporter(Serializable):
yaml_tag = "!SegmentingReporter"
class SegmentPLLogger(Serializable):
yaml_tag = "!SegmentPLLogger"

@serializable_init
@register_xnmt_handler
Expand Down Expand Up @@ -40,10 +40,16 @@ def report_process(self, encoder: SegmentingSeqTransducer):
format.append("{:>5}")
table.append(["ACT"] + sample_dense)
if encoder.policy_learning is not None:
policy_lls = [encoder.policy_learning.policy_lls[j].npvalue().transpose()[self.idx] for j in range(src_len)]
if self.src_sent.batch_size() == 1:
policy_lls = [encoder.policy_learning.policy_lls[j].npvalue().transpose() for j in range(src_len)]
else:
policy_lls = [encoder.policy_learning.policy_lls[j].npvalue().transpose()[self.idx] for j in range(src_len)]
table.append(["LLS"] + ["{:.4f}".format(math.exp(policy_lls[j][sample_dense[j]])) for j in range(src_len)])
self.pad_last(table)
valid_pos = [1 if self.idx in x else 0 for x in encoder.policy_learning.valid_pos]
if encoder.policy_learning.valid_pos is not None:
valid_pos = [1 if self.idx in x else 0 for x in encoder.policy_learning.valid_pos]
else:
valid_pos = [1 for _ in range(src.sent_len())]
table.append(["MSK"] + valid_pos)
format.append("{:>8}")
format.append("{:>5}")
Expand Down

0 comments on commit b2db445

Please sign in to comment.