In [1]:
#第13章/全局常量
repo_id = 'lansinuote/diffusion.7.control_net'
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
#第13章/加载数据集
from datasets import load_dataset

dataset = load_dataset('fusing/fill50k', split='train')

dataset, dataset[0]

Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/fusing--fill50k/f23b778406682a796a540934e7163495e1b8a88fefc76ca08f7e5a79ddcd668b (last modified on Fri May  5 18:07:46 2023) since it couldn't be found locally at fusing/fill50k., or remotely on the Hugging Face Hub.
No config specified, defaulting to: fill50k/default
Found cached dataset fill50k (/root/.cache/huggingface/datasets/fusing___fill50k/default/0.0.2/f23b778406682a796a540934e7163495e1b8a88fefc76ca08f7e5a79ddcd668b)


(Dataset({
     features: ['image', 'conditioning_image', 'text'],
     num_rows: 50000
 }),
 {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512>,
  'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512>,
  'text': 'pale golden rod circle with old lace background'})

In [3]:
#第13章/使用笔者转存的数据集
dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.7.control_net-93b816e0abd9e137
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.7.control_net-93b816e0abd9e137/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image', 'conditioning_image', 'text'],
     num_rows: 50000
 }),
 {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512>,
  'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512>,
  'text': 'pale golden rod circle with old lace background'})

In [4]:
#第13章/数据集预处理
from transformers import CLIPTokenizer
import torchvision

#文字编码
tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

#图像编码
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.ToTensor(),
])

#图像数据增强
norm = torchvision.transforms.Normalize([0.5], [0.5])


def f(data):
    #文字编码
    #77 = tokenizer.model_max_length
    input_ids = tokenizer.batch_encode_plus(data['text'],
                                            max_length=77,
                                            padding='max_length',
                                            truncation=True,
                                            return_tensors='pt').input_ids

    #图像编码
    pixel_values = norm(compose(data['image'][0])).unsqueeze(dim=0)
    conditioning_pixel_values = compose(
        data['conditioning_image'][0]).unsqueeze(dim=0)

    return {
        'input_ids': input_ids,
        'pixel_values': pixel_values,
        'conditioning_pixel_values': conditioning_pixel_values
    }


dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

input_ids torch.Size([77]) torch.int64
pixel_values torch.Size([3, 512, 512]) torch.float32
conditioning_pixel_values torch.Size([3, 512, 512]) torch.float32


Dataset({
    features: ['image', 'conditioning_image', 'text'],
    num_rows: 50000
})

In [5]:
#第13章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=None,
                                     batch_size=1)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

input_ids torch.Size([1, 77]) torch.int64
pixel_values torch.Size([1, 3, 512, 512]) torch.float32
conditioning_pixel_values torch.Size([1, 3, 512, 512]) torch.float32


50000

In [6]:
#第13章/加载模型
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from transformers import PretrainedConfig

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

#定义controlnet
%run controlnet_model.ipynb
controlnet = ControlNet(PretrainedConfig())
load_params(controlnet, unet)

vae.requires_grad_(False)
encoder.requires_grad_(False)
unet.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)
print_model_size('controlnet', controlnet)

encoder 12306.048
vae 8365.3863
unet 85952.0964
controlnet 36127.912


In [7]:
#第13章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(controlnet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [8]:
#第13章/定义计算loss的函数
def get_loss(data):
    device = data['input_ids'].device

    #文字编码
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #抽取图像特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()
    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    noise = torch.randn_like(out_vae)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #1 = out_vae.shape[0]
    noise_step = torch.randint(0, 1000, (1, )).long()
    noise_step = noise_step.to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #使用ControlNet计算U-Net模型的down和mid部分的数据
    #在U-Net中依然会计算down和mid,但会和ControlNet的计算结果相加
    #相当于为U-Net的down和mid的计算结果添加额外信息
    out_control_down, out_control_mid = controlnet(
        out_vae_noise,
        noise_step,
        out_encoder,
        condition=data['conditioning_pixel_values'])

    #根据文字信息,把特征图中的噪声计算出来
    out_unet = unet(out_vae_noise,
                    noise_step,
                    encoder_hidden_states=out_encoder,
                    down_block_additional_residuals=out_control_down,
                    mid_block_additional_residual=out_control_mid).sample

    #计算MSELoss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512),
    'conditioning_pixel_values': torch.randn(1, 3, 512, 512),
})

tensor(0.4101, grad_fn=<MseLossBackward0>)

In [9]:
#第13章/训练
from diffusers import StableDiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    encoder.to(device)
    vae.to(device)
    controlnet.to(device)
    controlnet.train()

    loss_sum = 0
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        loss = get_loss(data) / 4
        loss.backward()
        loss_sum += loss.item()

        if i % 4 == 0:
            torch.nn.utils.clip_grad_norm_(controlnet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        if i % 2000 == 0:
            print(i, loss_sum)
            loss_sum = 0

    #保存到本地
    torch.save(controlnet.cpu(), './save/controlnet.model')


train()

0 0.0018208031542599201
2000 3.6985999149037525
4000 3.1040094250274706
6000 2.883207921491703
8000 2.6912338031579566
10000 2.5750201834653126
12000 2.6205025968265545
14000 2.766968385942164
16000 2.5338104366965126
18000 2.4684665279964975
20000 2.43693045443797
22000 2.3107143878332863
24000 2.244754715113231
26000 2.346211364132614
28000 2.4493323729584517
30000 2.2317955002490635
32000 2.24456998439382
34000 2.1874302507603716
36000 2.2547806691163714
38000 2.3454841869133816
40000 2.247590599119576
42000 2.16816546260452
44000 2.2492195993963833
46000 2.380007522253436
48000 2.2661790897145693
