In [1]:
#第7章/全局常量
repo_id = 'lansinuote/diffusion.3.dream_booth'
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
#第7章/加载数据集
from datasets import load_dataset

dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image', 'text'],
     num_rows: 5
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,
  'text': 'a photo of little dog'})

In [3]:
#第7章/数据集预处理
import torchvision
from transformers import AutoTokenizer

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.RandomCrop(512),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])

#文字编码
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          subfolder='tokenizer',
                                          use_fast=False)


def f(data):
    #图像编码
    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)

    #文字编码
    #77 = tokenizer.model_max_length
    tokens = tokenizer(data['text'][0],
                       truncation=True,
                       padding='max_length',
                       max_length=77,
                       return_tensors='pt')

    return {
        'pixel_values': pixel_values,
        'input_ids': tokens.input_ids,
        'attention_mask': tokens.attention_mask
    }


dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

pixel_values torch.Size([3, 512, 512]) torch.float32
input_ids torch.Size([77]) torch.int64
attention_mask torch.Size([77]) torch.int64


Dataset({
    features: ['image', 'text'],
    num_rows: 5
})

In [4]:
#第7章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=1,
                                     shuffle=True,
                                     collate_fn=None)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

pixel_values torch.Size([1, 3, 512, 512]) torch.float32
input_ids torch.Size([1, 77]) torch.int64
attention_mask torch.Size([1, 77]) torch.int64


5

In [5]:
#第7章/加载模型
from transformers.models.clip.modeling_clip import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)
unet.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [6]:
#第7章/创建LoRA神经网络层
from diffusers.models.cross_attention import LoRACrossAttnProcessor


def set_processors():
    processors = {}

    #遍历unet的所有层,找出所有有set_processor属性的层,每一个都组装成lora层
    for name in unet.attn_processors.keys():
        #768 = unet.config.cross_attention_dim
        cross_attention_dim = 768
        if name.endswith('attn1.processor'):
            cross_attention_dim = None

        #1280 = unet.config.block_out_channels[-1]
        hidden_size = 1280

        if name.startswith('up_blocks'):
            #取层名字中的第一个数字
            #p_blocks.1.attentions.0.transformer_blocks.0.attn1.processor -> 1
            block_id = int(name[10])
            hidden_size = list(reversed(
                unet.config.block_out_channels))[block_id]

        if name.startswith('down_blocks'):
            #取层名字中的第一个数字
            #down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor -> 2
            block_id = int(name[12])
            hidden_size = unet.config.block_out_channels[block_id]

        processors[name] = LoRACrossAttnProcessor(hidden_size,
                                                  cross_attention_dim)

        print(name, hidden_size, cross_attention_dim)

    #把上面组装好的字典,设置到unet的层当中
    unet.set_attn_processor(processors)

    return torch.nn.ModuleList(unet.attn_processors.values())


lora_layers = set_processors()

len(lora_layers)

down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor 320 None
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor 320 768
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor 320 None
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor 320 768
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 640 None
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor 640 768
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor 640 None
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor 640 768
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor 1280 None
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor 1280 768
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor 1280 None
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor 1280 768
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 1280 None
up_blocks.1.attentions.0.transformer_blocks.0.attn2.pr

  deprecate(
  deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)


32

In [7]:
#第7章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(lora_layers.parameters(),
                              lr=1e-4,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 0.0001
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [8]:
#第7章/定义计算loss的函数
def get_loss(data):
    device = data['input_ids'].device

    #编码文字
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(input_ids=data['input_ids'],
                          attention_mask=data['attention_mask'])[0]

    #计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample().detach()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ), device=device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss({
    'pixel_values': torch.randn(1, 3, 512, 512),
    'input_ids': torch.ones(1, 77).long(),
    'attention_mask': torch.ones(1, 77).long()
})

tensor(0.2792, grad_fn=<MseLossBackward0>)

In [9]:
#第7章/训练
from diffusers import DiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    vae.to(device)
    encoder.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(200):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)
            loss = get_loss(data)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(lora_layers.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    unet.save_attn_procs('./save')


train()

0 1.5372274816036224
10 4.841767566744238
20 4.612118934630416
30 5.884213687619194
40 5.1643907791003585
50 6.224851356353611
60 4.4865004296880215
70 2.892019562306814
80 3.9639462111517787
90 2.890541360131465
100 4.777392436284572
110 4.6514521923381835
120 4.407747785095125
130 3.77918853063602
140 4.638857806101441
150 4.5095627416158095
160 6.594769516959786
170 5.017870861222036
180 5.188574061263353
190 4.35909217922017
