Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update lr_finder.py #42

Closed
wants to merge 1 commit into from
Closed

Update lr_finder.py #42

wants to merge 1 commit into from

Conversation

yongduek
Copy link

I was looking for a pytorch lrfinder. Very nice and thank you for sharing them.

When lr_finder.history[\'lr\'] was printed out, it did not start with the initial learning rate set in the optimizer declaration. So two things seem to be modified to make it happen.

  1. Order change: get_lr() before append to history.
  2. Function change: get_last_lr() to get the latest lr. This seems to be a recent change in pytorch.
    self.history["lr"].append(lr_schedule.get_last_lr()[0])
    lr_schedule.step()
  1. self.last_epoch does not need to be incremented in ExponentialLR and LinearLR
class ExponentialLR(_LRScheduler):
    def get_lr(self):
        curr_iter = self.last_epoch # + 1
        r = curr_iter / self.num_iter
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
  • From inspection, super().__init__() calls step() within _LRScheduler, then step() calls get_lr() to update the values. The results are saved in variables in the base class (_LRScheduler).
  • The lastly updated learning rate (or list of lrs) are retrieved by get_last_lr(). This is reportedly the recent pytorch way of using it.

@davidtvs
Copy link
Owner

Thanks for the PR!

  1. The order of the call to get_lr and step is definitely wrong
  2. get_last_lr was introduced in v1.4.0 of PyTorch. I would prefer not to use it unless it's absolutely needed so that we can keep supporting older versions of PyTorch (currently this package runs on v0.4.1 onwards).
  3. I tried this and it didn't seem to work. I ran the experiment with these settings:
# what's not visible here comes from the lrfinder_mnist example
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=10, num_iter=5, step_mode="exp")

With curr_iter = self.last_epoch + 1, print(lr_finder.history["lr"]) returns:

[0.001, 0.006309573444801934, 0.039810717055349734, 0.25118864315095796, 1.584893192461114]

With curr_iter = self.last_epoch, print(lr_finder.history["lr"]) returns:

[0.00015848931924611134, 0.001, 0.006309573444801934, 0.039810717055349734, 0.25118864315095796]

So the logic is definitely wrong in both cases, the history should start at 0.001 and end at 10 😕

@NaleRaphael
Copy link
Contributor

NaleRaphael commented May 21, 2020

@yongduek, @davidtvs Sorry for replying late.

When lr_finder.history['lr'] was printed out, it did not start with the initial learning rate set in the optimizer declaration.

For this statement, I want to mention an implementation detail in torch.optim.lr_scheduler._LRScheduler.

In _LRScheduler, step() will be called right after it is initialized (since PyTorch v0.2.0), that's why the first value retrieved from either scheduler or optimizer will not be the initial learning rate we set.

Besides, the reason why the history did not start with the initial learning rate we set is that the history is used to save values changed after lr_schedule.step() is called. Hence that it is reasonable the initial learning rate of optimizer is not saved in history.

So that it is correct to change the order of the following two lines as you said:

lr_schedule.step()
self.history["lr"].append(lr_schedule.get_lr()[0])

, but we should keep the line of setting curr_iter unchanged.

In conclusion, with the example provided by @davidtvs, the meaning of lr history you expected is actually like this:

[0.001, 0.00630957, 0.03981071, 0.25118864, 1.58489319, 10.0]
 ^ initial lr (before `lr_scheduler.step()` is called)
        <-      lr changed by `lr_scheduler.step()`       ->

, but the meaning of history recorded by current implementation of this package is:

[0.00630957, 0.03981071, 0.25118864, 1.58489319, 10.0]
 <-      lr changed by `lr_scheduler.step()`       ->

For the point 2 you mentioned, I agree with @davidtvs's decision. We should care about the backward compatibility.

@davidtvs
Copy link
Owner

I vaguely remember getting pretty confused by how the _LRSheduler works when I was writing the schedulers for this package and even now I still find the logic hard to follow for something that seems pretty simple. Anyway, rant over and it looks much better after 1.4.0.

