Skip to content

Commit

Permalink
foreach optimizers
Browse files Browse the repository at this point in the history
Summary: Allow using the new `foreach` option on optimizers.

Reviewed By: shapovalov

Differential Revision: D39694843

fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 22, 2022
1 parent db3c12a commit 209c160
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
29 changes: 17 additions & 12 deletions projects/implicitron_trainer/impl/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
foreach: Whether to use new "foreach" implementation of optimizer where
available (e.g. requires PyTorch 1.12.0 for Adam)
"""

betas: Tuple[float, ...] = (0.9, 0.999)
Expand All @@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
weight_decay: float = 0.0
linear_exponential_lr_milestone: int = 200
linear_exponential_start_gamma: float = 0.1
foreach: Optional[bool] = True

def __post_init__(self):
run_auto_creation(self)
Expand Down Expand Up @@ -115,23 +119,24 @@ def __call__(
p_groups = [{"params": allprm, "lr": self.lr}]

# Intialize the optimizer
optimizer_kwargs: Dict[str, Any] = {
"lr": self.lr,
"weight_decay": self.weight_decay,
}
if self.breed == "SGD":
optimizer = torch.optim.SGD(
p_groups,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
optimizer_class = torch.optim.SGD
optimizer_kwargs["momentum"] = self.momentum
elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad(
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adagrad
elif self.breed == "Adam":
optimizer = torch.optim.Adam(
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adam
optimizer_kwargs["betas"] = self.betas
else:
raise ValueError(f"No such solver type {self.breed}")

if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
optimizer_kwargs["foreach"] = self.foreach
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
logger.info(f"Solver type = {self.breed}")

# Load state from checkpoint
Expand Down
1 change: 1 addition & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
weight_decay: 0.0
linear_exponential_lr_milestone: 200
linear_exponential_start_gamma: 0.1
foreach: true
training_loop_ImplicitronTrainingLoop_args:
evaluator_class_type: ImplicitronEvaluator
evaluator_ImplicitronEvaluator_args:
Expand Down
16 changes: 15 additions & 1 deletion projects/implicitron_trainer/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
import unittest
from pathlib import Path

import torch

from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf
from projects.implicitron_trainer.impl.optimizer_factory import (
ImplicitronOptimizerFactory,
)

from .. import experiment
from .utils import interactive_testing_requested, intercept_logs


internal = os.environ.get("FB_TEST", False)


Expand Down Expand Up @@ -151,6 +155,16 @@ def test_load_configs(self):
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
compose(file.name)

def test_optimizer_factory(self):
model = torch.nn.Linear(2, 2)

adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
self.assertIsInstance(adam, torch.optim.Adam)
sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
self.assertIsInstance(sgd, torch.optim.SGD)
adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
self.assertIsInstance(adagrad, torch.optim.Adagrad)


class TestNerfRepro(unittest.TestCase):
@unittest.skip("This test runs full blender training.")
Expand Down

0 comments on commit 209c160

Please sign in to comment.