In [1]:
import torch

#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/diffusion.7.control_net'
push_to_hub = True
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
from datasets import load_dataset

if push_to_hub:
    #加载数据集
    #不能用map直接编码,否则数据量太大,没法处理
    dataset = load_dataset('fusing/fill50k')
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#使用我转存的数据集
dataset = load_dataset(path=repo_id, split='train')

print(dataset, dataset[0])

Using custom data configuration lansinuote--diffusion.7.control_net-93b816e0abd9e137
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.7.control_net-93b816e0abd9e137/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Dataset({
    features: ['image', 'conditioning_image', 'text'],
    num_rows: 50000
}) {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7F9695DF5CD0>, 'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7F9695DF5DF0>, 'text': 'pale golden rod circle with old lace background'}


In [3]:
from transformers import CLIPTokenizer
import torchvision

tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.ToTensor(),
])

norm = torchvision.transforms.Normalize([0.5], [0.5])


#定义loader
def collate_fn(data):
    #取数据
    text = [i['text'] for i in data]
    image = [i['image'] for i in data]
    conditioning_image = [i['conditioning_image'] for i in data]

    #文字编码
    #77 = tokenizer.model_max_length
    input_ids = tokenizer.batch_encode_plus(text,
                                            max_length=77,
                                            padding='max_length',
                                            truncation=True,
                                            return_tensors=None).input_ids

    #图像编码
    pixel_values = [norm(compose(i)) for i in image]
    conditioning_pixel_values = [compose(i) for i in conditioning_image]

    #转tensor
    input_ids = torch.LongTensor(input_ids)
    pixel_values = torch.stack(pixel_values)
    conditioning_pixel_values = torch.stack(conditioning_pixel_values)

    return {
        'input_ids': input_ids,
        'pixel_values': pixel_values,
        'conditioning_pixel_values': conditioning_pixel_values
    }


loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     batch_size=1)

len(loader), next(iter(loader))

(50000,
 {'input_ids': tensor([[49406, 16033,  7117,   593, 16776,  5994, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),
  'pixel_values': tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
            ...,
            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
            [1.0000, 1.0000, 1.0000,  ..., 1

In [4]:
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from transformers import PretrainedConfig

#加载3个模型
encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

#定义controlnet
%run controlnet_model.ipynb
controlnet = ControlNet(PretrainedConfig())
load_params(controlnet, unet)

vae.requires_grad_(False)
encoder.requires_grad_(False)
unet.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)
print_model_size('controlnet', controlnet)

encoder 12306.048
vae 8365.3863
unet 85952.0964
controlnet 36127.912


In [5]:
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(controlnet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [6]:
def get_loss(data):
    device = data['input_ids'].device

    #文字编码
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #抽取图像特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()
    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机数,unet的计算目标
    noise = torch.randn_like(out_vae)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #1 = out_vae.shape[0]
    noise_step = torch.randint(0, 1000, (1, )).long()
    noise_step = noise_step.to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #使用controlnet计算unet模型的down和mid部分的数据
    #在unet中依然会计算down和mid,但会和controlnet的计算结果相加
    #相当于为unet的down和mid的计算结果添加额外信息
    out_control_down, out_control_mid = controlnet(
        out_vae_noise,
        noise_step,
        out_encoder,
        condition=data['conditioning_pixel_values'])

    #根据文字信息,把特征图中的噪声计算出来
    out_unet = unet(out_vae_noise,
                    noise_step,
                    encoder_hidden_states=out_encoder,
                    down_block_additional_residuals=out_control_down,
                    mid_block_additional_residual=out_control_mid).sample

    #计算mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512),
    'conditioning_pixel_values': torch.randn(1, 3, 512, 512),
})

tensor(0.2469, grad_fn=<MseLossBackward0>)

In [7]:
from diffusers import StableDiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    encoder.to(device)
    vae.to(device)
    controlnet.to(device)
    controlnet.train()

    loss_sum = 0
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        loss = get_loss(data) / 4
        loss.backward()
        loss_sum += loss.item()

        if i % 4 == 0:
            torch.nn.utils.clip_grad_norm_(controlnet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        if i % 2000 == 0:
            print(i, loss_sum)
            loss_sum = 0

    #保存到本地
    torch.save(controlnet.cpu(), './save/controlnet.model')


if push_to_hub:
    train()
    #保存到hub
    controlnet.push_to_hub(repo_id=repo_id, use_auth_token=hub_token)

0 0.0026870747096836567
2000 3.461793488706462
4000 3.049983833108854
6000 2.88362369704555
8000 3.005527748755412
10000 2.7510610005374474
12000 2.7980021113398834
14000 2.4720400111327763
16000 2.62037534315823
18000 2.375480920953123
20000 2.5328880913584726
22000 2.540079788388539
24000 2.484456200958448
26000 2.3819653876125813
28000 2.3381337551472825
30000 2.383196011167456
32000 2.33385793850357
34000 2.2283505542582134
36000 2.0874762790735986
38000 2.211367720625276
40000 2.3298714868487878
42000 2.2381660530427325
44000 2.1768586739563034
46000 2.202571557407282
48000 2.176738911841312
