In [1]:
from diffusers import DiffusionPipeline
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipeline = DiffusionPipeline.from_pretrained(
    'lansinuote/diffsion_from_scratch.params', safety_checker=None)

scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer

del pipeline

device, scheduler, tokenizer

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


('cuda',
 PNDMScheduler {
   "_class_name": "PNDMScheduler",
   "_diffusers_version": "0.12.1",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "trained_betas": null
 },
 CLIPTokenizer(name_or_path='/root/.cache/huggingface/diffusers/models--lansinuote--diffsion_from_scratch.params/snapshots/310d9345e14b3b625635041dd573676c008d83ea/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_tok

In [2]:
from datasets import load_dataset
import torchvision

#加载数据集
dataset = load_dataset(path='lansinuote/diffsion_from_scratch', split='train')

#图像增强模块
compose = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    #torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])


def f(data):
    #应用图像增强
    pixel_values = [compose(i) for i in data['image']]

    #文字编码
    input_ids = tokenizer.batch_encode_plus(data['text'],
                                            padding='max_length',
                                            truncation=True,
                                            max_length=77).input_ids

    return {'pixel_values': pixel_values, 'input_ids': input_ids}


dataset = dataset.map(f,
                      batched=True,
                      batch_size=100,
                      num_proc=1,
                      remove_columns=['image', 'text'])

dataset.set_format(type='torch')

dataset, dataset[0]

Using custom data configuration lansinuote--diffsion_from_scratch-34318fc75271f5a0
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffsion_from_scratch-34318fc75271f5a0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffsion_from_scratch-34318fc75271f5a0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-314ecfdae0cf40d9.arrow


(Dataset({
     features: ['pixel_values', 'input_ids'],
     num_rows: 833
 }),
 {'pixel_values': tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]],
  
          [[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]],
  
          [[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]]]),
  'input_ids': tensor([49406,   320,  3610,   539,   320,  1901,  9528

In [3]:
#定义loader
def collate_fn(data):
    pixel_values = [i['pixel_values'] for i in data]
    input_ids = [i['input_ids'] for i in data]

    pixel_values = torch.stack(pixel_values).to(device)
    input_ids = torch.stack(input_ids).to(device)

    return {'pixel_values': pixel_values, 'input_ids': input_ids}


loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     batch_size=1)

len(loader), next(iter(loader))

(833,
 {'pixel_values': tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]],
  
           [[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]],
  
           [[1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            ...,
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.],
            [1., 1., 1.,  ..., 1., 1., 1.]]]], device='cuda:0'),
  'input_ids': tensor([[49406,   320,  3610,   539,   320,  7651,  4009,   530,  3360,   537,
            

In [4]:
#加载模型
%run 1.encoder.ipynb
%run 2.vae.ipynb
%run 3.unet.ipynb

#准备训练
encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(True)

encoder.eval()
vae.eval()
unet.train()

encoder.to(device)
vae.to(device)
unet.to(device)

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

optimizer, criterion

The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


(AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     eps: 1e-08
     foreach: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [5]:
def get_loss(data):
    with torch.no_grad():
        #文字编码
        #[1, 77] -> [1, 77, 768]
        out_encoder = encoder(data['input_ids'])

        #抽取图像特征图
        #[1, 3, 512, 512] -> [1, 4, 64, 64]
        out_vae = vae.encoder(data['pixel_values'])
        out_vae = vae.sample(out_vae)

        #0.18215 = vae.config.scaling_factor
        out_vae = out_vae * 0.18215

    #随机数,unet的计算目标
    noise = torch.randn_like(out_vae)

    #往特征图中添加噪声
    #1000 = scheduler.num_train_timesteps
    #1 = batch size
    noise_step = torch.randint(0, 1000, (1, )).long().to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #根据文字信息,把特征图中的噪声计算出来
    out_unet = unet(out_vae=out_vae_noise,
                    out_encoder=out_encoder,
                    time=noise_step)

    #计算mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)


# get_loss({
#     'input_ids': torch.ones(1, 77, device=device).long(),
#     'pixel_values': torch.randn(1, 3, 512, 512, device=device)
# })

In [6]:
def train():
    loss_sum = 0
    for epoch in range(400):
        for i, data in enumerate(loader):
            loss = get_loss(data) / 4
            loss.backward()
            loss_sum += loss.item()

            if (epoch * len(loader) + i) % 4 == 0:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #torch.save(unet.to('cpu'), 'saves/unet.model')


train()

0 11.7118999005761
10 105.27776907754014
20 101.45478522218764
30 97.96161541804031
40 95.7652038520173
50 92.64628775657911
60 91.62508884524868
70 88.90302349776903
80 84.6358380591555
90 82.70271758512536
100 81.53195204613439
110 76.3927595877758
120 74.14106083381193
130 71.42537906522921
140 69.16221529991162
150 65.47076485656726
160 62.1360088881047
170 60.89056803673884
180 57.985315461344726
190 54.73302427918679
200 50.69724302080431
210 48.59712202517403
220 46.407517315681616
230 44.99496047659704
240 44.07751854383969
250 39.62402399040002
260 37.051896732489695
270 36.89249631060375
280 35.71413582353853
290 33.45783720578038
300 33.08240255239798
310 30.282505852694158
320 29.86848702972202
330 29.363934024146147
340 29.187604612583527
350 27.543819716789585
360 26.130621485815936
370 25.465440133120865
380 25.48384229660587
390 24.789676978944044


In [7]:
from transformers import PreTrainedModel, PretrainedConfig


#包装类
class Model(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.unet = unet.to('cpu')

#保存到hub
Model(PretrainedConfig()).push_to_hub(
    repo_id='lansinuote/diffsion_from_scratch.unet',
    use_auth_token=open('/root/hub_token.txt').read().strip())

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/lansinuote/diffsion_from_scratch.unet/commit/32f5e4163edb6d1a3fa1d8265ad2cdf0406cb425', commit_message='Upload model', commit_description='', oid='32f5e4163edb6d1a3fa1d8265ad2cdf0406cb425', pr_url=None, pr_revision=None, pr_num=None)