Skip to content

Commit

Permalink
test: Increased coverage of holocron.optim (#160)
Browse files Browse the repository at this point in the history
* docs: Added docstring examples for wrappers

* test: Increased coverage of holocron.optim

* test: Speeds up test by switching to mobilenets

* test: Fixed import in test

* test: Fixed unittest

* test: Fixed typo in unittests

* test: Fixed unittests

* test: Fixed unittest
  • Loading branch information
frgfm authored Nov 6, 2021
1 parent 7d6e415 commit ee903c4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 29 deletions.
18 changes: 16 additions & 2 deletions holocron/optim/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class Lookahead(Optimizer):
"""Implements the Lookahead optimizer wrapper from `"Lookahead Optimizer: k steps forward, 1 step back"
<https://arxiv.org/pdf/1907.08610.pdf>`_.
Example::
>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Lookahead
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Lookahead(opt)
Args:
base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer
sync_rate (int, optional): rate of weight synchronization
Expand All @@ -26,7 +33,7 @@ def __init__(
self,
base_optimizer: torch.optim.Optimizer,
sync_rate=0.5,
sync_period=6
sync_period=6,
) -> None:
if sync_rate < 0 or sync_rate > 1:
raise ValueError(f'expected positive float lower than 1 as sync_rate, received: {sync_rate}')
Expand Down Expand Up @@ -140,6 +147,13 @@ class Scout(Optimizer):
"""Implements a new optimizer wrapper based on `"Lookahead Optimizer: k steps forward, 1 step back"
<https://arxiv.org/pdf/1907.08610.pdf>`_.
Example::
>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Scout
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Scout(opt)
Args:
base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer
sync_rate (int, optional): rate of weight synchronization
Expand All @@ -150,7 +164,7 @@ def __init__(
self,
base_optimizer: torch.optim.Optimizer,
sync_rate=0.5,
sync_period=6
sync_period=6,
) -> None:
if sync_rate < 0 or sync_rate > 1:
raise ValueError(f'expected positive float lower than 1 as sync_rate, received: {sync_rate}')
Expand Down
26 changes: 14 additions & 12 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any

import torch
from torch.nn import functional as F
from torchvision.models import resnet18
from torchvision.models import mobilenet_v3_small

from holocron import optim


def _test_optimizer(name: str) -> None:
def _test_optimizer(name: str, **kwargs: Any) -> None:

lr = 1e-4
input_shape = (3, 224, 224)
num_batches = 4
# Get model and optimizer
model = resnet18(num_classes=10)
for n, m in model.named_children():
if n != 'fc':
for p in m.parameters():
p.requires_grad_(False)
optimizer = optim.__dict__[name](model.fc.parameters(), lr=lr)
model = mobilenet_v3_small(num_classes=10)
for p in model.parameters():
p.requires_grad_(False)
for p in model.classifier[3].parameters():
p.requires_grad_(True)
optimizer = optim.__dict__[name](model.classifier[3].parameters(), lr=lr, **kwargs)

# Save param value
_p = model.fc.weight
_p = model.classifier[3].weight
p_val = _p.data.clone()

# Random inputs
Expand All @@ -44,15 +46,15 @@ def _test_optimizer(name: str) -> None:


def test_lars():
_test_optimizer('Lars')
_test_optimizer('Lars', momentum=0.9, weight_decay=2e-5)


def test_lamb():
_test_optimizer('Lamb')
_test_optimizer('Lamb', weight_decay=2e-5)


def test_ralars():
_test_optimizer('RaLars')
_test_optimizer('RaLars', weight_decay=2e-5)


def test_tadam():
Expand Down
33 changes: 18 additions & 15 deletions test/test_optim_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.nn import functional as F
from torch.optim import SGD
from torchvision.models import resnet18
from torchvision.models import mobilenet_v3_small

from holocron.optim import wrapper

Expand All @@ -17,15 +17,15 @@ def _test_wrapper(name: str) -> None:
input_shape = (3, 224, 224)
num_batches = 4
# Get model, optimizer and criterion
model = resnet18(num_classes=10)
for n, m in model.named_children():
if n != 'fc':
for p in m.parameters():
p.requires_grad_(False)
model = mobilenet_v3_small(num_classes=10)
for p in model.parameters():
p.requires_grad_(False)
for p in model.classifier[3].parameters():
p.requires_grad_(True)
# Pick an optimizer whose update is easy to verify
optimizer = SGD(model.fc.parameters(), lr=lr)
optimizer = SGD(model.classifier[3].parameters(), lr=lr)

# Wrap the optimizer
# Wrap the optimizer
opt_wrapper = wrapper.__dict__[name](optimizer)

# Check gradient reset
Expand All @@ -36,21 +36,24 @@ def _test_wrapper(name: str) -> None:
assert torch.all(p.grad == 0.)

# Check update step
_p = model.fc.weight
_p = model.classifier[3].weight
p_val = _p.data.clone()

# Random inputs
input_t = torch.rand((num_batches, *input_shape), dtype=torch.float32)
target = torch.zeros(num_batches, dtype=torch.long)

# Update
output = model(input_t)
loss = F.cross_entropy(output, target)
loss.backward()
opt_wrapper.step()

for _ in range(10):
output = model(input_t)
loss = F.cross_entropy(output, target)
loss.backward()
opt_wrapper.step()
# Check update rule
assert not torch.equal(_p.data, p_val - lr * _p.grad)
assert not torch.equal(_p.data, p_val) and not torch.equal(_p.data, p_val - lr * _p.grad)

# Repr
assert len(repr(opt_wrapper).split('\n')) == len(repr(optimizer).split('\n')) + 4


def test_lookahead():
Expand Down

0 comments on commit ee903c4

Please sign in to comment.