Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions torchrec/optim/tests/test_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,160 @@ def step(self, closure: Any) -> None:
pass # Override NotImplementedError.


class TestGetMultiplier(unittest.TestCase):
"""Tests for the _get_multiplier function with TRANSFORMER policy."""

def test_transformer_warmup_at_step_one(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=4000,
lr_scale=1.0,
)

# Execute: Get multiplier at iteration 0 (step 1 internally)
from torchrec.optim.warmup import _get_multiplier

multiplier = _get_multiplier(stage, iter=0)

# Assert: At step 1, multiplier should be min(1, 1/4000^1.5) ≈ 0.0000158
# step^(-0.5) = 1^(-0.5) = 1.0
# step * warm_steps^(-1.5) = 1 * 4000^(-1.5) ≈ 0.0000158
expected = min(1.0, 1 * (4000 ** (-1.5)))
self.assertAlmostEqual(multiplier, expected, places=8)
self.assertLess(multiplier, 0.00002)

def test_transformer_warmup_at_warmup_steps(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=4000,
lr_scale=1.0,
)

# Execute: Get multiplier at iteration 3999 (step 4000 internally)
from torchrec.optim.warmup import _get_multiplier

multiplier = _get_multiplier(stage, iter=3999)

# Assert: At step=warm_steps, both terms are equal
# step^(-0.5) = 4000^(-0.5) ≈ 0.0158
# step * warm_steps^(-1.5) = 4000 * 4000^(-1.5) ≈ 0.0158
step = 4000
expected = min(step ** (-0.5), step * (4000 ** (-1.5)))
self.assertAlmostEqual(multiplier, expected, places=8)
self.assertAlmostEqual(multiplier, 0.0158114, places=6)

def test_transformer_warmup_after_warmup_steps(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=4000,
lr_scale=1.0,
)

# Execute: Get multiplier at iteration 7999 (step 8000 internally)
from torchrec.optim.warmup import _get_multiplier

multiplier = _get_multiplier(stage, iter=7999)

# Assert: After warmup, step^(-0.5) dominates (is smaller)
# step^(-0.5) = 8000^(-0.5) ≈ 0.0112
# step * warm_steps^(-1.5) = 8000 * 4000^(-1.5) ≈ 0.0316
step = 8000
inv_sqrt = step ** (-0.5)
warmup_term = step * (4000 ** (-1.5))
self.assertAlmostEqual(multiplier, inv_sqrt, places=8)
self.assertLess(inv_sqrt, warmup_term)
self.assertAlmostEqual(multiplier, 0.0111803, places=6)

def test_transformer_warmup_with_lr_scale(self) -> None:
# Setup: Create TRANSFORMER warmup stage with lr_scale=2.0
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=4000,
lr_scale=2.0,
)

# Execute: Get multiplier at iteration 3999 (step 4000 internally)
from torchrec.optim.warmup import _get_multiplier

multiplier = _get_multiplier(stage, iter=3999)

# Assert: lr_scale is applied as a multiplier
step = 4000
base_multiplier = min(step ** (-0.5), step * (4000 ** (-1.5)))
expected = base_multiplier * 2.0
self.assertAlmostEqual(multiplier, expected, places=8)

def test_transformer_warmup_formula_correctness(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=1000,
lr_scale=1.0,
)

# Execute: Test multiple iterations to verify formula
from torchrec.optim.warmup import _get_multiplier

test_iters = [0, 99, 499, 999, 1999] # steps 1, 100, 500, 1000, 2000
for iter_val in test_iters:
multiplier = _get_multiplier(stage, iter=iter_val)
step = iter_val + 1

# Assert: Multiplier matches the Transformer formula
expected = min(step ** (-0.5), step * (1000 ** (-1.5)))
self.assertAlmostEqual(
multiplier,
expected,
places=8,
msg=f"Failed at iteration {iter_val} (step {step})",
)

def test_transformer_warmup_monotonic_increase_during_warmup(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=1000,
lr_scale=1.0,
)

# Execute: Get multipliers during warmup phase
from torchrec.optim.warmup import _get_multiplier

multipliers = [_get_multiplier(stage, iter=i) for i in range(0, 1000)]

# Assert: Multipliers should increase monotonically during warmup
for idx in range(len(multipliers) - 1):
self.assertLess(
multipliers[idx],
multipliers[idx + 1],
msg=f"Multiplier should increase at iteration {idx}",
)

def test_transformer_warmup_monotonic_decrease_after_warmup(self) -> None:
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
stage = WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=1000,
lr_scale=1.0,
)

# Execute: Get multipliers after warmup phase
from torchrec.optim.warmup import _get_multiplier

multipliers = [_get_multiplier(stage, iter=i) for i in range(1000, 2000)]

# Assert: Multipliers should decrease monotonically after warmup
for i in range(len(multipliers) - 1):
self.assertGreater(
multipliers[i],
multipliers[i + 1],
msg=f"Multiplier should decrease at iteration {i + 1000}",
)


class TestWarmupOptimizer(unittest.TestCase):
def test_load_state_dict(self) -> None:
def get_optimizer() -> WarmupOptimizer:
Expand Down Expand Up @@ -72,3 +226,155 @@ def get_optimizer() -> WarmupOptimizer:
warmup_optimizer_1.state_dict()["state"]["__warmup"],
warmup_optimizer_2.state_dict()["state"]["__warmup"],
)

def test_transformer_warmup_integration(self) -> None:
# Setup: Create optimizer with TRANSFORMER warmup policy
param = Variable(torch.tensor([1.0, 2.0]))
keyed_optimizer = DummyKeyedOptimizer(
{"param": param}, defaultdict(dict), [{"params": [param]}]
)

base_lr = 0.001
warm_steps = 100

