Skip to content

Commit

Permalink
Move latent scale factor from VAE to model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 23, 2023
1 parent 30a3861 commit 8607c2d
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 33 deletions.
16 changes: 16 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

class LatentFormat:
def process_in(self, latent):
return latent * self.scale_factor

def process_out(self, latent):
return latent / self.scale_factor

class SD15(LatentFormat):
def __init__(self, scale_factor=0.18215):
self.scale_factor = scale_factor

class SDXL(LatentFormat):
def __init__(self):
self.scale_factor = 0.13025

27 changes: 18 additions & 9 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import numpy as np

class BaseModel(torch.nn.Module):
def __init__(self, unet_config, v_prediction=False):
def __init__(self, model_config, v_prediction=False):
super().__init__()

unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
Expand Down Expand Up @@ -75,9 +77,16 @@ def load_model_weights(self, sd, unet_prefix=""):
del to_load
return self

def process_latent_in(self, latent):
return self.latent_format.process_in(latent)

def process_latent_out(self, latent):
return self.latent_format.process_out(latent)


class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, noise_aug_config, v_prediction=True):
super().__init__(model_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

def encode_adm(self, **kwargs):
Expand Down Expand Up @@ -112,13 +121,13 @@ def encode_adm(self, **kwargs):
return adm_out

class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.concat_keys = ("mask", "masked_image")

class SDXLRefiner(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.embedder = Timestep(256)

def encode_adm(self, **kwargs):
Expand All @@ -144,8 +153,8 @@ def encode_adm(self, **kwargs):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.embedder = Timestep(256)

def encode_adm(self, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")

if latent_image is not None:
latent_image = self.model.process_latent_in(latent_image)

extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}

cond_concat = None
Expand Down Expand Up @@ -672,4 +675,4 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)

return samples.to(torch.float32)
return self.model.process_latent_out(samples.to(torch.float32))
32 changes: 19 additions & 13 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def encode(self, text):


class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
def __init__(self, ckpt_path=None, device=None, config=None):
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
Expand All @@ -550,7 +550,6 @@ def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=Non
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)

self.scale_factor = scale_factor
if device is None:
device = model_management.get_torch_device()
self.device = device
Expand All @@ -561,7 +560,7 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)

decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
Expand All @@ -575,7 +574,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)

encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
Expand All @@ -593,7 +592,7 @@ def decode(self, samples_in):
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
Expand All @@ -620,7 +619,7 @@ def encode(self, pixel_samples):
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()

except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
Expand Down Expand Up @@ -958,6 +957,7 @@ def load_gligen(ckpt_path):
return model

def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
if config is None:
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
Expand Down Expand Up @@ -992,12 +992,20 @@ class WeightsLoader(torch.nn.Module):
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)

class EmptyClass:
pass

model_config = EmptyClass()
model_config.unet_config = unet_config
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)

if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
model = model_base.BaseModel(model_config, v_prediction=v_prediction)

if fp16:
model = model.half()
Expand All @@ -1006,14 +1014,12 @@ class WeightsLoader(torch.nn.Module):

if output_vae:
w = WeightsLoader()
vae = VAE(scale_factor=scale_factor, config=vae_config)
vae = VAE(config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)

if output_clip:
w = WeightsLoader()
class EmptyClass:
pass
clip_target = EmptyClass()
clip_target.params = clip_config.get("params", {})
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
Expand Down Expand Up @@ -1055,7 +1061,7 @@ class WeightsLoader(torch.nn.Module):
model.load_model_weights(sd, "model.diffusion_model.")

if output_vae:
vae = VAE(scale_factor=model_config.vae_scale_factor)
vae = VAE()
w = WeightsLoader()
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
Expand Down
13 changes: 7 additions & 6 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import sdxl_clip

from . import supported_models_base
from . import latent_formats

class SD15(supported_models_base.BASE):
unet_config = {
Expand All @@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE):
"num_head_channels": -1,
}

vae_scale_factor = 0.18215
latent_format = latent_formats.SD15

def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
Expand All @@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE):
"adm_in_channels": None,
}

vae_scale_factor = 0.18215
latent_format = latent_formats.SD15

def v_prediction(self, state_dict):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
Expand Down Expand Up @@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE):
"transformer_depth": [0, 4, 4, 0],
}

vae_scale_factor = 0.13025
latent_format = latent_formats.SDXL

def get_model(self, state_dict):
return model_base.SDXLRefiner(self.unet_config)
return model_base.SDXLRefiner(self)

def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
Expand All @@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE):
"adm_in_channels": 2816
}

vae_scale_factor = 0.13025
latent_format = latent_formats.SDXL

def get_model(self, state_dict):
return model_base.SDXL(self.unet_config)
return model_base.SDXL(self)

def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
Expand Down
7 changes: 4 additions & 3 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ def inpaint_model(self):

def __init__(self, unet_config):
self.unet_config = unet_config
self.latent_format = self.latent_format()
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]

def get_model(self, state_dict):
if self.inpaint_model():
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
else:
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))

def process_clip_state_dict(self, state_dict):
return state_dict
Expand Down
6 changes: 5 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=No

output = {}
output["latent_tensor"] = samples["samples"]
output["latent_format_version_0"] = torch.tensor([])

safetensors.torch.save_file(output, file, metadata=metadata)

Expand All @@ -305,7 +306,10 @@ def INPUT_TYPES(s):
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
multiplier = 1.0
if "latent_format_version_0" not in latent:
multiplier = 1.0 / 0.18215
samples = {"samples": latent["latent_tensor"].float() * multiplier}
return (samples, )

@classmethod
Expand Down

0 comments on commit 8607c2d

Please sign in to comment.