In [1]:
import torch

#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/diffusion.2.textual_inversion'
push_to_hub = True
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')
tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

scheduler, tokenizer

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.12.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 CLIPTokenizer(name_or_path='runwayml/stable-diffusion-v1-5', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}))

In [3]:
from datasets import Dataset, DatasetDict, load_dataset
import PIL.Image
import numpy as np


def get_dataset():
    images = [{
        'image': PIL.Image.open('images/%d.jpeg' % i)
    } for i in range(6)]

    dataset = Dataset.from_list(images)

    return DatasetDict({'train': dataset})


if push_to_hub:
    dataset = get_dataset()
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#直接使用我处理好的数据集
dataset = load_dataset(path=repo_id, split='train')

dataset, dataset[0]

Pushing split train to the Hub.


  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09


Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.66M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/6 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


(Dataset({
     features: ['image'],
     num_rows: 6
 }),
 {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1167x1010>})

In [4]:
import torchvision
import random

#描述文本
texts = [
    'a photo of a <cat-toy>', 'a rendering of a <cat-toy>',
    'a cropped photo of the <cat-toy>', 'the photo of a <cat-toy>',
    'a photo of a clean <cat-toy>', 'a photo of a dirty <cat-toy>',
    'a dark photo of the <cat-toy>', 'a photo of my <cat-toy>',
    'a photo of the cool <cat-toy>', 'a close-up photo of a <cat-toy>',
    'a bright photo of the <cat-toy>', 'a cropped photo of a <cat-toy>',
    'a photo of the <cat-toy>', 'a good photo of the <cat-toy>',
    'a photo of one <cat-toy>', 'a close-up photo of the <cat-toy>',
    'a rendition of the <cat-toy>', 'a photo of the clean <cat-toy>',
    'a rendition of a <cat-toy>', 'a photo of a nice <cat-toy>',
    'a good photo of a <cat-toy>', 'a photo of the nice <cat-toy>',
    'a photo of the small <cat-toy>', 'a photo of the weird <cat-toy>',
    'a photo of the large <cat-toy>', 'a photo of a cool <cat-toy>',
    'a photo of a small <cat-toy>'
]

#数据增强
compose = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
])


def f(data):
    #编码文字
    #77 = tokenizer.model_max_length
    input_ids = tokenizer(random.choice(texts),
                          padding='max_length',
                          truncation=True,
                          max_length=77,
                          return_tensors='pt')['input_ids']

    #编码图片
    pixel_values = []
    for i in range(len(data['image'])):
        image = data['image'][i]

        #数据增强
        image = compose(image)

        #尺寸缩放
        image = image.resize((512, 512), resample=PIL.Image.Resampling.BICUBIC)

        #数值操作
        image = np.array(image).astype(np.uint8)
        image = image / 127.5 - 1.0
        image = image.astype(np.float32)

        #转tensor,把通道数放在前面
        #[512, 512, 3] -> [3, 512, 512]
        image = torch.from_numpy(image).permute(2, 0, 1)

        pixel_values.append(image)

    return {'input_ids': input_ids, 'pixel_values': pixel_values}


dataset = dataset.with_transform(f)

dataset, dataset[0]

