Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Minimum Trust Lamb (#1186)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1186

Implementation of minimum trust Lamb as described in section 6.3 of https://arxiv.org/abs/1911.11423

Reviewed By: ArmenAg

Differential Revision: D18893828

fbshipit-source-id: 61b82be2377f388aa2132ded26ddd3f279902bb7
  • Loading branch information
Akshat Shrivastava authored and facebook-github-bot committed Dec 10, 2019
1 parent d46a90f commit 060e3e2
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion pytext/optimizer/lamb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional

import torch
from pytext.optimizer.optimizers import Optimizer
Expand All @@ -13,12 +14,17 @@ class Lamb(Optimizer, PT_Optimizer):
https://github.com/cybertronai/pytorch-lamb
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`.
https://arxiv.org/abs/1904.00962
Has the option for minimum trust LAMB as described in "Single Headed
Attention RNN: Stop Thinking With Your Head" section 6.3
https://arxiv.org/abs/1911.11423
"""

class Config(Optimizer.Config):
lr: float = 0.001
weight_decay: float = 0.00001
eps: float = 1e-8
min_trust: Optional[float] = None

@classmethod
def from_config(cls, config: Config, model: torch.nn.Module):
Expand All @@ -27,9 +33,18 @@ def from_config(cls, config: Config, model: torch.nn.Module):
lr=config.lr,
weight_decay=config.weight_decay,
eps=config.eps,
min_trust=config.min_trust,
)

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
min_trust=None,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -44,6 +59,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0
{"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
)

self.min_trust = min_trust

def step(self, closure=None):
"""Performs a single optimization step.
Expand Down Expand Up @@ -104,6 +121,8 @@ def step(self, closure=None):
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
if self.min_trust:
trust_ratio = max(self.min_trust, trust_ratio)
state["weight_norm"] = weight_norm
state["adam_norm"] = adam_norm
state["trust_ratio"] = trust_ratio
Expand Down

0 comments on commit 060e3e2

Please sign in to comment.