In [1]:
#第6章/全局常量
repo_id = 'lansinuote/diffusion.4.text_to_image.book'
checkpoint = 'CompVis/stable-diffusion-v1-4'

In [2]:
#第6章/加载数据集
from datasets import load_dataset


def get_dataset():
    #加载数据集
    dataset = load_dataset('m1guelpf/nouns', split='train')

    #采样
    dataset = dataset.shuffle(seed=0).select(range(1000))

    return dataset


get_dataset()

Using custom data configuration m1guelpf--nouns-cc6819088b485316
Found cached dataset parquet (/root/.cache/huggingface/datasets/m1guelpf___parquet/m1guelpf--nouns-cc6819088b485316/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/m1guelpf___parquet/m1guelpf--nouns-cc6819088b485316/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-e94f99777e7dd79d.arrow


Dataset({
    features: ['image', 'text'],
    num_rows: 1000
})

In [3]:
#第6章/直接使用笔者处理好的数据集
dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.4.text_to_image.book-22cef9f769873d9d
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.4.text_to_image.book-22cef9f769873d9d/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image', 'text'],
     num_rows: 1000
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=320x320>,
  'text': 'a pixel art character with square orange glasses, a whale-shaped head and a teal-colored body on a warm background'})

In [4]:
#第6章/数据集预处理
import torchvision
from transformers import CLIPTokenizer

#图像增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])

#文字编码
tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')


def f(data):
    #应用图像数据增强
    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)

    #文字编码
    #77 = tokenizer.model_max_length
    input_ids = tokenizer.batch_encode_plus(data['text'],
                                            max_length=77,
                                            padding='max_length',
                                            truncation=True,
                                            return_tensors='pt').input_ids

    return {'pixel_values': pixel_values, 'input_ids': input_ids}


#因为图像增强在每一个epoch中是动态计算的,所以不能简单地用map处理
dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

pixel_values torch.Size([3, 512, 512]) torch.float32
input_ids torch.Size([77]) torch.int64


Dataset({
    features: ['image', 'text'],
    num_rows: 1000
})

In [5]:
#第6章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=None,
                                     batch_size=1)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

pixel_values torch.Size([1, 3, 512, 512]) torch.float32
input_ids torch.Size([1, 77]) torch.int64


1000

In [6]:
#第6章/加载模型
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [7]:
#第6章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [8]:
#第6章/定义计算loss的函数
def get_loss(data):
    device = data['input_ids'].device

    #文字编码
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #抽取图像特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()
    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    noise = torch.randn_like(out_vae)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #1 = out_vae.shape[0]
    noise_step = torch.randint(0, 1000, (1, )).long()
    noise_step = noise_step.to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #根据文字信息,把特征图中的噪声计算出来
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    #计算mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512)
})

tensor(0.5090, grad_fn=<MseLossBackward0>)

In [9]:
#第6章/训练
from diffusers import StableDiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    encoder.to(device)
    vae.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(150):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data) / 4
            loss.backward()
            loss_sum += loss.item()

            if (i + 1) % 4 == 0:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #保存
    StableDiffusionPipeline.from_pretrained(
        checkpoint, text_encoder=encoder, vae=vae,
        unet=unet).save_pretrained('./save')


train()

0 8.22686066525057
10 67.46926167499623
20 60.65740441545495
30 55.43549637134129
40 50.09387398760009
50 45.60564994261949
60 43.586335998283175
70 40.77665254819658
80 37.13344954059721
90 34.85505028096668
100 32.70626295961847
110 31.372446675017272
120 30.28200590808774
130 28.05749295847636
140 26.114318155006913
