Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4
width: 512
height: 512
Expand Down
10 changes: 10 additions & 0 deletions ldm/invoke/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import hashlib
import psutil
import transformers
import os
from sys import getrefcount
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
Expand Down Expand Up @@ -193,6 +194,7 @@ def _load_model(self, model_name:str):
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width
height = mconfig.height

Expand Down Expand Up @@ -222,9 +224,17 @@ def _load_model(self, model_name:str):
else:
print(' | Using more accurate float32 precision')

# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)

model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device

model.eval()

for m in model.modules():
Expand Down
14 changes: 12 additions & 2 deletions scripts/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config'])

done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0

completer.complete_extensions(None)

for field in ('width','height'):
Expand Down Expand Up @@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer):

conf = config[model_name]
new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
for field in ('description', 'weights', 'config', 'width','height'):
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value
Expand Down