-
Notifications
You must be signed in to change notification settings - Fork 473
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme * Fix lora inject, added weight self apply lora (#39) * Develop (#34) * Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme Co-authored-by: brian6091 <brian6091@gmail.com> * release : version 0.0.4, now able to tune rank, now add loras dynamically * readme : add brain6091's discussions * fix:inject lora in to_out module list * feat: added weight self apply lora * chore: add import copy * fix: readded r Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr> * Revert "Fix lora inject, added weight self apply lora (#39)" (#40) This reverts commit fececf3. * fix : rank bug in monkeypatch * fix cli fix * visualizatio on effect of LR * Fix save_steps, max_train_steps, and logging (#45) * v 0.0.5 (#42) * Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme * Fix lora inject, added weight self apply lora (#39) * Develop (#34) * Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme Co-authored-by: brian6091 <brian6091@gmail.com> * release : version 0.0.4, now able to tune rank, now add loras dynamically * readme : add brain6091's discussions * fix:inject lora in to_out module list * feat: added weight self apply lora * chore: add import copy * fix: readded r Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr> * Revert "Fix lora inject, added weight self apply lora (#39)" (#40) This reverts commit fececf3. * fix : rank bug in monkeypatch * fix cli fix * visualizatio on effect of LR Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: Davide Paglieri <paglieridavide@gmail.com> * Fix save_steps, max_train_steps, and logging Corrected indenting so checking save_steps, max_train_steps, and updating logs are performed every step instead at the end of an epoch. Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: Davide Paglieri <paglieridavide@gmail.com> * Enable resuming (#52) * v 0.0.5 (#42) * Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme * Fix lora inject, added weight self apply lora (#39) * Develop (#34) * Add parameter to control rank of decomposition (#28) * ENH: allow controlling rank of approximation * Training script accepts lora_rank * feat : statefully monkeypatch different loras + example ipynb + readme Co-authored-by: brian6091 <brian6091@gmail.com> * release : version 0.0.4, now able to tune rank, now add loras dynamically * readme : add brain6091's discussions * fix:inject lora in to_out module list * feat: added weight self apply lora * chore: add import copy * fix: readded r Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr> * Revert "Fix lora inject, added weight self apply lora (#39)" (#40) This reverts commit fececf3. * fix : rank bug in monkeypatch * fix cli fix * visualizatio on effect of LR Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: Davide Paglieri <paglieridavide@gmail.com> * Enable resume training unet/text encoder (#48) * Enable resume training unet/text encoder New flags --resume_text_encoder --resume_unet accept the paths to .pt files to resume. Make sure to change the output directory from the previous training session, or else .pt files will be overwritten since training does not resume from previous global step. * Load weights from .pt with inject_trainable_lora Adds new loras argument to inject_trainable_lora function which accepts path to a .pt file containing previously trained weights. Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: Davide Paglieri <paglieridavide@gmail.com> * feat : low-rank pivotal tuning * feat : pivotal tuning * v 0.0.6 * Learning rate switching & fix indent (#57) * Learning rate switching & fix indent Make learning rates switch from training textual inversion to unet/text encoder after unfreeze_lora_step. I think this is how it was explained in the paper linked(?) Either way, it might be useful to add another parameter to activate unet/text encoder training at a certain step instead of at unfreeze_lora_step. This would let the user have more control. Also fix indenting to make save_steps and logging work properly. * Fix indent fix accelerator.wait_for_everyone() indent according to original dreambooth training * Re:Fix indent (#58) Fix indenting of accelerator.wait_for_everyone() according to original dreambooth training * ff now training default * feat : dataset * feat : utils to back training * readme : more contents. citations, etc. * fix : weight init Co-authored-by: brian6091 <brian6091@gmail.com> Co-authored-by: Davide Paglieri <paglieridavide@gmail.com> Co-authored-by: hdeezy <82070413+hdeezy@users.noreply.github.com>
- Loading branch information
1 parent
dc1c6d2
commit bae3ef2
Showing
18 changed files
with
1,519 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .lora import * | ||
from .dataset import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
from torch.utils.data import Dataset | ||
|
||
|
||
from PIL import Image | ||
from torchvision import transforms | ||
from pathlib import Path | ||
|
||
import random | ||
|
||
imagenet_templates_small = [ | ||
"a photo of a {}", | ||
"a rendering of a {}", | ||
"a cropped photo of the {}", | ||
"the photo of a {}", | ||
"a photo of a clean {}", | ||
"a photo of a dirty {}", | ||
"a dark photo of the {}", | ||
"a photo of my {}", | ||
"a photo of the cool {}", | ||
"a close-up photo of a {}", | ||
"a bright photo of the {}", | ||
"a cropped photo of a {}", | ||
"a photo of the {}", | ||
"a good photo of the {}", | ||
"a photo of one {}", | ||
"a close-up photo of the {}", | ||
"a rendition of the {}", | ||
"a photo of the clean {}", | ||
"a rendition of a {}", | ||
"a photo of a nice {}", | ||
"a good photo of a {}", | ||
"a photo of the nice {}", | ||
"a photo of the small {}", | ||
"a photo of the weird {}", | ||
"a photo of the large {}", | ||
"a photo of a cool {}", | ||
"a photo of a small {}", | ||
] | ||
|
||
imagenet_style_templates_small = [ | ||
"a painting in the style of {}", | ||
"a rendering in the style of {}", | ||
"a cropped painting in the style of {}", | ||
"the painting in the style of {}", | ||
"a clean painting in the style of {}", | ||
"a dirty painting in the style of {}", | ||
"a dark painting in the style of {}", | ||
"a picture in the style of {}", | ||
"a cool painting in the style of {}", | ||
"a close-up painting in the style of {}", | ||
"a bright painting in the style of {}", | ||
"a cropped painting in the style of {}", | ||
"a good painting in the style of {}", | ||
"a close-up painting in the style of {}", | ||
"a rendition in the style of {}", | ||
"a nice painting in the style of {}", | ||
"a small painting in the style of {}", | ||
"a weird painting in the style of {}", | ||
"a large painting in the style of {}", | ||
] | ||
|
||
|
||
def _randomset(lis): | ||
ret = [] | ||
for i in range(len(lis)): | ||
if random.random() < 0.5: | ||
ret.append(lis[i]) | ||
return ret | ||
|
||
|
||
def _shuffle(lis): | ||
|
||
return random.sample(lis, len(lis)) | ||
|
||
|
||
class PivotalTuningDatasetTemplate(Dataset): | ||
""" | ||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | ||
It pre-processes the images and the tokenizes prompts. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
instance_data_root, | ||
learnable_property, | ||
placeholder_token, | ||
stochastic_attribute, | ||
tokenizer, | ||
class_data_root=None, | ||
class_prompt=None, | ||
size=512, | ||
center_crop=False, | ||
color_jitter=False, | ||
): | ||
self.size = size | ||
self.center_crop = center_crop | ||
self.tokenizer = tokenizer | ||
|
||
self.instance_data_root = Path(instance_data_root) | ||
if not self.instance_data_root.exists(): | ||
raise ValueError("Instance images root doesn't exists.") | ||
|
||
self.instance_images_path = list(Path(instance_data_root).iterdir()) | ||
self.num_instance_images = len(self.instance_images_path) | ||
|
||
self.placeholder_token = placeholder_token | ||
self.stochastic_attribute = stochastic_attribute.split(",") | ||
|
||
self.templates = ( | ||
imagenet_style_templates_small | ||
if learnable_property == "style" | ||
else imagenet_templates_small | ||
) | ||
|
||
self._length = self.num_instance_images | ||
|
||
if class_data_root is not None: | ||
self.class_data_root = Path(class_data_root) | ||
self.class_data_root.mkdir(parents=True, exist_ok=True) | ||
self.class_images_path = list(self.class_data_root.iterdir()) | ||
self.num_class_images = len(self.class_images_path) | ||
self._length = max(self.num_class_images, self.num_instance_images) | ||
self.class_prompt = class_prompt | ||
else: | ||
self.class_data_root = None | ||
|
||
self.image_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize( | ||
size, interpolation=transforms.InterpolationMode.BILINEAR | ||
), | ||
transforms.CenterCrop(size) | ||
if center_crop | ||
else transforms.RandomCrop(size), | ||
transforms.ColorJitter(0.2, 0.1) | ||
if color_jitter | ||
else transforms.Lambda(lambda x: x), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
|
||
def __len__(self): | ||
return self._length | ||
|
||
def __getitem__(self, index): | ||
example = {} | ||
instance_image = Image.open( | ||
self.instance_images_path[index % self.num_instance_images] | ||
) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["instance_images"] = self.image_transforms(instance_image) | ||
|
||
text = random.choice(self.templates).format( | ||
", ".join( | ||
[self.placeholder_token] | ||
+ _shuffle(_randomset(self.stochastic_attribute)) | ||
) | ||
) | ||
|
||
example["instance_prompt_ids"] = self.tokenizer( | ||
text, | ||
padding="do_not_pad", | ||
truncation=True, | ||
max_length=self.tokenizer.model_max_length, | ||
).input_ids | ||
|
||
if self.class_data_root: | ||
class_image = Image.open( | ||
self.class_images_path[index % self.num_class_images] | ||
) | ||
if not class_image.mode == "RGB": | ||
class_image = class_image.convert("RGB") | ||
example["class_images"] = self.image_transforms(class_image) | ||
example["class_prompt_ids"] = self.tokenizer( | ||
self.class_prompt, | ||
padding="do_not_pad", | ||
truncation=True, | ||
max_length=self.tokenizer.model_max_length, | ||
).input_ids | ||
|
||
return example | ||
|
||
|
||
class PivotalTuningDatasetCapation(Dataset): | ||
def __init__( | ||
self, | ||
instance_data_root, | ||
learnable_property, | ||
placeholder_token, | ||
stochastic_attribute, | ||
tokenizer, | ||
class_data_root=None, | ||
class_prompt=None, | ||
size=512, | ||
center_crop=False, | ||
color_jitter=False, | ||
): | ||
self.size = size | ||
self.center_crop = center_crop | ||
self.tokenizer = tokenizer | ||
|
||
self.instance_data_root = Path(instance_data_root) | ||
if not self.instance_data_root.exists(): | ||
raise ValueError("Instance images root doesn't exists.") | ||
|
||
self.instance_images_path = list(Path(instance_data_root).iterdir()) | ||
self.num_instance_images = len(self.instance_images_path) | ||
|
||
self.placeholder_token = placeholder_token | ||
self.stochastic_attribute = stochastic_attribute.split(",") | ||
|
||
self._length = self.num_instance_images | ||
|
||
if class_data_root is not None: | ||
self.class_data_root = Path(class_data_root) | ||
self.class_data_root.mkdir(parents=True, exist_ok=True) | ||
self.class_images_path = list(self.class_data_root.iterdir()) | ||
self.num_class_images = len(self.class_images_path) | ||
self._length = max(self.num_class_images, self.num_instance_images) | ||
self.class_prompt = class_prompt | ||
else: | ||
self.class_data_root = None | ||
|
||
self.image_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize( | ||
size, interpolation=transforms.InterpolationMode.BILINEAR | ||
), | ||
transforms.CenterCrop(size) | ||
if center_crop | ||
else transforms.RandomCrop(size), | ||
transforms.ColorJitter(0.2, 0.1) | ||
if color_jitter | ||
else transforms.Lambda(lambda x: x), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
|
||
def __len__(self): | ||
return self._length | ||
|
||
def __getitem__(self, index): | ||
example = {} | ||
instance_image = Image.open( | ||
self.instance_images_path[index % self.num_instance_images] | ||
) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["instance_images"] = self.image_transforms(instance_image) | ||
|
||
text = self.instance_images_path[index % self.num_instance_images].stem | ||
|
||
example["instance_prompt_ids"] = self.tokenizer( | ||
text, | ||
padding="do_not_pad", | ||
truncation=True, | ||
max_length=self.tokenizer.model_max_length, | ||
).input_ids | ||
|
||
if self.class_data_root: | ||
class_image = Image.open( | ||
self.class_images_path[index % self.num_class_images] | ||
) | ||
if not class_image.mode == "RGB": | ||
class_image = class_image.convert("RGB") | ||
example["class_images"] = self.image_transforms(class_image) | ||
example["class_prompt_ids"] = self.tokenizer( | ||
self.class_prompt, | ||
padding="do_not_pad", | ||
truncation=True, | ||
max_length=self.tokenizer.model_max_length, | ||
).input_ids | ||
|
||
return example |
Oops, something went wrong.