Skip to content

Commit

Permalink
Fix torch version comparison (#10934)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuro96 committed Sep 18, 2023
1 parent ba358bc commit 02526bc
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import yaml
from mmengine.config import Config
from mmengine.fileio import dump
from mmengine.utils import mkdir_or_exist, scandir
from mmengine.utils import digit_version, mkdir_or_exist, scandir


def ordered_yaml_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
Expand Down Expand Up @@ -45,7 +45,7 @@ def process_checkpoint(in_file, out_file):

# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
if torch.__version__ >= '1.6':
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/layers/normed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.utils import digit_version
from torch import Tensor

from mmdet.registry import MODELS
Expand Down Expand Up @@ -91,7 +92,7 @@ def forward(self, x: Tensor) -> Tensor:
if hasattr(self, 'conv2d_forward'):
x_ = self.conv2d_forward(x_, weight_)
else:
if torch.__version__ >= '1.8':
if digit_version(torch.__version__) >= digit_version('1.8'):
x_ = self._conv_forward(x_, weight_, self.bias)
else:
x_ = self._conv_forward(x_, weight_)
Expand Down
3 changes: 2 additions & 1 deletion tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmengine.utils import digit_version

from mmdet.registry import MODELS

Expand Down Expand Up @@ -44,7 +45,7 @@ def parse_args():


def inference(args, logger):
if str(torch.__version__) < '1.12':
if digit_version(torch.__version__) < digit_version('1.12'):
logger.warning(
'Some config files, such as configs/yolact and configs/detectors,'
'may have compatibility issues with torch.jit when torch<1.12. '
Expand Down
3 changes: 2 additions & 1 deletion tools/model_converters/publish_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from mmengine.logging import print_log
from mmengine.utils import digit_version


def parse_args():
Expand Down Expand Up @@ -37,7 +38,7 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']):

# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
if torch.__version__ >= '1.6':
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
Expand Down
3 changes: 2 additions & 1 deletion tools/model_converters/upgrade_ssd_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from mmengine import Config
from mmengine.utils import digit_version


def parse_config(config_strings):
Expand Down Expand Up @@ -39,7 +40,7 @@ def convert(in_file, out_file):
out_state_dict[new_key] = value
checkpoint['state_dict'] = out_state_dict

if torch.__version__ >= '1.6':
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
Expand Down

0 comments on commit 02526bc

Please sign in to comment.