In [1]:
#第4章/全局常量
repo_id = 'lansinuote/diffusion.2.textual_inversion'
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
#第4章/加载数据集
from datasets import Dataset
import PIL.Image


def get_dataset():
    images = [{
        'image': PIL.Image.open('images/%d.jpeg' % i)
    } for i in range(6)]

    return Dataset.from_list(images)


get_dataset()

Dataset({
    features: ['image'],
    num_rows: 6
})

In [3]:
#第4章/在线加载数据集
from datasets import load_dataset

dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image'],
     num_rows: 6
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1167x1010>})

In [4]:
#第4章/定义描述文本
texts = [
    'a photo of a <cat-toy>', 'a rendering of a <cat-toy>',
    'a cropped photo of the <cat-toy>', 'the photo of a <cat-toy>',
    'a photo of a clean <cat-toy>', 'a photo of a dirty <cat-toy>',
    'a dark photo of the <cat-toy>', 'a photo of my <cat-toy>',
    'a photo of the cool <cat-toy>', 'a close-up photo of a <cat-toy>',
    'a bright photo of the <cat-toy>', 'a cropped photo of a <cat-toy>',
    'a photo of the <cat-toy>', 'a good photo of the <cat-toy>',
    'a photo of one <cat-toy>', 'a close-up photo of the <cat-toy>',
    'a rendition of the <cat-toy>', 'a photo of the clean <cat-toy>',
    'a rendition of a <cat-toy>', 'a photo of a nice <cat-toy>',
    'a good photo of a <cat-toy>', 'a photo of the nice <cat-toy>',
    'a photo of the small <cat-toy>', 'a photo of the weird <cat-toy>',
    'a photo of the large <cat-toy>', 'a photo of a cool <cat-toy>',
    'a photo of a small <cat-toy>'
]

In [5]:
#第4章/数据集预处理
import torchvision
import random
import numpy as np
from transformers import CLIPTokenizer
import torch

#文字编码
tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
])


def f(data):
    #编码文字
    #77 = tokenizer.model_max_length
    input_ids = tokenizer(random.choice(texts),
                          padding='max_length',
                          truncation=True,
                          max_length=77,
                          return_tensors='pt')['input_ids']

    #编码图片
    pixel_values = []
    for i in range(len(data['image'])):
        image = data['image'][i]

        #数据增强
        image = compose(image)

        #尺寸缩放
        image = image.resize((512, 512), resample=PIL.Image.Resampling.BICUBIC)

        #数值操作
        image = np.array(image).astype(np.uint8)
        image = image / 127.5 - 1.0
        image = image.astype(np.float32)

        #转tensor,把通道维度放在前面
        #[512, 512, 3] -> [3, 512, 512]
        image = torch.from_numpy(image).permute(2, 0, 1)

        pixel_values.append(image)

    return {'input_ids': input_ids, 'pixel_values': pixel_values}


dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

input_ids torch.Size([77]) torch.int64
pixel_values torch.Size([3, 512, 512]) torch.float32


Dataset({
    features: ['image'],
    num_rows: 6
})

In [6]:
#第4章/定义loader
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

input_ids torch.Size([1, 77]) torch.int64
pixel_values torch.Size([1, 3, 512, 512]) torch.float32


6

In [7]:
#第4章/加载模型
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [8]:
#第4章/添加新词
def init_new_word():
    #字典里添加新词
    tokenizer.add_tokens('<cat-toy>')

    #扩展encoder的embed层,添加一个新空间用于容纳新词
    encoder.resize_token_embeddings(len(tokenizer))

    #取新旧两个词的id
    old_id = tokenizer.convert_tokens_to_ids('toy')
    new_id = tokenizer.convert_tokens_to_ids('<cat-toy>')

    embed = encoder.get_input_embeddings().weight.data

    #以旧词来初始化新词
    embed[new_id] = embed[old_id]


init_new_word()

In [9]:
#第4章/锁定部分模型参数
#这两个模型不更新参数
vae.requires_grad_(False)
unet.requires_grad_(False)

#只训练encoder.text_model.embeddings.token_embedding层,其他全部锁定
encoder.text_model.encoder.requires_grad_(False)
encoder.text_model.final_layer_norm.requires_grad_(False)
encoder.text_model.embeddings.position_embedding.requires_grad_(False)

Embedding(77, 768)

In [10]:
#第4章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(encoder.get_input_embeddings().parameters(),
                              lr=5e-4,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 0.0005
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [11]:
#第4章/定义计算loss的函数
def get_loss(data):
    device = data['input_ids'].device

    #编码文字
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample().detach()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ), device=device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512)
})

tensor(0.0791, grad_fn=<MseLossBackward0>)

In [12]:
#第4章/训练
from diffusers import StableDiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    encoder.to(device)
    vae.to(device)
    unet.to(device)
    encoder.train()

    loss_sum = 0
    for epoch in range(800):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data) / 4
            loss.backward()

            #积累更新
            if (epoch * len(loader) + i) % 4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 20 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #保存
    StableDiffusionPipeline.from_pretrained(
        checkpoint,
        text_encoder=encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer).save_pretrained('./save')


train()

0 0.05298610217869282
20 3.669440200086683
40 3.6395892064902
60 4.406068226147909
80 4.021411922876723
100 4.375712881417712
120 3.691248642571736
140 4.046981384977698
160 4.32251713424921
180 4.481215654464904
200 4.141445596702397
220 3.5479576015495695
240 4.255890477390494
260 4.373657913354691
280 3.9518783968524076
300 3.48222059692489
320 3.856919998244848
340 4.329844517749734
360 3.967784309934359
380 3.713769310968928
400 4.741978614707477
420 3.2556250613415614
440 3.7640564047615044
460 4.055520819092635
480 3.800076044688467
500 4.761352627770975
520 3.5383732234477066
540 3.6530803244677372
560 3.705913124314975
580 3.817220114695374
600 4.2431890980224125
620 3.727043268852867
640 4.312535111501347
660 4.316632353293244
680 3.5848469483316876
700 4.623237420630176
720 3.859215323231183
740 3.680146592843812
760 4.15289314155234
780 3.849531917832792
