diff --git a/pytext/optimizer/optimizers.py b/pytext/optimizer/optimizers.py index d1ed3a3c4..58682eeb6 100644 --- a/pytext/optimizer/optimizers.py +++ b/pytext/optimizer/optimizers.py @@ -34,13 +34,14 @@ class Adam(torch.optim.Adam, Optimizer): class Config(Optimizer.Config): lr: float = 0.001 weight_decay: float = 0.00001 + eps: float = 1e-8 - def __init__(self, parameters, lr, weight_decay): - super().__init__(parameters, lr=lr, weight_decay=weight_decay) + def __init__(self, parameters, lr, weight_decay, eps): + super().__init__(parameters, lr=lr, weight_decay=weight_decay, eps=eps) @classmethod def from_config(cls, config: Config, model: torch.nn.Module): - return cls(model.parameters(), config.lr, config.weight_decay) + return cls(model.parameters(), config.lr, config.weight_decay, config.eps) class SGD(torch.optim.SGD, Optimizer):