Skip to content

Commit

Permalink
Adapted textual inversion distillation for quantization example to la…
Browse files Browse the repository at this point in the history
…test transformers and diffusers packages (#1471)
  • Loading branch information
XinyuYe-Intel committed Apr 15, 2024
1 parent 733697c commit 0ec83b1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
datasets
diffusers==0.4.1
accelerate<=0.19.0
diffusers
accelerate
torchvision
ftfy
tensorboard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ def generate_images(
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
safety_checker=None,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.safety_checker = lambda images, clip_input: (images, False)
int8_model_path = os.path.join(args.pretrained_model_name_or_path, "pytorch_model.bin")
if os.path.exists(int8_model_path):
unet = load(int8_model_path, model=unet)
Expand All @@ -120,7 +119,7 @@ def generate_images(
grid.save(
os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split())))
)
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
os.makedirs(shlex.quote(dirname), exist_ok=True)
dirname = shlex.quote(os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split())))
os.makedirs(dirname, exist_ok=True)
for idx, image in enumerate(images):
image.save(os.path.join(dirname, "{}.png".format(idx+1)))
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def main():
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
project_dir=logging_dir,
)

# If passed along, set the training seed now.
Expand Down Expand Up @@ -782,18 +782,18 @@ def train_func(model):
layer_mappings=[
[['conv_in', ]],
[['time_embedding', ]],
[['down_blocks.0.attentions.0',]],
[['down_blocks.0.attentions.1',]],
[['down_blocks.0.attentions.0', '0']],
[['down_blocks.0.attentions.1', '0']],
[['down_blocks.0.resnets.0',]],
[['down_blocks.0.resnets.1',]],
[['down_blocks.0.downsamplers.0',]],
[['down_blocks.1.attentions.0',]],
[['down_blocks.1.attentions.1',]],
[['down_blocks.1.attentions.0', '0']],
[['down_blocks.1.attentions.1', '0']],
[['down_blocks.1.resnets.0',]],
[['down_blocks.1.resnets.1',]],
[['down_blocks.1.downsamplers.0',]],
[['down_blocks.2.attentions.0',]],
[['down_blocks.2.attentions.1',]],
[['down_blocks.2.attentions.0', '0']],
[['down_blocks.2.attentions.1', '0']],
[['down_blocks.2.resnets.0',]],
[['down_blocks.2.resnets.1',]],
[['down_blocks.2.downsamplers.0',]],
Expand All @@ -803,27 +803,27 @@ def train_func(model):
[['up_blocks.0.resnets.1', ]],
[['up_blocks.0.resnets.2', ]],
[['up_blocks.0.upsamplers.0', ]],
[['up_blocks.1.attentions.0', ]],
[['up_blocks.1.attentions.1', ]],
[['up_blocks.1.attentions.2', ]],
[['up_blocks.1.attentions.0', '0']],
[['up_blocks.1.attentions.1', '0']],
[['up_blocks.1.attentions.2', '0']],
[['up_blocks.1.resnets.0', ]],
[['up_blocks.1.resnets.1', ]],
[['up_blocks.1.resnets.2', ]],
[['up_blocks.1.upsamplers.0', ]],
[['up_blocks.2.attentions.0', ]],
[['up_blocks.2.attentions.1', ]],
[['up_blocks.2.attentions.2', ]],
[['up_blocks.2.attentions.0', '0']],
[['up_blocks.2.attentions.1', '0']],
[['up_blocks.2.attentions.2', '0']],
[['up_blocks.2.resnets.0', ]],
[['up_blocks.2.resnets.1', ]],
[['up_blocks.2.resnets.2', ]],
[['up_blocks.2.upsamplers.0', ]],
[['up_blocks.3.attentions.0', ]],
[['up_blocks.3.attentions.1', ]],
[['up_blocks.3.attentions.2', ]],
[['up_blocks.3.attentions.0', '0']],
[['up_blocks.3.attentions.1', '0']],
[['up_blocks.3.attentions.2', '0']],
[['up_blocks.3.resnets.0', ]],
[['up_blocks.3.resnets.1', ]],
[['up_blocks.3.resnets.2', ]],
[['mid_block.attentions.0', ]],
[['mid_block.attentions.0', '0']],
[['mid_block.resnets.0', ]],
[['mid_block.resnets.1', ]],
[['conv_out', ]],
Expand Down

0 comments on commit 0ec83b1

Please sign in to comment.