Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CyclicLR throws ZeroDivisionError when finetuning with a single batch. #174

Closed
LuoXueling opened this issue Jul 5, 2023 · 2 comments · Fixed by #177
Closed

CyclicLR throws ZeroDivisionError when finetuning with a single batch. #174

LuoXueling opened this issue Jul 5, 2023 · 2 comments · Fixed by #177
Assignees
Labels

Comments

@LuoXueling
Copy link
Collaborator

I get an error when I finetune a model. Here is the minimal code to reproduce:

from pytorch_widedeep.models import WideDeep, TabMlp
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import TabPreprocessor
from sklearn.datasets import make_regression
import pandas as pd
import numpy as np

data = make_regression(n_samples=1024, n_features=100, n_targets=1, random_state=0)
cont_cols = [f"cont_{i}" for i in range(100)]
df = pd.DataFrame(
    data=np.hstack((data[0], data[1].reshape(-1, 1))), columns=cont_cols + ["target"]
)
tab_preprocessor = TabPreprocessor(continuous_cols=cont_cols)
X_tab_train = tab_preprocessor.fit_transform(df)
tab_model = TabMlp(column_idx=tab_preprocessor.column_idx, continuous_cols=cont_cols)
model = WideDeep(deeptabular=tab_model)
trainer = Trainer(model, objective="regression")

trainer.fit(
    X_train={"X_tab": X_tab_train, "target": df["target"]},
    n_epochs=2,
    batch_size=1024,
    finetune=False,
)

# Note that batch_size=1024 means len(loader)=1
trainer.fit(
    X_train={"X_tab": X_tab_train, "target": df["target"]},
    n_epochs=2,
    batch_size=1024,
    finetune=True,
)

You can see that when fine-tuning the model, I am using a single batch (batch_size==len(df)). This code throws an error:

  File "/home/xlluo/hdd/ML-fracture/debug/widedeep/scheduler.py", line 26, in <module>
    trainer.fit(
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py", line 61, in __call__
    return wrapped(*args, **kwargs)
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py", line 470, in fit
    self._finetune(train_loader, **finetune_args)
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py", line 61, in __call__
    return wrapped(*args, **kwargs)
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py", line 61, in __call__
    return wrapped(*args, **kwargs)
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py", line 909, in _finetune
    finetuner.finetune_all(
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/pytorch_widedeep/training/_finetune.py", line 108, in finetune_all
    scheduler = torch.optim.lr_scheduler.CyclicLR(
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 1160, in __init__
    super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 77, in __init__
    self.step()
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 154, in step
    values = self.get_lr()
  File "/home/xlluo/anaconda3/envs/mlfatigue/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 1197, in get_lr
    scale_factor = x / self.step_ratio
ZeroDivisionError: float division by zero

This is caused by these lines:

step_size_up, step_size_down = self._steps_up_down(len(loader), n_epochs)
scheduler = torch.optim.lr_scheduler.CyclicLR(
optimizer,
base_lr=max_lr / 10.0,
max_lr=max_lr,
step_size_up=step_size_up,
step_size_down=step_size_down,
cycle_momentum=False,
)

up = round((steps * n_epochs) * 0.1)
down = (steps * n_epochs) - up

Here, len(loader)==1 and n_epochs==5 (by default), so step_size_up==0, which is an illegal input for CyclicLR. The same error can be revealed by:

torch.optim.lr_scheduler.CyclicLR(
    torch.optim.Adam(torch.nn.Linear(5, 5).parameters()),
    base_lr=0.1 / 10.0,
    max_lr=0.1,
    step_size_up=0,
    step_size_down=5,
    cycle_momentum=False,
)

I think changing the calculation of up to up = max([round((steps * n_epochs) * 0.1), 1]) would help. Currently, I can only add finetune_epochs=10 to fit to fix this issue.

@jrzaurin jrzaurin added the bug label Jul 7, 2023
jrzaurin added a commit that referenced this issue Jul 7, 2023
@jrzaurin
Copy link
Owner

jrzaurin commented Jul 7, 2023

An upcoming PR from the fix_restore_best_weights will fix this issue

@LuoXueling
Copy link
Collaborator Author

Thanks a lot. Closing the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants