Skip to content

Commit

Permalink
Using non-JIT by default; compat fix with 1.8+
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Jul 19, 2021
1 parent cfcffb9 commit db20393
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
29 changes: 24 additions & 5 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC


if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")


__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()

Expand Down Expand Up @@ -57,7 +68,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):

def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Expand All @@ -70,7 +81,7 @@ def available_models() -> List[str]:
return list(_MODELS.keys())


def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
"""Load a CLIP model
Parameters
Expand All @@ -82,7 +93,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Returns
-------
Expand Down Expand Up @@ -121,7 +132,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []

if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

Expand All @@ -141,7 +156,11 @@ def patch_device(module):
float_node = float_input.node()

def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []

if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ftfy
regex
tqdm
torch~=1.7.1
torchvision~=0.8.2
torch
torchvision

1 comment on commit db20393

@julienbelangerunity
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior Changed

import clip
clip.load("ViT-B/32", device="cpu") 

Yield
'CLIP' object has no attribute 'input_resolution'
fixed by
clip.load("ViT-B/32", device="cpu", jit=True)
Version
torch=1.7.1
torchvision=0.8.2

Please sign in to comment.