Skip to content

Commit

Permalink
Generalize loss tracker (#406)
Browse files Browse the repository at this point in the history
* divide GeneratorModel into TrainableModel, GeneratorModel, and EventTrigger

* count_trg_words supports several input types

* move word counting to Input
  • Loading branch information
msperber authored and neubig committed May 25, 2018
1 parent 0713023 commit 8988979
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 44 deletions.
41 changes: 33 additions & 8 deletions xnmt/input.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,63 @@
from typing import Any, Sequence

import numpy as np

from xnmt import vocab

class Input(object):
"""
A template class to represent a single input of any type.
"""
def __len__(self):
"""
Return length of input, included padded tokens.
Returns: length
"""
raise NotImplementedError("__len__() must be implemented by Input subclasses")

def len_unpadded(self):
"""
Return length of input prior to applying any padding.
Returns: unpadded length
"""

def __getitem__(self):
raise NotImplementedError("__getitem__() must be implemented by Input subclasses")

def get_padded_sent(self, token, pad_len):
def get_padded_sent(self, token: Any, pad_len: int) -> 'Input':
"""
Return padded version of the sent.
Return padded version of the sentence.
Args:
token: padding token
pad_len (int): number of tokens to append
pad_len: number of tokens to append
Returns:
xnmt.input.Input: padded sent
padded sentence
"""
raise NotImplementedError("get_padded_sent() must be implemented by Input subclasses")

class SimpleSentenceInput(Input):
"""
A simple sent, represented as a list of tokens
A simple sentence, represented as a list of tokens
Args:
words (List[int]): list of integer word ids
vocab (Vocab):
"""
def __init__(self, words, vocab=None):

def __init__(self, words: Sequence[int], vocab: vocab.Vocab = None):
self.words = words
self.vocab = vocab

def __len__(self):
return len(self.words)

def len_unpadded(self):
return sum(x!=vocab.Vocab.ES for x in self.words)

def __getitem__(self, key):
return self.words[key]

Expand Down Expand Up @@ -79,12 +100,16 @@ class ArrayInput(Input):
Args:
nparr: numpy array
"""
def __init__(self, nparr):
def __init__(self, nparr: np.ndarray, padded_len: int = 0):
self.nparr = nparr
self.padded_len = padded_len

def __len__(self):
return self.nparr.shape[1] if len(self.nparr.shape) >= 2 else 1

def len_unpadded(self):
return len(self) - self.padded_len

def __getitem__(self, key):
return self.nparr.__getitem__(key)

Expand All @@ -106,7 +131,7 @@ def get_padded_sent(self, token, pad_len):
new_nparr = np.append(self.nparr, np.zeros((self.nparr.shape[0], pad_len)), axis=1)
else:
raise NotImplementedError(f"currently only support 'None' or '0' as, but got '{token}'")
return ArrayInput(new_nparr)
return ArrayInput(new_nparr, padded_len=self.padded_len + pad_len)

def get_array(self):
return self.nparr
70 changes: 34 additions & 36 deletions xnmt/loss_tracker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Union
import time

import xnmt.loss
from xnmt.vocab import Vocab
from xnmt.events import register_xnmt_handler, handle_xnmt_event
from xnmt.util import format_time, log_readable_and_structured
import numpy as np

from xnmt import batcher, input, loss, vocab, events, util

class AccumTimeTracker(object):
def __init__(self):
Expand All @@ -28,11 +28,11 @@ class TrainLossTracker(object):
REPORT_TEMPLATE_ADDITIONAL = '- {loss_name} {loss:5.6f}'
REPORT_EVERY = 1000

@register_xnmt_handler
@events.register_xnmt_handler
def __init__(self, training_task):
self.training_task = training_task

self.epoch_loss = xnmt.loss.LossScalarBuilder()
self.epoch_loss = loss.LossScalarBuilder()
self.epoch_words = 0
self.last_report_sents_into_epoch = 0
self.last_report_sents_since_start = 0
Expand All @@ -43,7 +43,7 @@ def __init__(self, training_task):
self.start_time = time.time()
self.name = self.training_task.name

@handle_xnmt_event
@events.handle_xnmt_event
def on_new_epoch(self, training_task, num_sents):
if training_task is self.training_task:
self.epoch_loss.zero()
Expand All @@ -66,36 +66,33 @@ def report(self, trg, loss):
fractional_epoch = (self.training_task.training_state.epoch_num - 1) \
+ self.training_task.training_state.sents_into_epoch / self.training_task.cur_num_sentences()
accum_time = self.time_tracker.get_and_reset()
log_readable_and_structured(
util.log_readable_and_structured(
TrainLossTracker.REPORT_TEMPLATE_SPEED if accum_time else TrainLossTracker.REPORT_TEMPLATE,
{"key": "train_loss", "data": "train",
"epoch": fractional_epoch,
"loss": self.epoch_loss.sum() / self.epoch_words,
"words": self.epoch_words,
"words_per_sec": (self.epoch_words - self.last_report_words) / (
accum_time) if accum_time else "-",
"time": format_time(time.time() - self.start_time)},
"time": util.format_time(time.time() - self.start_time)},
task_name=self.name)

if len(self.epoch_loss) > 1:
for loss_name, loss_values in self.epoch_loss.items():
log_readable_and_structured(TrainLossTracker.REPORT_TEMPLATE_ADDITIONAL,
{"key": "additional_train_loss",
"loss_name": loss_name,
"loss": loss_values / self.epoch_words},
task_name=self.name)
util.log_readable_and_structured(TrainLossTracker.REPORT_TEMPLATE_ADDITIONAL,
{"key": "additional_train_loss",
"loss_name": loss_name,
"loss": loss_values / self.epoch_words},
task_name=self.name)

self.last_report_words = self.epoch_words
self.last_report_sents_since_start = self.training_task.training_state.sents_since_start

def count_trg_words(self, trg_words):
trg_cnt = 0
for x in trg_words:
if type(x) == int:
trg_cnt += 1 if x != Vocab.ES else 0
else:
trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
return trg_cnt
def count_trg_words(self, trg_words: Union[input.Input, batcher.Batch]) -> int:
if isinstance(trg_words, batcher.Batch):
return sum(inp.len_unpadded() for inp in trg_words)
else:
return trg_words.len_unpadded()

class DevLossTracker(object):

Expand Down Expand Up @@ -138,20 +135,21 @@ def report(self):
self.fractional_epoch = (self.training_task.training_state.epoch_num - 1) \
+ self.training_task.training_state.sents_into_epoch / self.training_task.cur_num_sentences()
dev_time = self.time_tracker.get_and_reset()
log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_DEV,
{"key": "dev_loss",
"epoch": self.fractional_epoch,
"score": self.dev_score,
"words": self.dev_words,
"time": format_time(this_report_time - self.start_time)
},
task_name=self.name)
util.log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_DEV,
{"key": "dev_loss",
"epoch": self.fractional_epoch,
"score": self.dev_score,
"words": self.dev_words,
"time": util.format_time(this_report_time - self.start_time)
},
task_name=self.name)
for score in self.aux_scores:
log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_DEV_AUX,
{"key": "auxiliary_score", "epoch": self.fractional_epoch, "score": score},
task_name=self.name)
log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_TIME_NEEDED,
{"key": "dev_time_needed", "epoch": self.fractional_epoch, "time_needed": format_time(dev_time)},
task_name=self.name)
util.log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_DEV_AUX,
{"key": "auxiliary_score", "epoch": self.fractional_epoch, "score": score},
task_name=self.name)
util.log_readable_and_structured(DevLossTracker.REPORT_TEMPLATE_TIME_NEEDED,
{"key": "dev_time_needed", "epoch": self.fractional_epoch,
"time_needed": util.format_time(dev_time)},
task_name=self.name)
self.aux_scores = []

0 comments on commit 8988979

Please sign in to comment.