Skip to content

Commit

Permalink
add an assert on torch version in order to use OpenAIDiscreteVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 16, 2022
1 parent 2612a51 commit d03df6a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pathlib import Path
from tqdm import tqdm
from math import sqrt, log
from packaging import version

from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
import importlib
Expand Down Expand Up @@ -98,11 +100,18 @@ def make_contiguous(module):
for param in module.parameters():
param.set_(param.contiguous())

# package versions

def get_pkg_version(pkg_name):
from pkg_resources import get_distribution
return get_distribution(pkg_name).version

# pretrained Discrete VAE from OpenAI

class OpenAIDiscreteVAE(nn.Module):
def __init__(self):
super().__init__()
assert version.parse(get_pkg_version('torch')) < version.parse('1.11.0'), 'torch version must be <= 1.10 in order to use OpenAI discrete vae'

self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.6.1',
version = '1.6.2',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d03df6a

Please sign in to comment.