In [1]:
import torch

#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/diffusion.3.dream_booth'
push_to_hub = True
checkpoint = 'CompVis/stable-diffusion-v1-4'

In [2]:
from datasets import Dataset, DatasetDict, load_dataset
import torchvision
import PIL.Image


def get_dataset():
    images = [{
        'image': PIL.Image.open('images/%d.jpeg' % i),
        'text': 'a photo of little dog',
    } for i in range(5)]

    dataset = Dataset.from_list(images)

    return DatasetDict({'train': dataset})


if push_to_hub:
    dataset = get_dataset()
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#直接使用我处理好的数据集
dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Using custom data configuration lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


(Dataset({
     features: ['image', 'text'],
     num_rows: 5
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,
  'text': 'a photo of little dog'})

In [3]:
import torchvision
from transformers import AutoTokenizer

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.RandomCrop(512),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])

#文字编码
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          subfolder='tokenizer',
                                          use_fast=False)


def f(data):
    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)

    #全程都是一句话,其实只需要编码一次即可
    #77 = tokenizer.model_max_length
    tokens = tokenizer(data['text'][0],
                       truncation=True,
                       padding='max_length',
                       max_length=77,
                       return_tensors='pt')

    return {
        'pixel_values': pixel_values,
        'input_ids': tokens.input_ids,
        'attention_mask': tokens.attention_mask
    }


dataset = dataset.with_transform(f)

dataset, dataset[0]

(Dataset({
     features: ['image', 'text'],
     num_rows: 5
 }),
 {'pixel_values': tensor([[[ 0.7725,  0.7804,  0.7804,  ...,  0.7647,  0.7725,  0.7647],
           [ 0.7725,  0.7882,  0.7725,  ...,  0.7647,  0.7647,  0.7647],
           [ 0.7804,  0.7804,  0.7725,  ...,  0.7647,  0.7804,  0.7647],
           ...,
           [ 0.7098,  0.7176,  0.7176,  ...,  0.6941,  0.7098,  0.7020],
           [ 0.7098,  0.7255,  0.7255,  ...,  0.6941,  0.6863,  0.7098],
           [ 0.7020,  0.7176,  0.7176,  ...,  0.6941,  0.7020,  0.7098]],
  
          [[-0.1529, -0.1451, -0.1451,  ..., -0.1373, -0.1373, -0.1373],
           [-0.1529, -0.1451, -0.1529,  ..., -0.1451, -0.1529, -0.1373],
           [-0.1373, -0.1451, -0.1608,  ..., -0.1529, -0.1451, -0.1373],
           ...,
           [-0.2549, -0.2471, -0.2471,  ..., -0.2471, -0.2314, -0.2392],
           [-0.2627, -0.2471, -0.2471,  ..., -0.2471, -0.2549, -0.2314],
           [-0.2706, -0.2549, -0.2471,  ..., -0.2471, -0.2392, -0.2314]],
  
 

In [4]:
loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=1,
                                     shuffle=True,
                                     collate_fn=None)

len(loader), next(iter(loader))

(5,
 {'pixel_values': tensor([[[[ 0.7647,  0.7490,  0.7490,  ...,  0.7569,  0.7412,  0.7490],
            [ 0.7490,  0.7647,  0.7412,  ...,  0.7412,  0.7490,  0.7333],
            [ 0.7647,  0.7569,  0.7569,  ...,  0.7333,  0.7333,  0.7412],
            ...,
            [ 0.6863,  0.6863,  0.6627,  ...,  0.6784,  0.6392,  0.6784],
            [ 0.6863,  0.6784,  0.6706,  ...,  0.6627,  0.6627,  0.6627],
            [ 0.6863,  0.6549,  0.6627,  ...,  0.6706,  0.6706,  0.6706]],
  
           [[-0.1608, -0.1765, -0.1765,  ..., -0.1843, -0.1922, -0.1765],
            [-0.1765, -0.1608, -0.1843,  ..., -0.1843, -0.1765, -0.1922],
            [-0.1608, -0.1686, -0.1686,  ..., -0.1922, -0.1843, -0.1922],
            ...,
            [-0.2549, -0.2549, -0.2863,  ..., -0.2392, -0.2706, -0.2314],
            [-0.2627, -0.2706, -0.2784,  ..., -0.2549, -0.2471, -0.2471],
            [-0.2627, -0.2941, -0.2863,  ..., -0.2392, -0.2392, -0.2392]],
  
           [[-0.6392, -0.6549, -0.6549,  ..., -0.6

In [5]:
from transformers.models.clip.modeling_clip import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [6]:
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=5e-6,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 5e-06
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [7]:
def get_loss(data):
    #编码文字,由于encoder不训练,其实这一步也可以只运算一次
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(input_ids=data['input_ids'],
                          attention_mask=data['attention_mask'])[0]

    #vae计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ),
                               device=data['input_ids'].device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #unet从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss({
    'pixel_values': torch.randn(1, 3, 512, 512),
    'input_ids': torch.ones(1, 77).long(),
    'attention_mask': torch.ones(1, 77).long()
})

tensor(0.0311, grad_fn=<MseLossBackward0>)

In [8]:
from diffusers import DiffusionPipeline
from huggingface_hub import Repository, create_repo


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    vae.to(device)
    encoder.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(200):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    DiffusionPipeline.from_pretrained(
        checkpoint, unet=unet, text_encoder=encoder).save_pretrained('./save')


if push_to_hub:
    create_repo(repo_id, exist_ok=True, token=hub_token)
    repo = Repository('./save', clone_from=repo_id, token=hub_token)
    train()
    repo.push_to_hub()

0 0.15169205842539668
10 5.945965459337458
20 3.8129762182943523
30 3.295396795729175
40 3.654667090624571
50 6.476580709917471
60 4.5216363104991615
70 3.3012475325958803
80 3.4042327782372013
90 4.213736498611979
100 4.350477983767632
110 4.34992079436779
120 3.7817602755967528
130 3.6707470717374235
140 2.6561128611210734
150 1.8198281603399664
160 2.432444425066933
170 3.2916298899799585
180 3.0428609114605933
190 3.1719170592259616