(Dataset({
     features: ['image'],
     num_rows: 6
 }),
 {'input_ids': tensor([49406,   320,   886,  1125,   539,   518,   283,  2368,   268,  5988,
            285, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
          49407, 49407, 49407, 49407, 49407, 49407, 49407]),
  'pixel_values': tensor([[[ 0.0196,  0.0196,  0.0196,  ..., -0.2157, -0.2314, -0.2392],
           [ 0.0196,  0.0196,  0.0196,  ..., -0.2078, -0.2314, -0.2392],
           [ 0.0196,  0.0196,  0.0196,  ..., -0.2078, -0.2157, -0.2235],
           ...,
           [ 0.5451,  0.5529,  0.5529,  ..., -0.1765, -0

In [5]:
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

len(loader), next(iter(loader))

(6,
 {'input_ids': tensor([[49406,   320,  1125,   539,   518,  2442,   283,  2368,   268,  5988,
             285, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),
  'pixel_values': tensor([[[[ 0.5529,  0.5451,  0.5451,  ...,  0.4902,  0.4902,  0.4824],
            [ 0.5529,  0.5529,  0.5451,  ...,  0.4745,  0.4745,  0.4902],
            [ 0.5608,  0.5529,  0.5529,  ...,  0.4745,  0.4745,  0.4824],
            ...,
            [ 0.6549,  0.6627,  0.6627,  ...,  0.6863,  0.6784,  0.6784],
            [ 0.6549,  0

In [6]:
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel

encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

encoder 12306.048
vae 8365.3863
unet 85952.0964


In [7]:
def init_new_word():
    #字典里添加新词
    tokenizer.add_tokens('<cat-toy>')

    #扩展encoder的embed层,添加一个新空间用于容纳新词
    encoder.resize_token_embeddings(len(tokenizer))

    #取新旧两个词的id
    old_id = tokenizer.convert_tokens_to_ids('toy')
    new_id = tokenizer.convert_tokens_to_ids('<cat-toy>')

    embed = encoder.get_input_embeddings().weight.data

    #以旧词来初始化新词
    embed[new_id] = embed[old_id]


init_new_word()

In [8]:
#这两个模型不更新参数
vae.requires_grad_(False)
unet.requires_grad_(False)

#只训练encoder.text_model.embeddings.token_embedding层,其他全部锁定
encoder.text_model.encoder.requires_grad_(False)
encoder.text_model.final_layer_norm.requires_grad_(False)
encoder.text_model.embeddings.position_embedding.requires_grad_(False)

Embedding(77, 768)

In [9]:
from diffusers.optimization import get_scheduler

optimizer = torch.optim.AdamW(encoder.get_input_embeddings().parameters(),
                              lr=5e-4,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

optimizer, criterion

(AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 0.0005
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [10]:
def get_loss(data):
    device = data['input_ids'].device

    #编码文字
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #vae计算特征图
    #[1, 3, 512, 512] -> [1, 4, 64, 64]
    out_vae = vae.encode(data['pixel_values']).latent_dist.sample().detach()

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #随机噪声
    #[1, 4, 64, 64]
    noise = torch.randn_like(out_vae)

    #随机噪声步
    #1000 = scheduler.config.num_train_timesteps
    #1 = b
    noise_step = torch.randint(0, 1000, (1, ), device=device).long()

    #添加噪声
    #[1, 4, 64, 64]
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #unet从噪声图中把噪声计算出来
    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]
    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample

    return criterion(out_unet, noise)


get_loss({
    'input_ids': torch.ones(1, 77).long(),
    'pixel_values': torch.randn(1, 3, 512, 512)
})

tensor(0.0532, grad_fn=<MseLossBackward0>)

In [11]:
from diffusers import StableDiffusionPipeline
from huggingface_hub import Repository, create_repo


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    encoder.to(device)
    vae.to(device)
    unet.to(device)
    encoder.train()

    loss_sum = 0
    for epoch in range(800):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data) / 4
            loss.backward()

            #积累更新
            if (epoch * len(loader) + i) % 4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            loss_sum += loss.item()

        if epoch % 20 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #保存
    StableDiffusionPipeline.from_pretrained(
        checkpoint,
        text_encoder=encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer).save_pretrained('./save')


if push_to_hub:
    create_repo(repo_id, exist_ok=True, token=hub_token)
    repo = Repository('./save', clone_from=repo_id, token=hub_token)
    train()
    repo.push_to_hub()

Cloning https://huggingface.co/lansinuote/diffusion.2.textual_inversion into local empty directory.


0 0.13851985079236329
20 3.809695530508179
40 4.4042913969024085
60 4.388479830056895
80 4.4977232606033795
100 4.548502518853638
120 3.693540668406058
140 4.156299739843234
160 3.7243479269673117
180 3.697449386701919
200 3.546344161964953
220 4.308455798542127
240 3.784453941101674
260 3.2318183617899194
280 4.367512361612171
300 3.0064080328447744
320 4.068744384741876
340 3.318315677694045
360 4.715249826433137
380 3.7152999051613733
400 3.955495126952883
420 4.009544232976623
440 4.505315208341926
460 4.091954343370162
480 4.5167977513046935
500 3.992238358419854
520 3.93575657781912
540 4.0084115316858515
560 3.725144908821676
580 4.794721863931045
600 4.073796798766125
620 4.574912750336807
640 4.365607368235942
660 4.622854062763508
680 4.699603373650461
700 3.90476599894464
720 4.237309526128229
740 4.308226598310284
760 4.023056301870383
780 3.4964061353821307


Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

Upload file unet/diffusion_pytorch_model.bin:   0%|          | 32.0k/3.20G [00:00<?, ?B/s]

Upload file text_encoder/pytorch_model.bin:   0%|          | 32.0k/470M [00:00<?, ?B/s]

Upload file safety_checker/pytorch_model.bin:   0%|          | 32.0k/1.13G [00:00<?, ?B/s]

Upload file vae/diffusion_pytorch_model.bin:   0%|          | 32.0k/319M [00:00<?, ?B/s]

remote: Scanning LFS files for validity...[K
remote: LFS file scan complete.[K
To https://user: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@huggingface.co/lansinuote/diffusion.2.textual_inversion
   fff28ba..d1d1389  main -> main

