### 初始化参数

In [1]:
import argparse
import sys
sys.path.append('./')
from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop
import blobfile as bf

In [2]:
import torch as th

In [3]:
diffusion_defaults =  dict(
        learn_sigma=False,
        diffusion_steps=1000,
        noise_schedule="linear",
        timestep_respacing="",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=False,
        rescale_learned_sigmas=False,
    )

In [4]:
model_and_diffusion_defaults = dict(
        image_size=64,
        num_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=-1,
        attention_resolutions="16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=True,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=False,
        use_fp16=False,
        use_new_attention_order=False,
    )
model_and_diffusion_defaults.update(diffusion_defaults)

In [5]:
defaults = dict(
        data_dir="datasets/ILSVRC2012/",
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=1,
        microbatch=-1,  # -1 disables microbatches
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=10,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
    )
defaults.update(model_and_diffusion_defaults)

In [6]:
logger.configure('./log/')

Logging to ./log/openai-2023-01-08-12-52-56-336706


In [7]:
model, diffusion = create_model_and_diffusion(**model_and_diffusion_defaults)

In [8]:
model.to(th.device('cuda:6'))

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (label_emb): Embedding(1000, 512)
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Conv2

In [9]:
defaults['schedule_sampler']

'uniform'

In [10]:
schedule_sampler = create_named_schedule_sampler(defaults['schedule_sampler'], diffusion)

In [11]:
logger.log("creating data loader...")
data = load_data(
    data_dir=defaults['data_dir'] ,
    batch_size=defaults['batch_size'],
    image_size=defaults['image_size'],
    class_cond=defaults['class_cond']
)

creating data loader...


In [12]:
from guided_diffusion.resample import LossAwareSampler, UniformSampler
from torch.optim import AdamW
from guided_diffusion.fp16_util import MixedPrecisionTrainer
import copy
import functools
from guided_diffusion.nn import update_ema
from guided_diffusion.train_util import find_resume_checkpoint, parse_resume_step_from_filename,find_ema_checkpoint,log_loss_dict,get_blob_logdir

class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        # 训练超参数相关
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        # 设置模型保存相关参数
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        # self.global_batch = self.batch_size * dist.get_world_size()
        self.global_batch = self.batch_size

        self.sync_cuda = th.cuda.is_available()
        # 在各个设备上同步参数
        # self._load_and_sync_parameters()
        
        # model 再包了一层 trainer 
        self.mp_trainer = MixedPrecisionTrainer(
            model=self.model,
            use_fp16=self.use_fp16,
            fp16_scale_growth=fp16_scale_growth,
        )

        self.opt = AdamW(
            self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
        )
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.mp_trainer.master_params)
                for _ in range(len(self.ema_rate))
            ]

        # if th.cuda.is_available():
        #     self.use_ddp = True
        #     self.ddp_model = DDP(
        #         self.model,
        #         device_ids=[dist_util.dev()],
        #         output_device=dist_util.dev(),
        #         broadcast_buffers=False,
        #         bucket_cap_mb=128,
        #         find_unused_parameters=False,
        #     )
        # else:
        #     if dist.get_world_size() > 1:
        #         logger.warn(
        #             "Distributed training requires CUDA. "
        #             "Gradients will not be synchronized properly!"
        #         )
        #     self.use_ddp = False
        #     self.ddp_model = self.model
        self.use_ddp = False
        self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            # if dist.get_rank() == 0:
            #     logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
            #     self.model.load_state_dict(
            #         dist_util.load_state_dict(
            #             resume_checkpoint, map_location=dist_util.dev()
            #         )
            #     )

        # dist_util.sync_params(self.model.parameters())

    # def _load_ema_parameters(self, rate):
    #     ema_params = copy.deepcopy(self.mp_trainer.master_params)

    #     main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
    #     ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
    #     if ema_checkpoint:
    #         if dist.get_rank() == 0:
    #             logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
    #             state_dict = dist_util.load_state_dict(
    #                 ema_checkpoint, map_location=dist_util.dev()
    #             )
    #             ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)

    #     dist_util.sync_params(ema_params)
    #     return ema_params

    # def _load_optimizer_state(self):
    #     main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
    #     opt_checkpoint = bf.join(
    #         bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
    #     )
    #     if bf.exists(opt_checkpoint):
    #         logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
    #         state_dict = dist_util.load_state_dict(
    #             opt_checkpoint, map_location=th.device('cuda:6')
    #         )
    #         self.opt.load_state_dict(state_dict)

    def run_loop(self):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.save_interval == 0:
                self.save()
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        took_step = self.mp_trainer.optimize(self.opt)
        if took_step:
            self._update_ema()
        self._anneal_lr()
        self.log_step()

    def forward_backward(self, batch, cond):
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to( th.device('cuda:6'))
            micro_cond = {
                k: v[i : i + self.microbatch].to(th.device('cuda:6'))
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            # 这里 t, weights 是什么东西
            t, weights = self.schedule_sampler.sample(micro.shape[0], th.device('cuda:6'))

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            self.mp_trainer.backward(loss)

    def _update_ema(self):
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.mp_trainer.master_params, rate=rate)

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            
            logger.log(f"saving model {rate}...")
            if not rate:
                filename = f"model{(self.step+self.resume_step):06d}.pt"
            else:
                filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
            with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                th.save(state_dict, f)

        save_checkpoint(0, self.mp_trainer.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)
        with bf.BlobFile(
            bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
            "wb",
        ) as f:
            th.save(self.opt.state_dict(), f)


In [13]:
logger.log("training...")
TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=defaults['batch_size'],
    microbatch=defaults['microbatch'],
    lr=defaults['lr'],
    ema_rate=defaults['ema_rate'],
    log_interval=defaults['log_interval'],
    save_interval=defaults['save_interval'],
    resume_checkpoint=defaults['resume_checkpoint'],
    use_fp16=defaults['use_fp16'],
    fp16_scale_growth=defaults['fp16_scale_growth'],
    schedule_sampler=schedule_sampler,
    weight_decay=defaults['weight_decay'],
    lr_anneal_steps=defaults['lr_anneal_steps'],
).run_loop()

training...
-------------------------
| grad_norm  | 0.807    |
| loss       | 0.996    |
| loss_q0    | 0.996    |
| mse        | 0.996    |
| mse_q0     | 0.996    |
| param_norm | 743      |
| samples    | 1        |
| step       | 0        |
-------------------------
saving model 0...
saving model 0.9999...
-------------------------
| grad_norm  | 3.44     |
| loss       | 0.917    |
| loss_q0    | 0.893    |
| loss_q1    | 0.939    |
| loss_q2    | 0.911    |
| loss_q3    | 0.915    |
| mse        | 0.917    |
| mse_q0     | 0.893    |
| mse_q1     | 0.939    |
| mse_q2     | 0.911    |
| mse_q3     | 0.915    |
| param_norm | 743      |
| samples    | 11       |
| step       | 10       |
-------------------------
-------------------------
| grad_norm  | 3.17     |
| loss       | 0.732    |
| loss_q0    | 0.724    |
| loss_q1    | 0.759    |
| loss_q2    | 0.749    |
| loss_q3    | 0.707    |
| mse        | 0.732    |
| mse_q0     | 0.724    |
| mse_q1     | 0.759    |
| mse_q2   

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [0,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIn

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.