warmup_optimizer = WarmupOptimizer(
keyed_optimizer,
stages=[
WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=100, # Stage ends at iteration 100
lr_scale=1.0,
),
],
lr=base_lr,
)

# Execute: Run optimizer through warmup steps
learning_rates = []
current_lr = 0.0
for _ in range(100): # Only iterate through the TRANSFORMER stage
for param_group in warmup_optimizer.param_groups:
current_lr = param_group["lr"]
learning_rates.append(current_lr)
warmup_optimizer.step()

# Assert: Verify learning rate follows Transformer schedule during warmup
# At step 1 (iteration 0)
step_1 = 1
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (warm_steps ** (-1.5)))
self.assertAlmostEqual(learning_rates[0], expected_lr_1, places=10)

# At step 50 (iteration 49) - mid-warmup
step_50 = 50
expected_lr_50 = base_lr * min(
step_50 ** (-0.5), step_50 * (warm_steps ** (-1.5))
)
self.assertAlmostEqual(learning_rates[49], expected_lr_50, places=10)

# At step 100 (iteration 99) - warmup completion
step_100 = 100
expected_lr_100 = base_lr * min(
step_100 ** (-0.5), step_100 * (warm_steps ** (-1.5))
)
self.assertAlmostEqual(learning_rates[99], expected_lr_100, places=10)

# Verify learning rate increases monotonically during warmup
for idx in range(warm_steps - 1):
self.assertLess(
learning_rates[idx],
learning_rates[idx + 1],
msg=f"LR should increase during warmup at step {idx + 1}",
)
# Verify formula correctness at this step
step = idx + 1
expected_lr_at_idx = base_lr * min(
step ** (-0.5), step * (warm_steps ** (-1.5))
)
self.assertAlmostEqual(
learning_rates[idx],
expected_lr_at_idx,
places=10,
msg=f"LR mismatch at step {step}",
)

def test_transformer_warmup_with_extended_stage(self) -> None:
# Setup: Create optimizer with TRANSFORMER stage to test warmup and decay
param = Variable(torch.tensor([1.0, 2.0]))
keyed_optimizer = DummyKeyedOptimizer(
{"param": param}, defaultdict(dict), [{"params": [param]}]
)

base_lr = 0.001
# In the TRANSFORMER policy, max_iters acts as warm_steps in the formula
max_iters = 8000 # Stage runs for 8000 iterations

warmup_optimizer = WarmupOptimizer(
keyed_optimizer,
stages=[
WarmupStage(
policy=WarmupPolicy.TRANSFORMER,
max_iters=max_iters, # Stage runs for 8000 iterations
lr_scale=1.0,
),
],
lr=base_lr,
)

# Execute: Run optimizer through warmup and decay phases
current_lr = 0.0
learning_rates = []
for _ in range(max_iters):
for param_group in warmup_optimizer.param_groups:
current_lr = param_group["lr"]
learning_rates.append(current_lr)
warmup_optimizer.step()

# Assert: Verify the formula uses max_iters as warm_steps
# At step 1, verify the formula: min(step^(-0.5), step * max_iters^(-1.5))
step_1 = 1
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (max_iters ** (-1.5)))
self.assertAlmostEqual(
learning_rates[0],
expected_lr_1,
places=10,
msg=f"LR at step 1 should match formula with warm_steps={max_iters}",
)

# At step 4000, verify with max_iters=8000
step_4000 = 4000
expected_lr_4000 = base_lr * min(
step_4000 ** (-0.5), step_4000 * (max_iters ** (-1.5))
)
self.assertAlmostEqual(
learning_rates[3999],
expected_lr_4000,
places=10,
msg=f"LR at step 4000 should match formula with warm_steps={max_iters}",
)

# At step max_iters (8000), both terms should be equal
step_max = max_iters
inv_sqrt = step_max ** (-0.5)
warmup_term = step_max * (max_iters ** (-1.5))
self.assertAlmostEqual(
inv_sqrt,
warmup_term,
places=10,
msg=f"At step={max_iters}, both formula terms should be equal",
)

expected_lr_max = base_lr * min(inv_sqrt, warmup_term)
self.assertAlmostEqual(
learning_rates[max_iters - 1],
expected_lr_max,
places=10,
msg=f"LR at step {max_iters} should match formula",
)

# Verify learning rate increases before max_iters
for idx in range(max_iters - 1):
self.assertLess(
learning_rates[idx],
learning_rates[idx + 1],
msg=f"LR should increase at step {idx + 1} (before max_iters={max_iters})",
)
10 changes: 10 additions & 0 deletions torchrec/optim/warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class WarmupPolicy(Enum):
STEP = "step"
INVSQRT = "inv_sqrt" # inverse square root
COSINE_ANNEALING_WARM_RESTARTS = "cosine_annealing_warm_restarts"
TRANSFORMER = (
"transformer" # Transformer warmup: min(step^(-0.5), step * warm_steps^(-1.5))
)


@dataclass
Expand Down Expand Up @@ -86,6 +89,13 @@ def _get_multiplier(stage: WarmupStage, iter: int) -> float:
t_cur = iter % t_0
cos_iter = 0.5 * (1 + math.cos(math.pi * t_cur / t_0))
multiplier = eta_min + (1.0 - eta_min) * cos_iter
elif stage.policy == WarmupPolicy.TRANSFORMER:
# Transformer warmup from "Attention is All You Need" (Vaswani et al., 2017)
# Formula: lr_scale = min(step^(-0.5), step * warm_steps^(-1.5))
# where warm_steps = max_iters
# Add 1 to iter to make it 1-indexed and avoid division by zero
step = iter + 1
multiplier = min(step ** (-0.5), step * stage.max_iters ** (-1.5))
return multiplier * stage.lr_scale


Expand Down
Loading