Skip to content

Commit

Permalink
start storing version of dalle-pytorch alongside model weights - also…
Browse files Browse the repository at this point in the history
… throw an error if trying to generate with a dalle-pytorch whose VAE is not the same as the type with which it was trained on
  • Loading branch information
lucidrains committed Dec 25, 2021
1 parent e1d10b9 commit 1c25f54
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 30 deletions.
3 changes: 3 additions & 0 deletions dalle_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE

from pkg_resources import get_distribution
__version__ = get_distribution('dalle_pytorch').version
16 changes: 14 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,16 @@ def exists(val):
assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)

dalle_params.pop('vae', None) # cleanup later
# friendly print

if exists(version):
print(f'Loading a model trained with DALLE-pytorch version {version}')
else:
print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')

# load VAE

if args.taming:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
Expand All @@ -90,6 +97,10 @@ def exists(val):
else:
vae = OpenAIDiscreteVAE()

assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'

# reconstitute DALL-E

dalle = DALLE(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)
Expand Down Expand Up @@ -118,6 +129,7 @@ def exists(val):
outputs = torch.cat(outputs)

# save all images

file_name = text
outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
outputs_dir.mkdir(parents = True, exist_ok = True)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.1.5',
version = '1.1.6',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down
60 changes: 33 additions & 27 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@
parser.add_argument('--image_text_folder', type=str, required=True,
help='path to your folder of images and text for learning the DALL-E')

parser.add_argument(
'--wds',
type = str,
default='',
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
)
parser.add_argument('--wds', type = str, default='',
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.')

parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
help='Captions passed in which exceed the max token length will be truncated if this is set.')
Expand All @@ -75,7 +71,7 @@


parser.add_argument('--amp', action='store_true',
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')

parser.add_argument('--wandb_name', default='dalle_train_transformer',
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
Expand Down Expand Up @@ -144,6 +140,10 @@ def exists(val):
def get_trainable_params(model):
return [params for params in model.parameters() if params.requires_grad]

def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle_pytorch').version

def cp_path_to_dir(cp_path, tag):
"""Convert a checkpoint path to a directory with `tag` inserted.
If `cp_path` is already a directory, return it unchanged.
Expand All @@ -157,6 +157,7 @@ def cp_path_to_dir(cp_path, tag):
return cp_dir

# constants

WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False

Expand Down Expand Up @@ -232,6 +233,7 @@ def cp_path_to_dir(cp_path, tag):
tokenizer = ChineseTokenizer()

# reconstitute vae

if RESUME:
dalle_path = Path(DALLE_PATH)
if using_deepspeed:
Expand All @@ -249,15 +251,11 @@ def cp_path_to_dir(cp_path, tag):

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
elif args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE()

dalle_params = dict(
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
Expand Down Expand Up @@ -311,7 +309,6 @@ def cp_path_to_dir(cp_path, tag):
if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
vae.enc.blocks.output.conv.use_float16 = True


# helpers

def group_weight(model):
Expand Down Expand Up @@ -388,17 +385,20 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
if not ENABLE_WEBDATASET:
print(f'{len(ds)} image-text pairs found for training')

# data sampler

data_sampler = None

if not is_shuffle:
data_sampler = torch.utils.data.distributed.DistributedSampler(
ds,
num_replicas=distr_backend.get_world_size(),
rank=distr_backend.get_rank()
)
else:
data_sampler = None

# WebLoader for WebDataset and DeepSpeed compatibility

if ENABLE_WEBDATASET:
# WebLoader for WebDataset and DeepSpeed compatibility
dl = wds.WebLoader(ds, batch_size=None, shuffle=False, num_workers=4) # optionally add num_workers=2 (n) argument
number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
dl = dl.slice(number_of_batches)
Expand All @@ -407,10 +407,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
# Regular DataLoader for image-text-folder datasets
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)


# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)

if not using_deepspeed:
if args.fp16:
dalle = dalle.half()
Expand All @@ -422,9 +422,14 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
# optimizer

opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)

if RESUME and opt_state:
opt.load_state_dict(opt_state)

# scheduler

scheduler = None

if LR_DECAY:
scheduler = ReduceLROnPlateau(
opt,
Expand All @@ -437,11 +442,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
)
if RESUME and scheduler_state:
scheduler.load_state_dict(scheduler_state)
else:
scheduler = None

# experiment tracker

if distr_backend.is_root_worker():
# experiment tracker

model_config = dict(
depth=DEPTH,
Expand Down Expand Up @@ -503,8 +507,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
config_params=deepspeed_config,
)
# Prefer scheduler in `deepspeed_config`.

if LR_DECAY and distr_scheduler is None:
distr_scheduler = scheduler

avoid_model_calls = using_deepspeed and args.fp16

if RESUME and using_deepspeed:
Expand All @@ -516,7 +522,10 @@ def save_model(path, epoch=0):
'hparams': dalle_params,
'vae_params': vae_params,
'epoch': epoch,
'version': get_pkg_version(),
'vae_class_name': vae.__class__.__name__
}

if using_deepspeed:
cp_dir = cp_path_to_dir(path, 'ds')

Expand Down Expand Up @@ -552,8 +561,9 @@ def save_model(path, epoch=0):
**save_obj,
'weights': dalle.state_dict(),
'opt_state': opt.state_dict(),
'scheduler_state': (scheduler.state_dict() if scheduler else None)
}
save_obj['scheduler_state'] = (scheduler.state_dict() if scheduler else None)

torch.save(save_obj, path)

# training
Expand Down Expand Up @@ -611,10 +621,6 @@ def save_model(path, epoch=0):
# CUDA index errors when we don't guard this
image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9


log = {
**log,
}
if not avoid_model_calls:
log['image'] = wandb.Image(image, caption=decoded_text)

Expand Down

0 comments on commit 1c25f54

Please sign in to comment.