There's still something wrong besides the order of those two lines because for my experiment I locally fixed the order and the learning rate never reaches 10. The history @NaleRaphael posted as expected has 6 elements instead of 5; I would definitely expect 5 elements. It seems to me that the computation of the learning rates is incorrect.

@NaleRaphael
Copy link
Contributor

NaleRaphael commented May 21, 2020

@davidtvs Maybe we should modify the line r = curr_iter / self.num_iter? It's kind of like how np.linspace() works. Given 2 endpoints (base_lr, end_lr) and a number of interval (num_iter), and the expected result should be:

[base_lr, x1, x2, ... , end_lr]
 <-    length: num_iter     ->

If this is what we want, then the following revision should make it work:

class ExponentialLR(_LRScheduler):
    def get_lr(self):
        # Note that we should handle the case when given `num_iter` is 1,
        # it would trigger `ZeroDivisionError` here.
        r = self.last_epoch / (self.num_iter - 1)
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]

A quick explanation:

# --- The first step, `last_epoch`: 0 ---
r
= last_epoch / (num_iter - 1)
= 0 / (num_iter - 1)
= 0.

base_lr * (end_lr / base_lr) ** r
= base_lr * (end_lr / base_lr) ** 0.0
= base_lr * 1.0
= base_lr


# --- The last step, `last_epoch`: num_iter - 1 ---
r
= last_epoch / (num_iter - 1)
= (num_iter - 1) / (num_iter - 1)
= 1.0

base_lr * (end_lr / base_lr) ** r
= base_lr * (end_lr / base_lr) ** 1.0
= base_lr * (end_lr / base_lr)
= end_lr

With the execution order like this:

for i in range(num_iter):
    train()
    eval()
    optimizer.step()
    scheduler.step()

, optimizer's lr in the first iteration will be exactly the initial value we set. So that we can ensure the lr used in each iteration will be exactly the same as those ones saved in history.

@NaleRaphael
Copy link
Contributor

The history @NaleRaphael posted as expected has 6 elements instead of 5

Sorry for the unclear description. What I want to explain in that comment is the history in LRFinder is used to saved those lrs changed after scheduler.step() is called, rather than being used to trace the variation of lr (including its initial state) during the lr-searching task. And that's why I add the init_lr directly to that list instead of recalculating a list with length of 5.

Besides, there is one thing I was considering while writing that comment. As we known so far, step() will be called right after a scheduler is created. Therefore, putting cases like using StepLR or others aside, the lr we expected before the first iteration should be a value calculated/changed by scheduler.get_lr() instead of the initial value. So I was wondering whether it will break the expectation of how lr scheduler works for PyTorch users if we are going to change this behavior. Even though the current design of official lr sheduler is somehow weird...

davidtvs added a commit that referenced this pull request May 22, 2020
Two things can be concluded:
* The current computation of the exponential learning rate is incorrect in all version of PyTorch above 0.4.1
* PyTorch 1.4.0 introduced a different design for the LRScheduler; the fix for the exponential learning rate for version 0.4.1 is not the same as the one for 1.4.0

Setup for the above:
1. torch==0.4.1 torchvision==0.2.1
2. torch==1.4.0 torchvision==0.5.0
@davidtvs
Copy link
Owner

I made a test that replicates the issue and ran it with PyTorch v0.4.1 and v1.4.0, here's what I observed:

  • the current computation of the exponential learning rate is wrong for all versions of PyTorch (at least the ones after v0.4.1)
  • with PyTorch v0.4.1 the fix for this issue is mix of the current implementation and what @NaleRaphael suggested:
def get_lr(self):
        curr_iter = self.last_epoch + 1
        r = curr_iter / (self.num_iter - 1)
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
def get_lr(self):
        r = self.last_epoch / (self.num_iter - 1)
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]

I'm thinking this means that there will be one last release with this fixed for v0.4.1 and after that, this package will support only PyTorch v1.4.0+.

@NaleRaphael
Copy link
Contributor

NaleRaphael commented May 23, 2020

