diff --git a/main.py b/main.py index 0cf5565..2028c23 100644 --- a/main.py +++ b/main.py @@ -26,8 +26,9 @@ def main(config): dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args) trainer.prepare_modules_for_training() - trainer.prepare_network(config.network) + trainer.prepare_controlnet(config.controlnet) + trainer.apply_module_settings() trainer.prepare_optimizer() @@ -63,7 +64,7 @@ def main(config): if current_step % save_interval == 0 or current_step == total_steps - 1: trainer.save_model(config.main.output_path) if current_step % sample_interval == 0 or current_step == total_steps - 1: - images = trainer.sample_validation(current_step) + images = trainer.sample_validation(batch) if wandb_run is not None: images = [wandb.Image(image, caption=config.trainer.validation_args.prompt) for image in images] wandb_run.log({'images': images}, step=current_step) diff --git a/modules/config.py b/modules/config.py index 8496084..33593d6 100644 --- a/modules/config.py +++ b/modules/config.py @@ -38,6 +38,7 @@ class TrainerConfig: lr_scheduler: str = "constant" gradient_checkpointing: bool = False optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + merging_loras: Optional[List[str]] = None validation_num_samples: int = 4 validation_seed: int = 4545 validation_args: Dict[str, Any] = field(default_factory=dict) @@ -53,6 +54,7 @@ class DatasetArgs: caption: Optional[str] = "captions" image: Optional[str] = None text_emb: Optional[str] = None + control: Optional[str] = None prompt: Optional[str] = None prefix: str = "" shuffle: bool = False @@ -84,13 +86,23 @@ class NetworkArgs: @dataclass class NetworkConfig: + module: str = "networks.manager.NetworkManager" + resume: Optional[str] = None train: bool = MISSING args: NetworkArgs = field(default_factory=NetworkArgs) +@dataclass +class ControlNetArgs: + train: bool = MISSING + resume: Optional[str] = None + transformer_layers_per_block: Optional[List[int]] = None + global_average_pooling: bool = False + @dataclass class Config: main: MainConfig = field(default_factory=MainConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) dataset: DatasetConfig = field(default_factory=DatasetConfig) dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig) - network: Optional[NetworkConfig] = None \ No newline at end of file + network: Optional[NetworkConfig] = None + controlnet: Optional[ControlNetArgs] = None \ No newline at end of file diff --git a/modules/controlnet/canny_dataset.py b/modules/controlnet/canny_dataset.py new file mode 100644 index 0000000..4519791 --- /dev/null +++ b/modules/controlnet/canny_dataset.py @@ -0,0 +1,24 @@ +from modules.dataset import BaseDataset +import cv2 +import os +import torch +import numpy as np +from torchvision import transforms + +class CannyDataset(BaseDataset): + def get_control(self, samples, dir="control"): + images = [] + transform = transforms.ToTensor() + for sample in samples: + # ref https://qiita.com/kotai2003/items/662c33c15915f2a8517e + image = cv2.imread(os.path.join(self.path, dir, sample + f".png")) + med_val = np.median(image) + sigma = 0.33 # 0.33 + min_val = int(max(0, (1.0 - sigma) * med_val)) + max_val = int(max(255, (1.0 + sigma) * med_val)) + image = cv2.Canny(image, threshold1 = min_val, threshold2 = max_val) + image = image[:, :, None] # add channel + image = np.concatenate([image]*3, axis=2) # grayscale to rgb + images.append(transform(image)) + images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float() + return images_tensor \ No newline at end of file diff --git a/modules/dataset.py b/modules/dataset.py index 629aa27..81a5b99 100644 --- a/modules/dataset.py +++ b/modules/dataset.py @@ -24,6 +24,7 @@ def __init__( caption: Optional[str] = "captions", image: Optional[str] = None, text_emb: Optional[str] = None, + control: Optional[str] = None, prompt: Optional[str] = None, prefix: str = "", shuffle: bool = False, @@ -46,6 +47,7 @@ def __init__( self.caption = caption self.image = image self.text_emb = text_emb + self.control = control self.prompt = prompt # 全ての画像のcaptionをpromptにする self.prefix = prefix # captionのprefix self.shuffle = shuffle # バッチの取り出し方をシャッフルするかどうか(データローダー側でシャッフルした方が良い^^) @@ -89,6 +91,9 @@ def __getitem__(self, i): else: batch["captions"] = self.get_captions(samples, self.caption) + if self.control: + batch["controlnet_hint"] = self.get_control(samples, self.control if isinstance(self.control, str) else "control") + return batch # バッチの取り出し方を初期化するメソッド @@ -180,4 +185,13 @@ def get_text_embeddings(self, samples, dir="text_emb"): for sample in samples ]) pooled_outputs.to(memory_format=torch.contiguous_format).float() - return encoder_hidden_states, pooled_outputs \ No newline at end of file + return encoder_hidden_states, pooled_outputs + + def get_control(self, samples, dir="control"): + images = [] + transform = transforms.ToTensor() + for sample in samples: + image = Image.open(os.path.join(self.path, dir, sample + f".png")).convert("RGB") + images.append(transform(image)) + images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float() + return images_tensor \ No newline at end of file diff --git a/modules/diffusion_model.py b/modules/diffusion_model.py index ccc2e52..c41c249 100644 --- a/modules/diffusion_model.py +++ b/modules/diffusion_model.py @@ -1,18 +1,20 @@ import torch import torch.nn as nn -from diffusers import UNet2DConditionModel +from diffusers import UNet2DConditionModel, ControlNetModel class DiffusionModel(nn.Module): def __init__( self, unet:UNet2DConditionModel, + controlnet:ControlNetModel=None, sdxl:bool=False, ): super().__init__() self.unet = unet + self.controlnet = controlnet self.sdxl = sdxl - def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None): + def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None, controlnet_hint=None): if self.sdxl: if size_condition is None: h, w = latents.shape[2] * 8, latents.shape[3] * 8 @@ -22,18 +24,58 @@ def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size else: added_cond_kwargs = None + if self.controlnet is not None: + assert controlnet_hint is not None, "controlnet_hint is required when controlnet is enabled" + down_block_additional_residuals, mid_block_additional_residual = self.controlnet( + latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_hint, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + else: + down_block_additional_residuals = None + mid_block_additional_residual = None + model_output = self.unet( latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, ).sample return model_output - def enable_gradient_checkpointing(self, enable:bool=True): - if enable: - self.unet.enable_gradient_checkpointing() + def create_controlnet(self, config): + if config.resume is not None: + pre_controlnet = ControlNetModel.from_pretrained(config.resume) + else: + pre_controlnet = ControlNetModel.from_unet(self.unet) + + if config.transformer_layers_per_block is not None: + down_block_types = tuple(["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in config.transformer_layers_per_block]) + transformer_layers_per_block = tuple([int(x) for x in config.transformer_layers_per_block]) + self.controlnet = ControlNetModel.from_config( + pre_controlnet.config, + down_block_types=down_block_types, + transformer_layers_per_block=transformer_layers_per_block, + ) + self.controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False) + del pre_controlnet else: - self.unet.disable_gradient_checkpointing() + self.controlnet = pre_controlnet + + self.controlnet.config.global_pool_conditions = config.global_average_pooling + + + def enable_gradient_checkpointing(self, enable:bool=True): + for model in [self.unet, self.controlnet]: + if model is not None: + if enable: + model.enable_gradient_checkpointing() + else: + model.disable_gradient_checkpointing() \ No newline at end of file diff --git a/modules/trainer.py b/modules/trainer.py index 5cfa2df..5c48b14 100644 --- a/modules/trainer.py +++ b/modules/trainer.py @@ -23,6 +23,7 @@ "trained", "trained/models", "trained/networks", + "trained/controlnet", ] for directory in DIRECTORIES: @@ -39,6 +40,15 @@ def __init__(self, config, diffusion:DiffusionModel, text_model:TextModel, vae:A self.scheduler = BaseScheduler(scheduler.config.prediction_type == "v_prediction") self.sdxl = text_model.sdxl + if config is not None and config.merging_loras: + for lora in config.merging_loras: + NetworkManager( + text_model=self.text_model, + unet=self.diffusion.unet, + file_name=lora, + mode="merge" + ) + @torch.no_grad() def decode_latents(self, latents): self.vae.to("cuda") @@ -116,13 +126,6 @@ def prepare_modules_for_training(self, device="cuda"): self.text_model.requires_grad_(config.train_text_encoder) self.vae.eval() - if config.gradient_checkpointing: - self.diffusion.enable_gradient_checkpointing() - self.text_model.enable_gradient_checkpointing() - self.diffusion.unet.train() # trainでないと適用されない。 - self.text_model.train() - logger.info("勾配チェックポイントを有効にしてみたよ!") - def prepare_network(self, config): if config is None: self.network = None @@ -130,11 +133,15 @@ def prepare_network(self, config): logger.info("ネットワークはないみたい。") return - self.network = NetworkManager( + manager_cls = get_attr_from_config(config.module) + self.network = manager_cls( text_model=self.text_model, unet=self.diffusion.unet, **config.args ) + + if config.resume: + self.network.load_weights(config.resume) self.network_train = config.train @@ -143,6 +150,30 @@ def prepare_network(self, config): self.network.requires_grad_(self.network_train) logger.info("ネットワークを作ったよ!") + + def prepare_controlnet(self, config): + if config is None: + self.controlnet = None + self.controlnet_train = False + logger.info("コントロールネットはないみたい。") + return + + self.diffusion.create_controlnet(config) + self.controlnet_train = config.train + + self.diffusion.controlnet.to(self.device, self.train_dtype if self.controlnet_train else self.weight_dtype) + self.diffusion.controlnet.train(self.controlnet_train) + self.diffusion.controlnet.requires_grad_(self.controlnet_train) + + logger.info("コントロールネットを作ったよ!") + + def apply_module_settings(self): + if self.config.gradient_checkpointing: + self.diffusion.enable_gradient_checkpointing() + self.text_model.enable_gradient_checkpointing() + self.diffusion.train() # trainでないと適用されない。 + self.text_model.train() + logger.info("勾配チェックポイントを有効にしてみたよ!") def prepare_network_from_file(self, file_name): self.network = NetworkManager( @@ -162,8 +193,10 @@ def prepare_optimizer(self): params += [{"params":self.diffusion.unet.parameters(), "lr":unet_lr}] if self.config.train_text_encoder: params += [{"params":self.text_model.parameters(), "lr":text_lr}] - if self.network: + if self.network_train: params += self.network.prepare_optimizer_params(text_lr, unet_lr) + if self.controlnet_train: + params += [{"params":self.diffusion.controlnet.parameters(), "lr":unet_lr}] optimizer_cls = get_attr_from_config(self.config.optimizer.module) self.optimizer = optimizer_cls(params, **self.config.optimizer.args or {}) @@ -204,12 +237,19 @@ def loss(self, batch): else: size_condition = None + if "controlnet_hint" in batch: + controlnet_hint = batch["controlnet_hint"].to(self.device) + if hasattr(self.network, "set_controlnet_hint"): + self.network.set_controlnet_hint(controlnet_hint) + else: + controlnet_hint = None + timesteps = torch.randint(0, 1000, (self.batch_size,), device=latents.device) noise = torch.randn_like(latents) noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) with torch.autocast("cuda", dtype=self.autocast_dtype): - model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition) + model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition, controlnet_hint) target = self.scheduler.get_target(latents, noise, timesteps) # v_predictionの場合はvelocityになる @@ -240,7 +280,21 @@ def step(self, batch): return logs @torch.no_grad() - def sample(self, prompt="", negative_prompt="", batch_size=1, height=768, width=768, num_inference_steps=30, guidance_scale=7.0, denoise=1.0, seed=4545, images=None): + def sample( + self, + prompt="", + negative_prompt="", + batch_size=1, + height=768, + width=768, + num_inference_steps=30, + guidance_scale=7.0, + denoise=1.0, + seed=4545, + images=None, + controlnet_hint=None, + **kwargs + ): rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() @@ -271,12 +325,23 @@ def sample(self, prompt="", negative_prompt="", batch_size=1, height=768, width= encoder_hidden_states, pooled_output = self.text_model(prompt) self.text_model.to(self.te_device) + if controlnet_hint is not None: + if isinstance(controlnet_hint, str): + controlnet_hint = Image.open(controlnet_hint).convert("RGB") + controlnet_hint = transforms.ToTensor()(controlnet_hint).unsqueeze(0) + controlnet_hint = controlnet_hint.to(self.device) + if guidance_scale != 1.0: + controlnet_hint = torch.cat([controlnet_hint] *2) + + if hasattr(self.network, "set_controlnet_hint"): + self.network.set_controlnet_hint(controlnet_hint) + progress_bar = tqdm(timesteps, desc="Sampling", leave=False, total=len(timesteps)) for i, t in enumerate(timesteps): with torch.autocast("cuda", dtype=self.autocast_dtype): latents_input = torch.cat([latents] * (2 if guidance_scale != 1.0 else 1), dim=0) - model_output = self.diffusion(latents_input, t, encoder_hidden_states, pooled_output) + model_output = self.diffusion(latents_input, t, encoder_hidden_states, pooled_output, controlnet_hint=controlnet_hint) if guidance_scale != 1.0: uncond, cond = model_output.chunk(2) @@ -302,6 +367,8 @@ def save_model(self, output_path): self.save_pretrained(os.path.join("trained/models", output_path)) if self.network_train: self.network.save_weights(os.path.join("trained/networks", output_path)) + if self.controlnet_train: + self.diffusion.controlnet.save_pretrained(os.path.join("trained/controlnet", output_path)) def sample_validation(self, step): logger.info(f"サンプルを生成するよ!") @@ -342,6 +409,6 @@ def from_pretrained(cls, path, sdxl, clip_skip=None, config=None, network=None): if clip_skip is None: clip_skip = -2 if sdxl else -1 text_model, vae, unet, scheduler = load_model(path, sdxl, clip_skip) - diffusion = DiffusionModel(unet, sdxl) + diffusion = DiffusionModel(unet, sdxl=sdxl) return cls(config, diffusion, text_model, vae, scheduler, network) diff --git a/networks/lortnoc/manager.py b/networks/lortnoc/manager.py new file mode 100644 index 0000000..c4c0b93 --- /dev/null +++ b/networks/lortnoc/manager.py @@ -0,0 +1,111 @@ +from networks.manager import NetworkManager +import torch +import torch.nn as nn +import torch.nn.functional as F + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + +# https://github.com/huggingface/diffusers/blob/687bc2772721af584d649129f8d2a28ca56a9ad8/src/diffusers/models/controlnet.py#L66 +class ControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + +class LoRTnoCManager(NetworkManager): + def __init__( + self, + text_model, + unet, + module=None, + file_name=None, + module_args=None, + unet_key_filters=None, + conv_module_args=None, + text_module_args=None, + multiplier=1.0, + mode="apply", # select "apply" or "merge" + in_channels=3, + ): + super().__init__( + text_model, + unet, + module, + file_name, + module_args, + unet_key_filters, + conv_module_args, + text_module_args, + multiplier, + mode, + ) + self.in_channels = in_channels + self.hidden_channels = unet.conv_in.out_channels + + self.conditioning_embedding = ControlNetConditioningEmbedding(self.hidden_channels, in_channels) + self.org_conv_in = [unet.conv_in] + + if mode == "apply": + self.org_conv_in_forward = self.org_conv_in[0].forward + self.org_conv_in[0].forward = self.forward_hook(self.org_conv_in_forward) + elif mode == "merge": + raise NotImplementedError("merge mode is not supported yet for LoRTnoCManager.") + else: + raise ValueError(f"mode {self.mode} is not supported.") + + def set_controlnet_hint(self, hint): + self.hint = hint + + def forward_hook(self, forward): + def hook(x): + hint = self.conditioning_embedding(self.hint) + return forward(x) + hint + return hook + + def apply_to(self, multiplier=None): + super().apply_to(multiplier) + if hasattr(self, "org_conv_in"): + if not hasattr(self, "org_conv_in_forward"): + self.org_conv_in_forward = self.org_conv_in[0].forward + self.org_conv_in[0].forward = self.forward_hook(self.org_conv_in_forward) + + def unapply_to(self): + super().unapply_to() + self.org_conv_in[0].forward = self.org_conv_in_forward + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr): + optimizer_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr) + optimizer_params += [{"params": self.conditioning_embedding.parameters(), "lr": unet_lr}] + return optimizer_params