In [1]:
#第14章/全局常量
repo_id = 'lansinuote/diffusion.8.instruct_pix2pix'
checkpoint = 'runwayml/stable-diffusion-v1-5'

In [2]:
#第14章/加载数据集
from datasets import load_dataset
from transformers import CLIPTokenizer
import torchvision

tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')

compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.ToTensor(),
    lambda x: (x * 2) - 1,
])

#转载自fusing/instructpix2pix-1000-samples
dataset = load_dataset(path=repo_id, split='train')


def f(data):
    #图像编码
    input = [compose(i) for i in data['input']]
    output = [compose(i) for i in data['output']]

    #文字编码
    #77 = tokenizer.model_max_length
    text = tokenizer.batch_encode_plus(data['text'],
                                       max_length=77,
                                       padding='max_length',
                                       truncation=True,
                                       return_tensors='pt').input_ids

    return {'input': input, 'output': output, 'text': text}


dataset = dataset.with_transform(f)

for k, v in dataset[0].items():
    print(k, v.shape, v.dtype)

dataset

Using custom data configuration lansinuote--diffusion.8.instruct_pix2pix-5db27ce94b1e4a9e
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.8.instruct_pix2pix-5db27ce94b1e4a9e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


input torch.Size([3, 256, 256]) torch.float32
output torch.Size([3, 256, 256]) torch.float32
text torch.Size([77]) torch.int64


Dataset({
    features: ['input', 'text', 'output'],
    num_rows: 1000
})

In [3]:
#第14章/定义loader
import torch

loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=None,
                                     batch_size=4)

for k, v in next(iter(loader)).items():
    print(k, v.shape, v.dtype)

len(loader)

input torch.Size([4, 3, 256, 256]) torch.float32
output torch.Size([4, 3, 256, 256]) torch.float32
text torch.Size([4, 77]) torch.int64


250

In [4]:
#第14章/加载模型
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel

#加载3个模型
encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')

#修改unet.conv_in层的形状
unet.register_to_config(in_channels=8)
with torch.no_grad():
    new_conv_in = torch.nn.Conv2d(8, 320, 3, 1, 1)
    new_conv_in.weight.zero_()
    new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
    unet.conv_in = new_conv_in
print(unet.config.in_channels)


def print_model_size(name, model):
    print(name, sum(i.numel() for i in model.parameters()) / 10000)


print_model_size('encoder', encoder)
print_model_size('vae', vae)
print_model_size('unet', unet)

8
encoder 12306.048
vae 8365.3863
unet 85953.2484


In [5]:
#第14章/初始化工具类
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=5e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

scheduler, optimizer, criterion

(DDPMScheduler {
   "_class_name": "DDPMScheduler",
   "_diffusers_version": "0.15.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "clip_sample_range": 1.0,
   "dynamic_thresholding_ratio": 0.995,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "sample_max_value": 1.0,
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "thresholding": false,
   "trained_betas": null,
   "variance_type": "fixed_small"
 },
 AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 5e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [6]:
#第14章/定义dropout部分数据的辅助函数
def dropout_data(out_encoder, input):
    #输入图编码
    #[4, 3, 256, 256] -> [4, 4, 32, 32]
    out_vae_input = vae.encode(input).latent_dist.mode()

    #生成mask
    r = torch.rand(4, device=out_encoder.device)
    #[4, 1, 1]
    mask_text = (r > 0.1).reshape(4, 1, 1)
    #[4, 1, 1, 1]
    mask_image = torch.logical_or(r < 0.05,
                                  r > 0.15).float().reshape(4, 1, 1, 1)

    #编码负采样的文本
    out_encoder_neg = tokenizer.batch_encode_plus(
        [''],
        max_length=77,
        padding='max_length',
        truncation=True,
        return_tensors='pt').input_ids.to(out_encoder.device)
    out_encoder_neg = encoder(out_encoder_neg)[0]

    #使用mask混合正负编码
    #文本大概率选择正编码
    #[4, 77, 768]
    out_encoder = torch.where(mask_text, out_encoder, out_encoder_neg)

    #图像小概率归零
    #[4, 4, 32, 32]
    out_vae_input = mask_image * out_vae_input

    return out_encoder, out_vae_input


out = dropout_data(out_encoder=torch.randn(4, 77, 768),
                   input=torch.randn(4, 3, 256, 256))

out[0].shape, out[1].shape

(torch.Size([4, 77, 768]), torch.Size([4, 4, 32, 32]))

In [7]:
#第14章/定义计算loss的函数
def get_loss(data):
    #文字编码
    #[4, 77] -> [4, 77, 768]
    out_encoder = encoder(data['text'])[0]

    #输出图编码
    #[4, 3, 256, 256] -> [4, 4, 32, 32]
    out_vae_output = vae.encode(data['output']).latent_dist.sample()
    #0.18215 = vae.config.scaling_factor
    out_vae_output = out_vae_output * 0.18215

    #随机噪声
    #[4, 4, 32, 32]
    noise = torch.randn_like(out_vae_output)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #4 = out_vae.shape[0]
    noise_step = torch.randint(0, 1000, (4, )).long()
    noise_step = noise_step.to(out_encoder.device)
    #[4, 4, 32, 32]
    out_vae_noise = scheduler.add_noise(out_vae_output, noise, noise_step)

    #使用mask组合正负采样的文本编码数据
    #输入图编码
    #[4, 77, 768],[4, 4, 32, 32]
    out_encoder, out_vae_input = dropout_data(out_encoder=out_encoder,
                                              input=data['input'])

    #向out_vae_noise中组合输入图的数据
    #[4, 4+4, 32, 32] -> [4, 8, 32, 32]
    out_vae_noise = torch.cat([out_vae_noise, out_vae_input], dim=1)

    #根据文字信息,把特征图中的噪声计算出来
    #[4, 4, 32, 32]
    out_unet = unet(out_vae_noise,
                    noise_step,
                    encoder_hidden_states=out_encoder).sample

    #计算loss
    #[4, 4, 32, 32],[4, 4, 32, 32]
    return criterion(out_unet, noise)


get_loss({
    'text': torch.ones(4, 77).long(),
    'input': torch.randn(4, 3, 256, 256),
    'output': torch.randn(4, 3, 256, 256),
})

tensor(0.0932, grad_fn=<MseLossBackward0>)

In [8]:
#第14章/训练
from diffusers import StableDiffusionPipeline


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    unet.to(device)
    encoder.to(device)
    vae.to(device)

    vae.requires_grad_(False)
    encoder.requires_grad_(False)
    unet.train()

    loss_sum = 0
    for epoch in range(4000):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)

            loss = get_loss(data) / 4
            loss.backward()
            loss_sum += loss.item()

            if i % 4 == 0:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                
        if (epoch + 1) % 20 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

            #保存
            StableDiffusionPipeline.from_pretrained(
                checkpoint, text_encoder=encoder, vae=vae,
                unet=unet).save_pretrained('./save')


#train()