@davidtvs Thanks for your effort on this, I only tested it on PyTorch 1.3.0 before. 🙇

I found this issue has been addressed in PyTorch #7889, and it has been fixed in PyTorch v1.1.0 (see also this commit).

In step(), self.last_epoch will be updated according to the argument epoch if it is given, otherwise, it will be incremented by 1. However, in v0.4, there is an additional assignment: self.last_epoch = last_epoch at the last line of _LRScheduler.__init__(). Therefore, self.last_epoch is reset to -1. It's like:

# source: https://github.com/pytorch/pytorch/blob/3749c58/torch/optim/lr_scheduler.py#L22-L23
# Given `last_epoch` is the default value -1.
self.step(last_epoch + 1)     # after this line is executed, self.last_epoch = 0
self.last_epoch = last_epoch  # after this line is executed, self.last_epoch = -1

In my opinion, we can keep supporting at least PyTorch v1.1.0+ on the master branch with the second approach, and perhaps create a new branch for PyTorch v0.4 with the first approach.

UPDATE: One more thing about the fix. Since the denominator of r is self.num_iter - 1 in these fixes, do you think it better to throw a warning or an error when given num_iter <= 1 while initializing the scheduler, or just use a if-else to handle it in get_lr()?

@davidtvs
Copy link
Owner

Ya, I think we can just raise an exception when num_iter <= 1 it doesn't make sense setting it to a value like that anyway.

As to how we handle the learning rate computation, I thought about using the packaging package to parse the Pytorch version in use and then have an if statement switch the way get_lr() computes the learning rate. This way we keep only one branch which is easier to maintain and the added complexity is not much.

@NaleRaphael
Copy link
Contributor

Great, so I think it's time to keep moving on this PR.
@yongduek Feel free to let us know if there is anything we can help with.

@yongduek
Copy link
Author

Thanks for asking, and thanks for your effort for the open source. I would like to join the activity but have to rely on yours because of a tight schedule until the end of June. One thing I would like to mention is that because of the smoothing inside the LRFinder class, the loss-lr graph shows some lagged version of the one without smoothing, and I am not sure how much this affects the actual learning process afterwords. Well this must be beyond the scope of this project.
Many thanks again for your sharing knowledge and effort ^^

@yongduek yongduek closed this May 28, 2020
@davidtvs
Copy link
Owner

@yongduek thanks for the PR at least now we know about the issue and it'll get fixed. Since I had already created a branch for this and made some changes there I'll finish it there so that this issue is fixed for the next release.

It also brought to my attention that different PyTorch versions can break stuff silently so I'll start running different versions of PyTorch in the CI.

davidtvs added a commit that referenced this pull request May 29, 2020
Two things can be concluded:
* The current computation of the exponential learning rate is incorrect in all version of PyTorch above 0.4.1
* PyTorch 1.4.0 introduced a different design for the LRScheduler; the fix for the exponential learning rate for version 0.4.1 is not the same as the one for 1.4.0

Setup for the above:
1. torch==0.4.1 torchvision==0.2.1
2. torch==1.4.0 torchvision==0.5.0
davidtvs added a commit that referenced this pull request May 30, 2020
…#43, #42)

* Add unit test related to #42

Two things can be concluded:
* The current computation of the exponential learning rate is incorrect in all version of PyTorch above 0.4.1
* PyTorch 1.4.0 introduced a different design for the LRScheduler; the fix for the exponential learning rate for version 0.4.1 is not the same as the one for 1.4.0

Setup for the above:
1. torch==0.4.1 torchvision==0.2.1
2. torch==1.4.0 torchvision==0.5.0

* Fix learning rate computation in schedulers

* Add PyTorch matrix to the CI job

Also, added caching to make the CI job faster.

* Raise ValueError for num_iter<=1

* Fix syntax error in CI yaml

* Fix syntax error in CI yaml v2

* Fix CI job

* The combo of py3.7 and torch 0.4.1 breaks type inference for torch.tensor with np.int64

* Allow CI to be skipped
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants