-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Configure how to combine batched losses (#411)
* some documentation and cleanup for the loss builders * configure batch_comb * fix unit test
- Loading branch information
Showing
13 changed files
with
177 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,140 @@ | ||
import dynet as dy | ||
from typing import Optional, Dict, List, Tuple | ||
import collections | ||
|
||
class LossBuilder(object): | ||
|
||
# TODO: document me | ||
|
||
def __init__(self, init_loss=None): | ||
self.loss_values = collections.defaultdict(lambda: dy.scalarInput(0)) | ||
if init_loss is not None: | ||
for key, val in init_loss.items(): | ||
self.loss_values[key] = val | ||
|
||
def add_loss(self, loss_name, loss_expr): | ||
if loss_expr is None: | ||
return | ||
if type(loss_expr) == LossBuilder: | ||
for loss_name, loss in loss_expr.loss_values.items(): | ||
self.loss_values[loss_name] += loss | ||
else: | ||
self.loss_values[loss_name] += loss_expr | ||
import dynet as dy | ||
|
||
def compute(self): | ||
return dy.sum_batches(dy.esum(list(self.loss_values.values()))) | ||
class FactoredLossExpr(object): | ||
|
||
""" | ||
Loss consisting of (possibly batched) DyNet expressions, with one expression per loss factor. | ||
def value(self): | ||
return dy.esum(list(self.loss_values.values())).value() | ||
Used to represent losses within a training step. | ||
def __getitem__(self, index): | ||
return self.loss_values[index] | ||
Args: | ||
init_loss: initial loss values | ||
""" | ||
|
||
def get_loss_stats(self): | ||
return LossScalarBuilder({k: dy.sum_batches(v).value() for k, v in self.loss_values.items()}) | ||
def __init__(self, init_loss: Optional[Dict[str, dy.Expression]] = None) -> None: | ||
self.expr_factors = collections.defaultdict(lambda: dy.scalarInput(0)) | ||
if init_loss is not None: | ||
for key, val in init_loss.items(): | ||
self.expr_factors[key] = val | ||
|
||
def add_loss(self, loss_name: str, loss_expr: Optional[dy.Expression]) -> None: | ||
if loss_expr: | ||
self.expr_factors[loss_name] += loss_expr | ||
|
||
def add_factored_loss_expr(self, factored_loss_expr: Optional['FactoredLossExpr']) -> None: | ||
if factored_loss_expr: | ||
for loss_name, loss in factored_loss_expr.expr_factors.items(): | ||
self.expr_factors[loss_name] += loss | ||
|
||
def compute(self, comb_method: str = "sum") -> dy.Expression: | ||
""" | ||
Compute loss as DyNet expression by summing over factors and batch elements. | ||
Args: | ||
comb_method: method for combining loss across batch elements ('sum' or 'avg'). | ||
Returns: | ||
Scalar DyNet expression. | ||
""" | ||
return self._combine_batches(dy.esum(list(self.expr_factors.values())), comb_method) | ||
|
||
def value(self) -> List[float]: | ||
""" | ||
Get list of per-batch-element loss values, summed over factors. | ||
Returns: | ||
List of same length as batch-size. | ||
""" | ||
return dy.esum(list(self.expr_factors.values())).value() | ||
|
||
def __getitem__(self, loss_name: str) -> dy.Expression: | ||
return self.expr_factors[loss_name] | ||
|
||
def get_factored_loss_val(self, comb_method: str = "sum") -> 'FactoredLossVal': | ||
""" | ||
Create factored loss values by calling ``.value()`` for each DyNet loss expression and applying batch combination. | ||
Args: | ||
comb_method: method for combining loss across batch elements ('sum' or 'avg'). | ||
Returns: | ||
Factored loss values. | ||
""" | ||
return FactoredLossVal({k: self._combine_batches(v, comb_method).value() for k, v in self.expr_factors.items()}) | ||
|
||
def _combine_batches(self, batched_expr, comb_method: str = "sum"): | ||
if comb_method == "sum": | ||
return dy.sum_batches(batched_expr) | ||
elif comb_method == "avg": | ||
return dy.sum_batches(batched_expr) * (1.0 / batched_expr.dim()[1]) | ||
else: | ||
raise ValueError(f"Unknown batch combination method '{comb_method}', expected 'sum' or 'avg'.'") | ||
|
||
def __len__(self): | ||
return len(self.loss_values) | ||
return len(self.expr_factors) | ||
|
||
def __repr__(self): | ||
loss_str = ", ".join(["%s %f" % (loss_name, dy.sum_batches(loss_value).value()) for loss_name, loss_value in self.loss_values.items()]) | ||
return "{Loss Builder: %s}" % (loss_str) | ||
loss_str = ", ".join( | ||
[f"{loss_name} {dy.sum_batches(loss_value).value()}" for loss_name, loss_value in self.expr_factors.items()]) | ||
return f"{{Loss Builder: {loss_str}}}" | ||
|
||
class LossScalarBuilder(object): | ||
class FactoredLossVal(object): | ||
|
||
# TODO: document me | ||
|
||
def __init__(self, loss_stats=None): | ||
if loss_stats is None: | ||
loss_stats = {} | ||
self.__loss_stats = loss_stats | ||
|
||
def __iadd__(self, other): | ||
for name, value in other.__loss_stats.items(): | ||
if name in self.__loss_stats: | ||
self.__loss_stats[name] += value | ||
""" | ||
Loss consisting of (unbatched) float values, with one value per loss factor. | ||
Used to represent losses accumulated across several training steps. | ||
""" | ||
|
||
def __init__(self, loss_dict = None) -> None: | ||
if loss_dict is None: | ||
loss_dict = {} | ||
self._loss_dict = loss_dict | ||
|
||
def __iadd__(self, other: 'FactoredLossVal'): | ||
""" | ||
Implements += operator, adding up factors individually. | ||
Args: | ||
other: other factored float loss | ||
Returns: | ||
self | ||
""" | ||
for name, value in other._loss_dict.items(): | ||
if name in self._loss_dict: | ||
self._loss_dict[name] += value | ||
else: | ||
self.__loss_stats[name] = value | ||
self._loss_dict[name] = value | ||
return self | ||
|
||
def sum(self): | ||
return sum([x for x in self.__loss_stats.values()]) | ||
def sum_factors(self) -> float: | ||
""" | ||
Return the sum of all loss factors. | ||
Returns: | ||
A float value. | ||
""" | ||
return sum([x for x in self._loss_dict.values()]) | ||
|
||
def items(self) -> List[Tuple[str, float]]: | ||
""" | ||
Get name/value tuples for loss factors. | ||
def items(self): | ||
return self.__loss_stats.items() | ||
Returns: | ||
Name/value tuples. | ||
""" | ||
return self._loss_dict.items() | ||
|
||
def __len__(self): | ||
return len(self.__loss_stats) | ||
return len(self._loss_dict) | ||
|
||
def zero(self): | ||
self.__loss_stats.clear() | ||
def clear(self) -> None: | ||
""" | ||
Clears all loss factors. | ||
""" | ||
self._loss_dict.clear() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.