In [1]:
from datasets import load_dataset

#直接使用我处理好的数据集
dataset = load_dataset(path='lansinuote/diffusion.3.dream_boothimages',
                       split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.3.dream_boothimages-73d13cf74f9b46f0
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_boothimages-73d13cf74f9b46f0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image'],
     num_rows: 5
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>})

In [2]:
import torchvision

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.RandomCrop(512),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])


def f(data):
    image = [compose(i) for i in data['image']]
    return {'image': image}


dataset = dataset.with_transform(f)

dataset, dataset[0]

(Dataset({
     features: ['image'],
     num_rows: 5
 }),
 {'image': tensor([[[ 0.7804,  0.7647,  0.7725,  ...,  0.7647,  0.7725,  0.7725],
           [ 0.7569,  0.7725,  0.7725,  ...,  0.7647,  0.7569,  0.7569],
           [ 0.7725,  0.7725,  0.7647,  ...,  0.7725,  0.7647,  0.7647],
           ...,
           [ 0.7098,  0.6941,  0.7020,  ...,  0.7098,  0.7098,  0.7020],
           [ 0.7020,  0.7176,  0.7098,  ...,  0.7020,  0.7176,  0.7098],
           [ 0.7176,  0.7333,  0.7098,  ...,  0.7020,  0.7176,  0.7020]],
  
          [[-0.1451, -0.1608, -0.1529,  ..., -0.1529, -0.1373, -0.1373],
           [-0.1686, -0.1529, -0.1529,  ..., -0.1529, -0.1529, -0.1529],
           [-0.1529, -0.1529, -0.1608,  ..., -0.1216, -0.1294, -0.1451],
           ...,
           [-0.2627, -0.2784, -0.2706,  ..., -0.2314, -0.2314, -0.2392],
           [-0.2706, -0.2549, -0.2627,  ..., -0.2392, -0.2235, -0.2314],
           [-0.2549, -0.2392, -0.2627,  ..., -0.2392, -0.2235, -0.2314]],
  
          [[-0.6

In [3]:
import torch

def collate_fn(datas):
    image = [i['image'] for i in datas]
    return torch.stack(image, dim=0)


loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=1,
                                     shuffle=True,
                                     collate_fn=collate_fn)

len(loader), next(iter(loader)).shape

(5, torch.Size([1, 3, 512, 512]))

In [4]:
from transformers.models.clip.modeling_clip import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

checkpoint = 'runwayml/stable-diffusion-v1-5'

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)
unet.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [5]:
#这个类不想测了,总之就是在做注意力的调整
from diffusers.models.cross_attention import LoRACrossAttnProcessor


def set_processors():
    processors = {}

    #遍历unet的所有层,找出所有有set_processor属性的层,每一个都组装成lora层
    for name in unet.attn_processors.keys():
        #768 = unet.config.cross_attention_dim
        cross_attention_dim = 768
        if name.endswith('attn1.processor'):
            cross_attention_dim = None

        #1280 = unet.config.block_out_channels[-1]
        hidden_size = 1280

        if name.startswith('up_blocks'):
            #取层名字中的第一个数字
            #p_blocks.1.attentions.0.transformer_blocks.0.attn1.processor -> 1
            block_id = int(name[10])
            hidden_size = list(reversed(
                unet.config.block_out_channels))[block_id]

        if name.startswith('down_blocks'):
            #取层名字中的第一个数字
            #down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor -> 2
            block_id = int(name[12])
            hidden_size = unet.config.block_out_channels[block_id]

        processors[name] = LoRACrossAttnProcessor(hidden_size,
                                                  cross_attention_dim)

        print(name, hidden_size, cross_attention_dim)

    #把上面组装好的字典,设置到unet的层当中
    unet.set_attn_processor(processors)


set_processors()

down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor 320 None
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor 320 768
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor 320 None
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor 320 768
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 640 None
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor 640 768
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor 640 None
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor 640 768
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor 1280 None
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor 1280 768
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor 1280 None
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor 1280 768
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 1280 None
up_blocks.1.attentions.0.transformer_blocks.0.attn2.pr

In [6]:
lora_layers = torch.nn.ModuleList(unet.attn_processors.values())

len(lora_layers)

32

In [7]:
from diffusers import DDPMScheduler
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          subfolder='tokenizer',
                                          use_fast=False)

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(lora_layers.parameters(),
                              lr=1e-4,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

tokenizer, scheduler, optimizer, criterion

(CLIPTokenizer(name_or_path='runwayml/stable-diffusion-v1-5', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}),
 DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.12.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9

In [8]:
def get_loss(data):
    device = data.device

    #只需要input ids就可以了
    #全程都是一句话,其实只需要编码一次即可
    #77 = tokenizer.model_max_length
    input_ids = tokenizer('a photo of sks dog',
                          truncation=True,
                          padding='max_length',
                          max_length=77,
                          return_tensors='pt')['input_ids']

    #编码文字,由于encoder不训练,其实这一步也可以只运算一次
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(input_ids.to(device))[0]

    #vae计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data).latent_dist.sample().detach()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ), device=device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #unet从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss(torch.randn(1, 3, 512, 512))

tensor(0.0348, grad_fn=<MseLossBackward0>)

In [9]:
from diffusers import DiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    vae.to(device)
    encoder.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(400):
        for i, data in enumerate(loader):
            loss = get_loss(data.to(device))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(lora_layers.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 20 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    unet.save_attn_procs('./save')


train()

0 0.3670004606246948
20 10.45627648581285
40 9.188310690922663
60 7.809588600997813
80 9.303583267959766
100 11.15790241828654
120 7.788689624285325
140 10.951385688968003
160 7.679351598257199
180 8.150683714775369
200 10.188504346180707
220 8.693295795586891
240 6.545399873633869
260 8.477286474080756
280 10.255770185613073
300 7.34097321680747
320 7.778350460110232
340 7.739824238466099
360 8.069359936169349
380 6.185638143564574
