In [1]:
#第3章/全局常量
repo_id = 'lansinuote/diffusion.1.unconditional'

In [2]:
#第3章/加载数据集
from datasets import load_dataset

load_dataset('huggan/flowers-102-categories', split='train')

Using custom data configuration huggan--flowers-102-categories-2ab3d0588f2a8da7
Found cached dataset parquet (/root/.cache/huggingface/datasets/huggan___parquet/huggan--flowers-102-categories-2ab3d0588f2a8da7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Dataset({
    features: ['image'],
    num_rows: 8189
})

In [3]:
#第3章/使用笔者转载的数据集
dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.1.unconditional-a9e8dd8d646db18f
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.1.unconditional-a9e8dd8d646db18f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image'],
     num_rows: 8189
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=752x500>})

In [4]:
#第3章/数据集预处理
import torchvision

#图像增强和编码
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        64, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.RandomCrop(64),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5])
])


def f(data):
    image = [compose(i) for i in data['image']]
    return {'image': image}


#因为图像增强在每个epoch中要动态计算,所以不能简单地用map处理
dataset.set_transform(f)

dataset, dataset[0]['image'].shape

(Dataset({
     features: ['image'],
     num_rows: 8189
 }),
 torch.Size([3, 64, 64]))

In [5]:
#第3章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=16,
                                     shuffle=True,
                                     drop_last=False)

len(loader), next(iter(loader))['image'].shape

(512, torch.Size([16, 3, 64, 64]))

In [6]:
#第3章/定义模型
from diffusers import UNet2DModel

#定义模型,随机初始化参数
model = UNet2DModel(
    sample_size=64,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        'DownBlock2D',
        'DownBlock2D',
        'DownBlock2D',
        'DownBlock2D',
        'AttnDownBlock2D',
        'DownBlock2D',
    ),
    up_block_types=(
        'UpBlock2D',
        'AttnUpBlock2D',
        'UpBlock2D',
        'UpBlock2D',
        'UpBlock2D',
        'UpBlock2D',
    ),
)

sum(i.numel() for i in model.parameters()) / 10000

11367.3219

In [7]:
#第3章/初始化工具类
from diffusers import DDPMScheduler
from diffusers.optimization import get_scheduler

scheduler = DDPMScheduler(num_train_timesteps=1000,
                          beta_schedule='linear',
                          prediction_type='epsilon')

optimizer = torch.optim.AdamW(model.parameters(),
                              lr=1e-4,
                              betas=(0.95, 0.999),
                              weight_decay=1e-6,
                              eps=1e-8)

scheduler_lr = get_scheduler('cosine',
                             optimizer=optimizer,
                             num_warmup_steps=500,
                             num_training_steps=len(loader) * 100)

criterion = torch.nn.MSELoss()

scheduler, optimizer, scheduler_lr, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.02,
   "beta_schedule": "linear",
   "beta_start": 0.0001,
   "clip_sample": true,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.95, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     initial_lr: 0.0001
     lr: 0.0
     maximize: False
     weight_decay: 1e-06
 ),
 <torch.optim.lr_scheduler.LambdaLR at 0x7fdb1003b280>,
 MSELoss())

In [8]:
#第3章/定义计算loss的函数
def get_loss(image):
    device = image.device

    #随机噪声
    #[b, 3, 64, 64]
    noise = torch.randn(image.shape).to(device)

    #随机b个噪声步数
    #1000 = scheduler.config.num_train_timesteps
    #[b]
    noise_step = torch.randint(0, 1000, (image.shape[0], ),
                               device=device).long()

    #往图片当中添加噪声
    #[b, 3, 64, 64]
    image_noise = scheduler.add_noise(image, noise, noise_step)

    #把图片里的噪声计算出来
    #[b, 3, 64, 64]
    out = model(image_noise, noise_step).sample

    #求mse loss
    return criterion(out, noise)


get_loss(torch.randn(16, 3, 64, 64))

tensor(1.2061, grad_fn=<MseLossBackward0>)

In [9]:
#第3章/训练
from diffusers import DDPMPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.train()

    loss_sum = 0
    for epoch in range(10):
        for i, data in enumerate(loader):
            loss = get_loss(data['image'].to(device))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler_lr.step()
            optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 1 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #save
    DDPMPipeline(unet=model, scheduler=scheduler).save_pretrained('./save')


train()

0 95.75697771087289
1 21.93589099915698
2 19.333377947565168
3 18.43381611024961
4 16.366059875115752
5 16.564373218920082
6 16.268689731601626
7 16.676569791976362
8 16.546459552831948
9 16.88732399744913
