Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import os
import random
from pathlib import Path
from typing import Optional

import accelerate
import numpy as np
Expand All @@ -31,7 +30,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
Expand Down Expand Up @@ -661,16 +660,6 @@ def collate_fn(examples):
}


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -704,22 +693,14 @@ def main(args):

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
Expand Down Expand Up @@ -1053,7 +1034,12 @@ def load_model_hook(models, input_dir):
controlnet.save_pretrained(args.output_dir)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()

Expand Down
46 changes: 16 additions & 30 deletions examples/controlnet/train_controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import os
import random
from pathlib import Path
from typing import Optional

import jax
import jax.numpy as jnp
Expand All @@ -33,7 +32,7 @@
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from PIL import Image
from torch.utils.data import IterableDataset
from torchvision import transforms
Expand Down Expand Up @@ -148,7 +147,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
return image_logs


def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None):
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
for i, log in enumerate(image_logs):
images = log["images"]
Expand All @@ -174,7 +173,7 @@ def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None
---
"""
model_card = f"""
# controlnet- {repo_name}
# controlnet- {repo_id}

These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
{img_str}
Expand Down Expand Up @@ -612,16 +611,6 @@ def collate_fn(examples):
return batch


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def get_params_to_save(params):
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))

Expand Down Expand Up @@ -656,22 +645,14 @@ def main():

# Handle the repository creation
if jax.process_index() == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_url = create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_url, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
Expand Down Expand Up @@ -1020,12 +1001,17 @@ def cumul_grad_step(grad_idx, loss_grad_rng):

if args.push_to_hub:
save_model_card(
repo_name,
repo_id,
image_logs=image_logs,
base_model=args.pretrained_model_name_or_path,
repo_folder=args.output_dir,
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)


if __name__ == "__main__":
Expand Down
40 changes: 13 additions & 27 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import warnings
from pathlib import Path
from typing import Optional

import accelerate
import numpy as np
Expand All @@ -32,7 +31,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
Expand Down Expand Up @@ -575,16 +574,6 @@ def __getitem__(self, index):
return example


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -677,22 +666,14 @@ def main(args):

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
Expand Down Expand Up @@ -1043,7 +1024,12 @@ def load_model_hook(models, input_dir):
pipeline.save_pretrained(args.output_dir)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()

Expand Down
47 changes: 16 additions & 31 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
import warnings
from pathlib import Path
from typing import Optional

import numpy as np
import torch
Expand All @@ -30,7 +29,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
Expand Down Expand Up @@ -59,7 +58,7 @@
logger = get_logger(__name__)


def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
Expand All @@ -80,7 +79,7 @@ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_fol
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_name}
# LoRA DreamBooth - {repo_id}

These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}
Expand Down Expand Up @@ -528,16 +527,6 @@ def __getitem__(self, index):
return example


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -625,23 +614,14 @@ def main(args):

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id

create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
Expand Down Expand Up @@ -1027,13 +1007,18 @@ def main(args):

if args.push_to_hub:
save_model_card(
repo_name,
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()

Expand Down
Loading