In [1]:
import torch

#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/diffusion.4.text_to_image'
push_to_hub = True
checkpoint = 'CompVis/stable-diffusion-v1-4'

In [2]:
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer

#加载两个工具类
scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')
tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

scheduler, tokenizer

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.12.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 CLIPTokenizer(name_or_path='CompVis/stable-diffusion-v1-4', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}))

In [3]:
from datasets import load_dataset


def get_dataset():
    #加载数据集
    dataset = load_dataset('lambdalabs/pokemon-blip-captions')
    
    print(dataset, dataset['train'][0])
    
    #文字编码
    def f(data):
        #77 = tokenizer.model_max_length
        token = tokenizer.batch_encode_plus(data['text'],
                                            max_length=77,
                                            padding='max_length',
                                            truncation=True,
                                            return_tensors=None)

        data['input_ids'] = token['input_ids']

        return data


    dataset = dataset.map(f,
                          batched=True,
                          batch_size=100,
                          num_proc=1,
                          remove_columns=['text'])

    return dataset


if push_to_hub:
    dataset = get_dataset()
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#直接使用我处理好的数据集
dataset = load_dataset(path=repo_id, split='train')

print(dataset, dataset[0])

Using custom data configuration lambdalabs--pokemon-blip-captions-10e3527a764857bd
Found cached dataset parquet (/root/.cache/huggingface/datasets/lambdalabs___parquet/lambdalabs--pokemon-blip-captions-10e3527a764857bd/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'text'],
        num_rows: 833
    })
}) {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x1280 at 0x7FCFD1469C40>, 'text': 'a drawing of a green pokemon with red eyes'}


Loading cached processed dataset at /root/.cache/huggingface/datasets/lambdalabs___parquet/lambdalabs--pokemon-blip-captions-10e3527a764857bd/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-2d5136ee1c738262.arrow
Pushing split train to the Hub.


  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b


Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/99.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/833 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.4.text_to_image-ef2bcab260392e0b/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.
Dataset({
    features: ['image', 'input_ids'],
    num_rows: 833
}) {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x1280 at 0x7FCFD13F5D90>, 'input_ids': [49406, 320, 3610, 539, 320, 1901, 9528, 593, 736, 3095, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]}


In [4]:
import torchvision

#图像增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])


#应用图像数据增强
def f(data):
    pixel_values = [compose(i) for i in data['image']]

    return {'pixel_values': pixel_values, 'input_ids': data['input_ids']}


#因为图像增强在每一个epoch中是动态计算的,所以不能简单地用map处理
dataset = dataset.with_transform(f)

print(dataset, dataset[0])

Dataset({
    features: ['image', 'input_ids'],
    num_rows: 833
}) {'pixel_values': tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]]), 'input_ids': [49406, 320, 3610, 539, 320, 1901, 9528, 593, 736, 3095, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 4

In [5]:
#定义loader
def collate_fn(data):
    pixel_values = torch.stack([i['pixel_values'] for i in data])
    input_ids = torch.LongTensor([i['input_ids'] for i in data])
    return {'pixel_values': pixel_values, 'input_ids': input_ids}


loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     batch_size=1)

len(loader), next(iter(loader))

(833,
 {'pixel_values': tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]],
  
           [[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]],
  
           [[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]]]]),
  'input_ids': tensor([[49406,   320,  3610,   539,   320,  7651,  9465,  5050,   320,  4987,
           49407, 49407, 4940

In [6]:
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel

#加载3个模型
encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

vae.requires_grad_(False)
encoder.requires_grad_(False)
unet.enable_gradient_checkpointing()


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [7]:
optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

optimizer, criterion

(AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [8]:
def get_loss(data):
    device = data['input_ids'].device

    #文字编码
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #抽取图像特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()
    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机数,unet的计算目标
    noise = torch.randn_like(out_vae)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #1 = out_vae.shape[0]
    noise_step = torch.randint(0, 1000, (1, )).long()
    noise_step = noise_step.to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #根据文字信息,把特征图中的噪声计算出来
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    #计算mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512)
})

tensor(0.0850, grad_fn=<MseLossBackward0>)

In [9]:
from diffusers import StableDiffusionPipeline
from huggingface_hub import Repository, create_repo


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    unet.to(device)
    encoder.to(device)
    vae.to(device)
    unet.train()

    loss_sum = 0
    for epoch in range(150):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data) / 4
            loss.backward()
            loss_sum += loss.item()

            if (epoch * len(loader) + i) % 4 == 0:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        if epoch % 2 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #保存
    StableDiffusionPipeline.from_pretrained(
        checkpoint, text_encoder=encoder, vae=vae,
        unet=unet).save_pretrained('./save')


if push_to_hub:
    create_repo(repo_id, exist_ok=True, token=hub_token)
    repo = Repository('./save', clone_from=repo_id, token=hub_token)
    train()
    repo.push_to_hub()

Cloning https://huggingface.co/lansinuote/diffusion.4.text_to_image into local empty directory.


0 11.485293690289836
2 22.174073058733484
4 21.349158763390733
6 21.735034728772007
8 21.165746794169536
10 19.808746901675477
12 21.422606183303287
14 19.779738241253654
16 20.875592649244936
18 20.723903047357453
20 20.24909785448108
22 21.206063293473562
24 20.186157436284702
26 20.302006809739396
28 19.989948059926974
30 20.634496434227913
32 20.226739437202923
34 19.87426323920954
36 20.088639282184886
38 20.3696921497758
40 19.518499594909372
42 19.35646905520116
44 19.544753054447938
46 18.9495884028438
48 19.509514296558336
50 20.509902719059028
52 19.261830052666483
54 19.733987627958413
56 19.839617400284624
58 19.995629528842983
60 20.015139478753554
62 19.416784361950704
64 18.512044332295773
66 19.14228089810058
68 18.59874345047865
70 19.604682393954135
72 18.631113044335507
74 18.933612779015675
76 19.089062765357085
78 18.24768439900072
80 18.213059518297086
82 19.6351735486096
84 17.98689945116348
86 18.266969905249425
88 17.443939730350394
90 19.13027249123843
92 18.5

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading (…)0b28/vae/config.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Upload file unet/diffusion_pytorch_model.bin:   0%|          | 32.0k/3.20G [00:00<?, ?B/s]

Upload file vae/diffusion_pytorch_model.bin:   0%|          | 32.0k/319M [00:00<?, ?B/s]

Upload file safety_checker/pytorch_model.bin:   0%|          | 32.0k/1.13G [00:00<?, ?B/s]

Upload file text_encoder/pytorch_model.bin:   0%|          | 32.0k/470M [00:00<?, ?B/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

