Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support controlnet #46

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def main(config):
dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args)

trainer.prepare_modules_for_training()

trainer.prepare_network(config.network)
trainer.prepare_controlnet(config.controlnet)
trainer.apply_module_settings()

trainer.prepare_optimizer()

Expand Down Expand Up @@ -63,7 +64,7 @@ def main(config):
if current_step % save_interval == 0 or current_step == total_steps - 1:
trainer.save_model(config.main.output_path)
if current_step % sample_interval == 0 or current_step == total_steps - 1:
images = trainer.sample_validation(current_step)
images = trainer.sample_validation(batch)
if wandb_run is not None:
images = [wandb.Image(image, caption=config.trainer.validation_args.prompt) for image in images]
wandb_run.log({'images': images}, step=current_step)
Expand Down
14 changes: 13 additions & 1 deletion modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TrainerConfig:
lr_scheduler: str = "constant"
gradient_checkpointing: bool = False
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
merging_loras: Optional[List[str]] = None
validation_num_samples: int = 4
validation_seed: int = 4545
validation_args: Dict[str, Any] = field(default_factory=dict)
Expand All @@ -53,6 +54,7 @@ class DatasetArgs:
caption: Optional[str] = "captions"
image: Optional[str] = None
text_emb: Optional[str] = None
control: Optional[str] = None
prompt: Optional[str] = None
prefix: str = ""
shuffle: bool = False
Expand Down Expand Up @@ -84,13 +86,23 @@ class NetworkArgs:

@dataclass
class NetworkConfig:
module: str = "networks.manager.NetworkManager"
resume: Optional[str] = None
train: bool = MISSING
args: NetworkArgs = field(default_factory=NetworkArgs)

@dataclass
class ControlNetArgs:
train: bool = MISSING
resume: Optional[str] = None
transformer_layers_per_block: Optional[List[int]] = None
global_average_pooling: bool = False

@dataclass
class Config:
main: MainConfig = field(default_factory=MainConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
network: Optional[NetworkConfig] = None
network: Optional[NetworkConfig] = None
controlnet: Optional[ControlNetArgs] = None
24 changes: 24 additions & 0 deletions modules/controlnet/canny_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from modules.dataset import BaseDataset
import cv2
import os
import torch
import numpy as np
from torchvision import transforms

class CannyDataset(BaseDataset):
def get_control(self, samples, dir="control"):
images = []
transform = transforms.ToTensor()
for sample in samples:
# ref https://qiita.com/kotai2003/items/662c33c15915f2a8517e
image = cv2.imread(os.path.join(self.path, dir, sample + f".png"))
med_val = np.median(image)
sigma = 0.33 # 0.33
min_val = int(max(0, (1.0 - sigma) * med_val))
max_val = int(max(255, (1.0 + sigma) * med_val))
image = cv2.Canny(image, threshold1 = min_val, threshold2 = max_val)
image = image[:, :, None] # add channel
image = np.concatenate([image]*3, axis=2) # grayscale to rgb
images.append(transform(image))
images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
return images_tensor
16 changes: 15 additions & 1 deletion modules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
caption: Optional[str] = "captions",
image: Optional[str] = None,
text_emb: Optional[str] = None,
control: Optional[str] = None,
prompt: Optional[str] = None,
prefix: str = "",
shuffle: bool = False,
Expand All @@ -46,6 +47,7 @@ def __init__(
self.caption = caption
self.image = image
self.text_emb = text_emb
self.control = control
self.prompt = prompt # 全ての画像のcaptionをpromptにする
self.prefix = prefix # captionのprefix
self.shuffle = shuffle # バッチの取り出し方をシャッフルするかどうか(データローダー側でシャッフルした方が良い^^)
Expand Down Expand Up @@ -89,6 +91,9 @@ def __getitem__(self, i):
else:
batch["captions"] = self.get_captions(samples, self.caption)

if self.control:
batch["controlnet_hint"] = self.get_control(samples, self.control if isinstance(self.control, str) else "control")

return batch

# バッチの取り出し方を初期化するメソッド
Expand Down Expand Up @@ -180,4 +185,13 @@ def get_text_embeddings(self, samples, dir="text_emb"):
for sample in samples
])
pooled_outputs.to(memory_format=torch.contiguous_format).float()
return encoder_hidden_states, pooled_outputs
return encoder_hidden_states, pooled_outputs

def get_control(self, samples, dir="control"):
images = []
transform = transforms.ToTensor()
for sample in samples:
image = Image.open(os.path.join(self.path, dir, sample + f".png")).convert("RGB")
images.append(transform(image))
images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
return images_tensor
54 changes: 48 additions & 6 deletions modules/diffusion_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers import UNet2DConditionModel, ControlNetModel

class DiffusionModel(nn.Module):
def __init__(
self,
unet:UNet2DConditionModel,
controlnet:ControlNetModel=None,
sdxl:bool=False,
):
super().__init__()
self.unet = unet
self.controlnet = controlnet
self.sdxl = sdxl

def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None):
def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None, controlnet_hint=None):
if self.sdxl:
if size_condition is None:
h, w = latents.shape[2] * 8, latents.shape[3] * 8
Expand All @@ -22,18 +24,58 @@ def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size
else:
added_cond_kwargs = None

if self.controlnet is not None:
assert controlnet_hint is not None, "controlnet_hint is required when controlnet is enabled"
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_hint,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
else:
down_block_additional_residuals = None
mid_block_additional_residual = None

model_output = self.unet(
latents,
timesteps,
encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
).sample

return model_output

def enable_gradient_checkpointing(self, enable:bool=True):
if enable:
self.unet.enable_gradient_checkpointing()
def create_controlnet(self, config):
if config.resume is not None:
pre_controlnet = ControlNetModel.from_pretrained(config.resume)
else:
pre_controlnet = ControlNetModel.from_unet(self.unet)

if config.transformer_layers_per_block is not None:
down_block_types = tuple(["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in config.transformer_layers_per_block])
transformer_layers_per_block = tuple([int(x) for x in config.transformer_layers_per_block])
self.controlnet = ControlNetModel.from_config(
pre_controlnet.config,
down_block_types=down_block_types,
transformer_layers_per_block=transformer_layers_per_block,
)
self.controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)
del pre_controlnet
else:
self.unet.disable_gradient_checkpointing()
self.controlnet = pre_controlnet

self.controlnet.config.global_pool_conditions = config.global_average_pooling


def enable_gradient_checkpointing(self, enable:bool=True):
for model in [self.unet, self.controlnet]:
if model is not None:
if enable:
model.enable_gradient_checkpointing()
else:
model.disable_gradient_checkpointing()