Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Commit

Permalink
Fixed torchvision version checking. (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
AniKar committed Oct 8, 2021
1 parent fe752a8 commit b9048eb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions util/misc.py
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict, deque
import datetime
import pickle
from packaging import version
from typing import Optional, List

import torch
Expand All @@ -18,7 +19,7 @@

# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if float(torchvision.__version__.split(".")[1]) < 7.0:
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size

Expand Down Expand Up @@ -454,7 +455,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if float(torchvision.__version__.split(".")[1]) < 7.0:
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
Expand Down

0 comments on commit b9048eb

Please sign in to comment.