In [1]:
#第5章/全局常量
repo_id = 'lansinuote/diffusion.3.dream_booth'
checkpoint = 'CompVis/stable-diffusion-v1-4'

In [2]:
#第5章/加载数据集
from datasets import Dataset
import PIL.Image


def get_dataset():
    images = [{
        'image': PIL.Image.open('images/%d.jpeg' % i),
        'text': 'a photo of little dog',
    } for i in range(5)]

    return Dataset.from_list(images)


get_dataset()

Dataset({
    features: ['image', 'text'],
    num_rows: 5
})

In [3]:
#第5章/在线加载数据集
from datasets import load_dataset

dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image', 'text'],
     num_rows: 5
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,
  'text': 'a photo of little dog'})

In [4]:
#第5章/数据集预处理
import torchvision
from transformers import AutoTokenizer

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.RandomCrop(512),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])

#文字编码
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          subfolder='tokenizer',
                                          use_fast=False)


def f(data):
    #图像编码
    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)

    #文字编码
    #77 = tokenizer.model_max_length
    tokens = tokenizer(data['text'][0],
                       truncation=True,
                       padding='max_length',
                       max_length=77,
                       return_tensors='pt')

    return {
        'pixel_values': pixel_values,
        'input_ids': tokens.input_ids,
        'attention_mask': tokens.attention_mask
    }


dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

pixel_values torch.Size([3, 512, 512]) torch.float32
input_ids torch.Size([77]) torch.int64
attention_mask torch.Size([77]) torch.int64


Dataset({
    features: ['image', 'text'],
    num_rows: 5
})

In [5]:
#第5章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=1,
                                     shuffle=True,
                                     collate_fn=None)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

pixel_values torch.Size([1, 3, 512, 512]) torch.float32
input_ids torch.Size([1, 77]) torch.int64
attention_mask torch.Size([1, 77]) torch.int64


5

In [6]:
#第5章/加载模型
from transformers.models.clip.modeling_clip import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [7]:
#第5章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=5e-6,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 5e-06
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [8]:
#第5章/定义计算loss的函数
def get_loss(data):
    #编码文字
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(input_ids=data['input_ids'],
                          attention_mask=data['attention_mask'])[0]

    #计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ),
                               device=data['input_ids'].device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss({
    'pixel_values': torch.randn(1, 3, 512, 512),
    'input_ids': torch.ones(1, 77).long(),
    'attention_mask': torch.ones(1, 77).long()
})

tensor(0.0022, grad_fn=<MseLossBackward0>)

In [9]:
#第5章/训练
from diffusers import DiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    vae.to(device)
    encoder.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(200):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    DiffusionPipeline.from_pretrained(
        checkpoint, unet=unet, text_encoder=encoder).save_pretrained('./save')


train()

0 0.8123962143436074
10 4.12263014621567
20 3.356302598491311
30 4.225471537210979
40 4.843276363797486
50 3.3777126917848364
60 4.564837389625609
70 4.9124050753889605
80 3.9872994840843603
90 3.7599874258739874
100 4.364669724949636
110 3.3703145849285647
120 2.4453093906631693
130 3.1313629028154537
140 3.1495733208721504
150 3.0935851570684463
160 2.7813034405699
170 3.6852261961903423
180 2.1988481080625206
190 2.6699363642837852
