Skip to content

Commit

Permalink
Fix bad import with PyTorch <= 1.4.1 (#8237)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Nov 2, 2020
1 parent 3c8d401 commit d1ad4bf
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import torch
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
from packaging import version
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler

Expand All @@ -34,6 +34,11 @@
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm

if version.parse(torch.__version__) <= version.parse("1.4.1"):
SAVE_STATE_WARNING = ""
else:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING

logger = logging.get_logger(__name__)


Expand Down

0 comments on commit d1ad4bf

Please sign in to comment.