Skip to content

Commit

Permalink
Hyperparams as numbers.Real (#523)
Browse files Browse the repository at this point in the history
* make hyper param an instance of number.Real, and update type annotations to use numbers.Real instead of float

* fix some issues in DefinedSequence

* fix unit tests
  • Loading branch information
msperber committed Aug 21, 2018
1 parent 23e0463 commit 8af53ea
Show file tree
Hide file tree
Showing 16 changed files with 212 additions and 171 deletions.
11 changes: 6 additions & 5 deletions xnmt/batchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
from abc import ABC, abstractmethod
from functools import lru_cache
import numbers

import numpy as np
import dynet as dy
Expand Down Expand Up @@ -538,7 +539,7 @@ class WordSortBatcher(SortBatcher):
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""

def __init__(self, words_per_batch: Optional[int], avg_batch_size: Optional[Union[int,float]], sort_key: Callable,
def __init__(self, words_per_batch: Optional[int], avg_batch_size: Optional[numbers.Real], sort_key: Callable,
break_ties_randomly: bool = True, pad_src_to_multiple: int = 1) -> None:
# Sanity checks
if words_per_batch and avg_batch_size:
Expand Down Expand Up @@ -566,7 +567,7 @@ class WordSrcBatcher(WordSortBatcher, Serializable):
yaml_tag = "!WordSrcBatcher"

@serializable_init
def __init__(self, words_per_batch:Optional[int]=None, avg_batch_size:Optional[Union[int,float]]=None,
def __init__(self, words_per_batch:Optional[int]=None, avg_batch_size:Optional[numbers.Real]=None,
break_ties_randomly:bool=True, pad_src_to_multiple:int=1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[0].sent_len(),
break_ties_randomly=break_ties_randomly,
Expand All @@ -592,7 +593,7 @@ class WordTrgBatcher(WordSortBatcher, Serializable):
yaml_tag = "!WordTrgBatcher"

@serializable_init
def __init__(self, words_per_batch:Optional[int]=None, avg_batch_size:Optional[Union[int,float]]=None,
def __init__(self, words_per_batch:Optional[int]=None, avg_batch_size:Optional[numbers.Real]=None,
break_ties_randomly:bool=True, pad_src_to_multiple:int=1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[1].sent_len(),
break_ties_randomly=break_ties_randomly,
Expand All @@ -618,7 +619,7 @@ class WordSrcTrgBatcher(WordSortBatcher, Serializable):
yaml_tag = "!WordSrcTrgBatcher"

@serializable_init
def __init__(self, words_per_batch: Optional[int] = None, avg_batch_size: Optional[Union[int, float]] = None,
def __init__(self, words_per_batch: Optional[int] = None, avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True, pad_src_to_multiple: bool = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[0].sent_len() + 1.0e-6 * x[1].sent_len(),
break_ties_randomly=break_ties_randomly,
Expand All @@ -644,7 +645,7 @@ class WordTrgSrcBatcher(WordSortBatcher, Serializable):
yaml_tag = "!WordTrgSrcBatcher"

@serializable_init
def __init__(self, words_per_batch: Optional[int] = None, avg_batch_size: Optional[Union[int, float]] = None,
def __init__(self, words_per_batch: Optional[int] = None, avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True, pad_src_to_multiple: int = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[1].sent_len() + 1.0e-6 * x[0].sent_len(),
break_ties_randomly=break_ties_randomly,
Expand Down
21 changes: 14 additions & 7 deletions xnmt/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import subprocess
from typing import List, Sequence, Dict, Tuple, Union, Any, Optional
import numbers

import yaml
import numpy as np
Expand Down Expand Up @@ -125,8 +126,8 @@ class LossScore(EvalScore, Serializable):
yaml_tag = "!LossScore"

@serializable_init
def __init__(self, loss: float, loss_stats: Dict[str, float] = None, num_ref_words: Optional[int] = None,
desc: Any = None) -> None:
def __init__(self, loss: numbers.Real, loss_stats: Dict[str, numbers.Real] = None,
num_ref_words: Optional[int] = None, desc: Any = None) -> None:
super().__init__(desc=desc)
self.loss = loss
self.loss_stats = loss_stats
Expand Down Expand Up @@ -159,8 +160,14 @@ class BLEUScore(EvalScore, Serializable):
yaml_tag = "!BLEUScore"

@serializable_init
def __init__(self, bleu: float, frac_score_list: Sequence[float] = None, brevity_penalty_score: float = None,
hyp_len: int = None, ref_len: int = None, ngram: int = 4, desc: Any = None) -> None:
def __init__(self,
bleu: numbers.Real,
frac_score_list: Sequence[numbers.Real] = None,
brevity_penalty_score: numbers.Real = None,
hyp_len: int = None,
ref_len: int = None,
ngram: int = 4,
desc: Any = None) -> None:
self.bleu = bleu
self.frac_score_list = frac_score_list
self.brevity_penalty_score = brevity_penalty_score
Expand Down Expand Up @@ -291,7 +298,7 @@ class RecallScore(SentenceLevelEvalScore, Serializable):
yaml_tag = "!RecallScore"

@serializable_init
def __init__(self, recall: float, hyp_len: int, ref_len: int, nbest: int = 5, desc: Any = None) -> None:
def __init__(self, recall: numbers.Real, hyp_len: int, ref_len: int, nbest: int = 5, desc: Any = None) -> None:
self.recall = recall
self.hyp_len = hyp_len
self.ref_len = ref_len
Expand Down Expand Up @@ -327,7 +334,7 @@ class ExternalScore(EvalScore, Serializable):
yaml_tag = "!ExternalScore"

@serializable_init
def __init__(self, value: float, higher_is_better: bool = True, desc: Any = None) -> None:
def __init__(self, value: numbers.Real, higher_is_better: bool = True, desc: Any = None) -> None:
self.value = value
self.higher_is_better = higher_is_better
self.desc = desc
Expand Down Expand Up @@ -479,7 +486,7 @@ class FastBLEUEvaluator(SentenceLevelEvaluator, Serializable):
yaml_tag = "!FastBLEUEvaluator"

@serializable_init
def __init__(self, ngram:int = 4, smooth:float = 1):
def __init__(self, ngram:int = 4, smooth:numbers.Real = 1):
self.ngram = ngram
self.weights = (1 / ngram) * np.ones(ngram, dtype=np.float32)
self.smooth = smooth
Expand Down
5 changes: 3 additions & 2 deletions xnmt/experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional
import numbers

from xnmt.param_initializers import ParamInitializer, GlorotInitializer, ZeroInitializer
from xnmt.settings import settings
Expand Down Expand Up @@ -38,8 +39,8 @@ class ExpGlobal(Serializable):
def __init__(self,
model_file: str = settings.DEFAULT_MOD_PATH,
log_file: str = settings.DEFAULT_LOG_PATH,
dropout: float = 0.3,
weight_noise: float = 0.0,
dropout: numbers.Real = 0.3,
weight_noise: numbers.Real = 0.0,
default_layer_dim: int = 512,
param_init: ParamInitializer = bare(GlorotInitializer),
bias_init: ParamInitializer = bare(ZeroInitializer),
Expand Down
51 changes: 30 additions & 21 deletions xnmt/hyper_params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import typing
import numbers

from xnmt.events import register_xnmt_handler, handle_xnmt_event
from xnmt.persistence import serializable_init, Serializable

Expand All @@ -8,26 +11,29 @@ class Scalar(Serializable):
Args:
initial: The value being hold by the scalar.
update: Is the epoch number.
times_updated: Is the epoch number.
"""

yaml_tag = "!Scalar"

@serializable_init
@register_xnmt_handler
def __init__(self, initial=0.0, update=0):
self.value = initial
self.update = update
def __init__(self, initial=0.0, times_updated=0):
self.initial = initial
self.times_updated = times_updated
self.value = self.get_curr_value()

@handle_xnmt_event
def on_new_epoch(self, *args, **kwargs):
self.value = self.update_value()
self.update += 1
self.save_processed_arg("initial", self.value)
self.save_processed_arg("update", self.update)
self.value = self.get_curr_value()
self.times_updated += 1
self.save_processed_arg("times_updated", self.times_updated)

def update_value(self):
return self.value
def get_curr_value(self):
return self.initial

def __repr__(self):
return f"{self.__class__.__name__}[curr={self.get_curr_value()}]"

# Operators
def __lt__(a, b): return a.value < b
Expand All @@ -45,30 +51,33 @@ def __pow__(a, b): return a.value ** b
def __truediv__(a, b): return a.value / b
def __floordiv__(a, b): return a.value // b

class DefinedSequence(Scalar):
class DefinedSequence(Scalar, Serializable):
"""
Class that represents a fixed defined sequence from config files.
If update has been made more than the length of the sequence, the last element of the sequence will be returned instead
x = DefinedSequence([0.1, 0.5, 1])
# Epoch 1: 0+x = 0.1
# Epoch 2: 0+x = 0.5
# Epoch 3: 0+x = 1
# Epoch 1: 0.1
# Epoch 2: 0.5
# Epoch 3: 1
# Epoch 4: 1
# ...
Args:
sequence: A list of numbers
initial: The current value or the value.
update: The epoch number
times_updated: The epoch number
"""

yaml_tag = "!DefinedSequence"

@serializable_init
def __init__(self, sequence=None, initial=0.0, update=0):
super().__init__(initial, update)
def __init__(self, sequence: typing.Sequence[numbers.Real], times_updated: int = 0):
self.sequence = sequence
if len(sequence)==0: raise ValueError("DefinedSequence initialized with empty sequence")
super().__init__(times_updated=times_updated)

def update_value(self):
return self.sequence[min(len(self.sequence)-1, self.update)]
def get_curr_value(self):
return self.sequence[min(len(self.sequence) - 1, self.times_updated)]

numbers.Real.register(Scalar)
7 changes: 4 additions & 3 deletions xnmt/inferences.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections.abc
from typing import List, Optional, Tuple, Sequence, Union
import numbers

from xnmt.settings import settings

Expand Down Expand Up @@ -93,7 +94,7 @@ def perform_inference(self, generator: 'models.GeneratorModel', src_file: str =
def _generate_output(self, generator: 'models.GeneratorModel', src_corpus: Sequence[sent.Sentence],
trg_file: str, batcher: Optional[batchers.Batcher] = None, max_src_len: Optional[int] = None,
forced_ref_corpus: Optional[Sequence[sent.Sentence]] = None,
assert_scores: Optional[Sequence[float]] = None) -> None:
assert_scores: Optional[Sequence[numbers.Real]] = None) -> None:
"""
Generate outputs and write them to file.
Expand Down Expand Up @@ -155,7 +156,7 @@ def _conclude_report(self):
for reporter in self.reporter:
reporter.conclude_report()

def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> List[float]:
def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> List[numbers.Real]:
batched_src, batched_ref = self.batcher.pack(src_corpus, ref_corpus)
ref_scores = []
for sent_count, (src, ref) in enumerate(zip(batched_src, batched_ref)):
Expand All @@ -171,7 +172,7 @@ def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> L


@staticmethod
def _write_rescored_output(ref_scores: Sequence[float], ref_file: str, trg_file: str) -> None:
def _write_rescored_output(ref_scores: Sequence[numbers.Real], ref_file: str, trg_file: str) -> None:
"""
Write scored sequences and scores to file when mode=='score'.
Expand Down
8 changes: 4 additions & 4 deletions xnmt/length_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numbers import Real
import numbers
from typing import Sequence, Optional

import numpy as np
Expand Down Expand Up @@ -57,7 +57,7 @@ class AdditiveNormalization(LengthNormalization, Serializable):
yaml_tag = '!AdditiveNormalization'

@serializable_init
def __init__(self, penalty: Real = -0.1, apply_during_search: bool = False):
def __init__(self, penalty: numbers.Real = -0.1, apply_during_search: bool = False):
self.penalty = penalty
self.apply_during_search = apply_during_search

Expand All @@ -78,7 +78,7 @@ class PolynomialNormalization(LengthNormalization, Serializable):
yaml_tag = '!PolynomialNormalization'

@serializable_init
def __init__(self, m: Real = 1, apply_during_search: bool = False):
def __init__(self, m: numbers.Real = 1, apply_during_search: bool = False):
self.m = m
self.apply_during_search = apply_during_search
self.pows = []
Expand Down Expand Up @@ -175,7 +175,7 @@ class EosBooster(Serializable):
"""
yaml_tag = "!EosBooster"
@serializable_init
def __init__(self, boost_val: float):
def __init__(self, boost_val: numbers.Real):
self.boost_val = boost_val
def __call__(self, scores:np.ndarray) -> None:
scores[Vocab.ES] += self.boost_val

0 comments on commit 8af53ea

Please sign in to comment.