Skip to content

Commit

Permalink
Small fixes, bit readmes (#70)
Browse files Browse the repository at this point in the history
* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

* Fix lora inject, added weight self apply lora (#39)

* Develop (#34)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

Co-authored-by: brian6091 <brian6091@gmail.com>

* release : version 0.0.4, now able to tune rank, now add loras dynamically

* readme : add brain6091's discussions

* fix:inject lora in to_out module list

* feat: added weight self apply lora

* chore: add import copy

* fix: readded r

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr>

* Revert "Fix lora inject, added weight self apply lora (#39)" (#40)

This reverts commit fececf3.

* fix : rank bug in monkeypatch

* fix cli fix

* visualizatio on effect of LR

* Fix save_steps, max_train_steps, and logging (#45)

* v 0.0.5 (#42)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

* Fix lora inject, added weight self apply lora (#39)

* Develop (#34)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

Co-authored-by: brian6091 <brian6091@gmail.com>

* release : version 0.0.4, now able to tune rank, now add loras dynamically

* readme : add brain6091's discussions

* fix:inject lora in to_out module list

* feat: added weight self apply lora

* chore: add import copy

* fix: readded r

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr>

* Revert "Fix lora inject, added weight self apply lora (#39)" (#40)

This reverts commit fececf3.

* fix : rank bug in monkeypatch

* fix cli fix

* visualizatio on effect of LR

Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>

* Fix save_steps, max_train_steps, and logging

Corrected indenting so checking save_steps, max_train_steps, and updating logs are performed every step instead at the end of an epoch.

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>

* Enable resuming (#52)

* v 0.0.5 (#42)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

* Fix lora inject, added weight self apply lora (#39)

* Develop (#34)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

Co-authored-by: brian6091 <brian6091@gmail.com>

* release : version 0.0.4, now able to tune rank, now add loras dynamically

* readme : add brain6091's discussions

* fix:inject lora in to_out module list

* feat: added weight self apply lora

* chore: add import copy

* fix: readded r

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr>

* Revert "Fix lora inject, added weight self apply lora (#39)" (#40)

This reverts commit fececf3.

* fix : rank bug in monkeypatch

* fix cli fix

* visualizatio on effect of LR

Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>

* Enable resume training unet/text encoder (#48)

* Enable resume training unet/text encoder

New flags --resume_text_encoder --resume_unet accept the paths to .pt files to resume.
Make sure to change the output directory from the previous training session, or else .pt files will be overwritten since training does not resume from previous global step.

* Load weights from .pt with inject_trainable_lora

Adds new loras argument to inject_trainable_lora function which accepts path to a .pt file containing previously trained weights.

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>

* feat : low-rank pivotal tuning

* feat :  pivotal tuning

* v 0.0.6

* Learning rate switching & fix indent (#57)

* Learning rate switching & fix indent

Make learning rates switch from training textual inversion to unet/text encoder after unfreeze_lora_step.
I think this is how it was explained in the paper linked(?)

Either way, it might be useful to add another parameter to activate unet/text encoder training at a certain step instead of at unfreeze_lora_step.
This would let the user have more control.

Also fix indenting to make save_steps and logging work properly.

* Fix indent

fix accelerator.wait_for_everyone() indent according to original dreambooth training

* Re:Fix indent (#58)

Fix indenting of accelerator.wait_for_everyone()
according to original dreambooth training

* ff now training default

* feat : dataset

* feat : utils to back training

* readme : more contents. citations, etc.

* fix : weight init

Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>
Co-authored-by: hdeezy <82070413+hdeezy@users.noreply.github.com>
  • Loading branch information
4 people committed Dec 21, 2022
1 parent dc1c6d2 commit bae3ef2
Show file tree
Hide file tree
Showing 18 changed files with 1,519 additions and 34 deletions.
71 changes: 68 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
## Main Features

- Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
- Get insanely small end result (3MB for just unet, 6MB for both unet + clip), easy to share and download.
- Get insanely small end result (3MB for just unet, 4MB for both unet + clip + token embedding), easy to share and download.
- Easy to use, compatible with `diffusers`
- Sometimes _even better performance_ than full fine-tuning (but left as future work for extensive comparisons)
- Merge checkpoints + Build recipes by merging LoRAs together
- Fine-tune both CLIP & Unet to gain better results.
- Pipeline to fine-tune CLIP + Unet + token to gain better results.

# Web Demo

Expand All @@ -49,6 +49,12 @@

# UPDATES & Notes

### 2022/12/22

- Pivotal Tuning now available with

### 2022/12/10

- **You can now fine-tune text_encoder as well! Enabled with simple `--train_text_encoder`**
- **Converting to CKPT format for A1111's repo consumption!** (Thanks to [jachiam](https://github.com/jachiam)'s conversion script)
- Img2Img Examples added.
Expand All @@ -71,6 +77,21 @@ This is the key idea of LoRA. We can then fine-tune $A$ and $B$ instead of $W$.

Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

Now, how would we actually use this to update diffusion model? First, we will use Stable-diffusion from [stability-ai](https://stability.ai/). Their model is nicely ported through Huggingface API, so this repo has built various fine-tuning methods around them. In detail, there are three subtle but important distictions in methods to make this work out.

1. [Dreambooth](https://arxiv.org/abs/2208.12242)

First, there is LoRA applied to Dreambooth. The idea is to use prior-preservation class images to regularize the training process, and use low-occuring tokens. This will keep the model's generalization capability while keeping high fidelity. If you turn off prior preservation, and train text encoder embedding as well, it will become naive fine tuning.

2. [Textual Inversion](https://arxiv.org/abs/2208.01618)

Second, there is Textual inversion. There is no room to apply LoRA here, but it is worth mensioning. The idea is to instantiate new token, and learn the token embedding via gradient descent. This is a very powerful method, and it is worth trying out if your use case is not focused on fidelity but rather on inverting conceptual ideas.

3. [Pivotal Tuning](https://arxiv.org/abs/2106.05744)

Last method (although originally proposed for GANs) takes the best of both worlds to further benefit. Wken combined together, this can be implemented as a strict generalization of both methods.
Simply you apply textual inversion to get a matching token embedding. Then, you use the token embedding + prior-preserving class image to fine-tune the model. This two-fold nature make this strict generalization of both methods.

Enough of the lengthy introduction, let's get to the code.

# Installation
Expand Down Expand Up @@ -102,7 +123,7 @@ optimizer = optim.Adam(
)
```

An example of this can be found in `train_lora_dreambooth.py`. Run this example with
A working example of this, applied on [Dreambooth](https://arxiv.org/abs/2208.12242) can be found in `train_lora_dreambooth.py`. Run this example with

```bash
run_lora_db.sh
Expand Down Expand Up @@ -318,3 +339,47 @@ TODOS
- Adaptor-guidance
- Time-aware fine-tuning.
- Test alpha scheduling. I think it will be meaningful.

# References

This work was heavily influenced by, and originated by these awesome researches. I'm just applying them here.

```bibtex
@article{roich2022pivotal,
title={Pivotal tuning for latent-based editing of real images},
author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
journal={ACM Transactions on Graphics (TOG)},
volume={42},
number={1},
pages={1--13},
year={2022},
publisher={ACM New York, NY}
}
```

```bibtex
@article{ruiz2022dreambooth,
title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation},
author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir},
journal={arXiv preprint arXiv:2208.12242},
year={2022}
}
```

```bibtex
@article{gal2022image,
title={An image is worth one word: Personalizing text-to-image generation using textual inversion},
author={Gal, Rinon and Alaluf, Yuval and Atzmon, Yuval and Patashnik, Or and Bermano, Amit H and Chechik, Gal and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2208.01618},
year={2022}
}
```

```
@article{hu2021lora,
title={Lora: Low-rank adaptation of large language models},
author={Hu, Edward J and Shen, Yelong and Wallis, Phillip and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Shean and Wang, Lu and Chen, Weizhu},
journal={arXiv preprint arXiv:2106.09685},
year={2021}
}
```
Binary file added contents/1e-5-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/2e-5-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/2e-6-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/5e-6-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank4-we1e-3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .lora import *
from .dataset import *
279 changes: 279 additions & 0 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from torch.utils.data import Dataset


from PIL import Image
from torchvision import transforms
from pathlib import Path

import random

imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]

imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]


def _randomset(lis):
ret = []
for i in range(len(lis)):
if random.random() < 0.5:
ret.append(lis[i])
return ret


def _shuffle(lis):

return random.sample(lis, len(lis))


class PivotalTuningDatasetTemplate(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""

def __init__(
self,
instance_data_root,
learnable_property,
placeholder_token,
stochastic_attribute,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
color_jitter=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)

self.placeholder_token = placeholder_token
self.stochastic_attribute = stochastic_attribute.split(",")

self.templates = (
imagenet_style_templates_small
if learnable_property == "style"
else imagenet_templates_small
)

self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

text = random.choice(self.templates).format(
", ".join(
[self.placeholder_token]
+ _shuffle(_randomset(self.stochastic_attribute))
)
)

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images]
)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

return example


class PivotalTuningDatasetCapation(Dataset):
def __init__(
self,
instance_data_root,
learnable_property,
placeholder_token,
stochastic_attribute,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
color_jitter=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)

self.placeholder_token = placeholder_token
self.stochastic_attribute = stochastic_attribute.split(",")

self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

text = self.instance_images_path[index % self.num_instance_images].stem

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images]
)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

return example
Loading

0 comments on commit bae3ef2

Please sign in to comment.