In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class TimeEmbedding(nn.Module):
    """
    TimeEmbedding模块将把整型t，以Transformer函数式位置编码的方式，映射成向量，
    其shape为(batch_size, time_channel)
    """

    def __init__(self, n_channels: int):
        """
        Params:
            n_channels：即time_channel
        """
        super().__init__()
        self.n_channels = n_channels
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        self.act = nn.ReLU()
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        """
        Params:
            t: 维度（batch_size），整型时刻t
        """
        # 以下转换方法和Transformer的位置编码一致
        # 【强烈建议大家动手跑一遍，打印出每一个步骤的结果和尺寸，更方便理解】
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        # 输出维度(batch_size, time_channels)
        return emb
    
Timeebd = TimeEmbedding(32)
t = torch.randint(0,100,(32,))
print(t)
print(Timeebd(t).shape)

tensor([69,  5, 61, 75, 19, 48, 90,  2, 13, 95, 30,  7, 51, 63,  6, 31, 98,  9,
        33, 98, 84, 90, 96, 60, 79, 63, 49, 79,  7, 39, 77, 11])
torch.Size([32, 32])


In [None]:
def gather(alpha_bar, t):
    return alpha_bar[t]

class DenoiseDiffusion:
    """
    Denoise Diffusion
    """

    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        """
        Params:
            eps_model: UNet去噪模型，我们将在下文详细解读它的架构。
            n_steps：训练总步数T
            device：训练所用硬件
        """
        super().__init__()
        # 定义UNet架构模型
        self.eps_model = eps_model
        # 人为设置超参数beta，满足beta随着t的增大而增大，同时将beta搬运到训练硬件上
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        # 根据beta计算alpha（参见数学原理篇）
        self.alpha = 1. - self.beta
        # 根据alpha计算alpha_bar（参见数学原理篇）
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        # 定义训练总步长
        self.n_steps = n_steps
        # sampling中的sigma_t
        self.sigma2 = self.beta

    
    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor):
        """
        Diffusion Process的中间步骤，根据x0和t，推导出xt所服从的高斯分布的mean和var
        Params:
            x0：来自训练数据的干净的图片
            t：某一步time_step
        Return:
            mean: xt所服从的高斯分布的均值
            var：xt所服从的高斯分布的方差
        """

        # ----------------------------------------------------------------
        # gather：人为定义的函数，从一连串超参中取出当前t对应的超参alpha_bar
        # 由于xt = sqrt(alpha_bar_t) * x0 + sqrt(1-alpha_bar_t) * epsilon
        # 其中epsilon~N(0, I)
        # 因此根据高斯分布性质，xt~N(sqrt(alpha_bar_t) * x0, 1-alpha_bar_t)
        # 即为本步中我们要求的mean和var
        # ----------------------------------------------------------------
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        var = 1 - gather(self.alpha_bar, t)

        return mean, var

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps = None):
        """
        Diffusion Process，根据xt所服从的高斯分布的mean和var，求出xt
        Params:
            x0：来自训练数据的干净的图片
            t：某一步time_step
        Return:
            xt: 第t时刻加完噪声的图片
        """

        # ----------------------------------------------------------------
        # xt = sqrt(alpha_bar_t) * x0 + sqrt(1-alpha_bar_t) * epsilon
        #    = mean + sqrt(var) * epsilon
        # 其中，epsilon~N(0, I)
        # ----------------------------------------------------------------
        if eps is None:
            eps = torch.randn_like(x0)
       
        mean, var = self.q_xt_x0(x0, t)
        return mean + (var ** 0.5) * eps#正向过程得到噪音

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        """
        Sampling, 当模型训练好之后，根据x_t和t，推出x_{t-1}
        Params:
            x_t：t时刻的图片
            t：某一步time_step
        Return:
            x_{t-1}: 第t-1时刻的图片
        """

        # eps_model: 训练好的UNet去噪模型
        # eps_theta: 用训练好的UNet去噪模型，预测第t步的噪声
        eps_theta = self.eps_model(xt, t)
        
        # 根据Sampling提供的公式，推导出x_{t-1}
        alpha_bar = gather(self.alpha_bar, t)       
        alpha = gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = gather(self.sigma2, t)
        eps = torch.randn(xt.shape, device=xt.device)
 
        return mean + (var ** .5) * eps

    def loss(self, x0: torch.Tensor, noise= None):
        """
        1. 随机抽取一个time_step t
        2. 执行diffusion process(q_sample)，随机生成噪声epsilon~N(0, I)，
           然后根据x0, t和epsilon计算xt
        3. 使用UNet去噪模型（p_sample），根据xt和t得到预测噪声epsilon_theta
        4. 计算mse_loss(epsilon, epsilon_theta)
        
        【MSE只是众多可选loss设计中的一种，大家也可以自行设计loss函数】
        
        Params:
            x0：来自训练数据的干净的图片
            noise: diffusion process中随机抽样的噪声epsilon~N(0, I)
        Return:
            loss: 真实噪声和预测噪声之间的loss         
        """
        
        batch_size = x0.shape[0]
        # 随机抽样t
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)#每个Batch随机选一个t
        
        # 如果为传入噪声，则从N(0, I)中抽样噪声
        if noise is None:
            noise = torch.randn_like(x0)

        # 执行Diffusion process，计算xt
        xt = self.q_sample(x0, t, eps=noise)
        # 执行Denoise Process，得到预测的噪声epsilon_theta
        eps_theta = self.eps_model(xt, t)
        
        # 返回真实噪声和预测噪声之间的mse loss
        return F.mse_loss(noise, eps_theta)

In [None]:
beta = torch.linspace(0.0001, 0.02, 121)
        # 根据beta计算alpha（参见数学原理篇）
alpha = 1. - beta
        # 根据alpha计算alpha_bar（参见数学原理篇）
alpha_bar = torch.cumprod(alpha, dim=0)
print(alpha_bar)    
def gather(alpha_bar, t):
    return alpha_bar[t]

tensor([0.9999, 0.9996, 0.9992, 0.9986, 0.9978, 0.9969, 0.9958, 0.9946, 0.9932,
        0.9916, 0.9898, 0.9879, 0.9859, 0.9836, 0.9812, 0.9787, 0.9760, 0.9732,
        0.9702, 0.9670, 0.9637, 0.9603, 0.9567, 0.9529, 0.9490, 0.9450, 0.9408,
        0.9365, 0.9321, 0.9275, 0.9228, 0.9180, 0.9130, 0.9079, 0.9027, 0.8974,
        0.8919, 0.8864, 0.8807, 0.8749, 0.8690, 0.8630, 0.8569, 0.8507, 0.8444,
        0.8380, 0.8316, 0.8250, 0.8184, 0.8116, 0.8048, 0.7979, 0.7910, 0.7839,
        0.7768, 0.7697, 0.7624, 0.7552, 0.7478, 0.7404, 0.7330, 0.7255, 0.7180,
        0.7104, 0.7028, 0.6951, 0.6875, 0.6798, 0.6720, 0.6643, 0.6565, 0.6487,
        0.6409, 0.6331, 0.6252, 0.6174, 0.6095, 0.6017, 0.5939, 0.5860, 0.5782,
        0.5704, 0.5625, 0.5547, 0.5470, 0.5392, 0.5315, 0.5237, 0.5160, 0.5084,
        0.5007, 0.4931, 0.4856, 0.4780, 0.4705, 0.4631, 0.4556, 0.4483, 0.4409,
        0.4337, 0.4264, 0.4192, 0.4121, 0.4050, 0.3980, 0.3910, 0.3841, 0.3773,
        0.3705, 0.3637, 0.3571, 0.3504, 