In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from torch.distributions import Normal,Uniform
from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli
from torch import Tensor ,optim 
from typing import Optional, Union, Tuple
from abc import abstractmethod, ABC 
import functools
import math
from torch.distributions.categorical import Categorical as torch_Categorical
import numpy as np

In [None]:
def idx_to_float(idx: np.ndarray, num_bins: int):
    """将离散化区间索引 k 转换为对应的区间中心值 k_c.
    注意, 此处 k 的取值范围与论文中的不同, 论文中 k 的取值范围是 1~K, 而这里:
    k_c = \frac{2k+1}{K} - 1, where k \in [0, K-1]."""
    
    flt_zero_one = (idx + 0.5) / num_bins
    return (2.0 * flt_zero_one) - 1.0


def float_to_idx(flt: np.ndarray, num_bins: int):
    """根据离散化值 k_c 计算出对应的区间索引 k, 是 float_to_idx() 的逆向操作."""
    
    flt_zero_one = (flt / 2.0) + 0.5
    return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()


def quantize(flt, num_bins: int):
    """将浮点值量化以对应的离散化区间中点 k_c 表示, 因此看作是一个量化的过程."""
    return idx_to_float(float_to_idx(flt, num_bins), num_bins)


def rgb_image_transform(x, num_bins=256):
    """将 RGB 图像进行离散化, 其中 x \in [0,1]"""
    return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()
def sandwich(x: Tensor):
    return x.reshape(x.size(0), -1, x.size(-1))

## Output distribution

In [None]:
CONST_log_min = 1e-10


def safe_log(data: Tensor):
    return data.clamp(min=CONST_log_min).log()


class CtsDistribution:
    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass


class DiscreteDistribution:
    @property
    @abstractmethod
    def probs(self):
        pass

    @functools.cached_property
    def log_probs(self):
        return safe_log(self.probs)

    @functools.cached_property
    def mean(self):
        pass

    @functools.cached_property
    def mode(self):
        pass

    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass

In [None]:
class DiscretizedDistribution(DiscreteDistribution):
    def __init__(self, num_bins, device):
        # 离散区间数量: K
        self.num_bins = num_bins
        # 原数据取值范围是[-1,1], 如今划分为 K 个区间, 因此每个区间宽度是 2/K.
        self.bin_width = 2.0 / num_bins
        self.half_bin_width = self.bin_width / 2.0

        self.device = device

    @functools.cached_property
    def class_centres(self):
        # 类别中心的取值范围: [-1 + 1/K, 1 - 1/K]
        return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)

    @functools.cached_property
    def class_boundaries(self):
        # 各类别之间的边界: [-1 + 2/K, 1 - 2/K], 共 K-1 个.
        return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)

    @functools.cached_property
    def mean(self):
        # 将各类别中心用它们各自所对应的概率加权求和: \sum_{k=1}^K{p_k * k_c}
        return (self.probs * self.class_centres).sum(-1)

    @functools.cached_property
    def mode(self):
        """概率分布的 mode, 代表众数, 即概率最高处所对应的样本."""

        # 因为 class_centres 是1维的, 所以这里需要将索引展平.
        mode_idx = self.probs.argmax(-1).flatten()
        return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])

In [None]:
class DiscretizedCtsDistribution(DiscretizedDistribution):
    """将一个连续型分布离散化."""
    
    def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):
        super().__init__(num_bins, device)

        # 原来的连续型分布, 要对其进行离散化处理.
        self.cts_dist = cts_dist
        # log(2/K)
        self.log_bin_width = np.log(self.bin_width)
        # B
        self.batch_dims = batch_dims
        
        # 是否要对原来连续型分布的 CDF 做截断.
        self.clip = clip
        # 用作概率的极小值
        self.min_prob = min_prob

    @functools.cached_property
    def probs(self):
        """计算数据位于各离散区间的概率."""

        # shape: [K-1] + [1] * B
        bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))
        # shape: [1] + [1] * B
        bdry_slice = bdry_cdfs[:1]
        
        if self.clip:
            '''对原来连续型分布的 CDF 做截断: 小于第一个区间的左端概率置0、小于等于最后一个区间右端的概率置1.'''
            
            cdf_min = torch.zeros_like(bdry_slice)
            cdf_max = torch.ones_like(bdry_slice)
            # shape: [K+1] + [1] * B
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)

            # 利用 CDF(k_r) - CDF(k_l) 得到位于各区间的概率.
            # shape: [1] * B + [K]
            return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)
        else:
            '''以条件概率的思想来计算数据位于各区间的概率，其中的条件就是数据位于 [-1,1] 取值范围内.
            先计算原连续型分布在 1 和 -1 处的 CDF 值，将两者作差从而得到位于 [-1,1] 内的概率，以此作为条件对各区间的概率进行缩放.'''

            # CDF(-1)
            cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)
            # CDF(1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))
            # shape: [K+1] + [1] * B
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)

            # p_{-1 < x <= 1}
            cdf_range = cdf_max - cdf_min
            cdf_mask = cdf_range < self.min_prob
            # 当 cdf_range 小于就以 1 代替, 避免作为分母时造成结果溢出.
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)

            # shape: [K] + [1] * B
            probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range
            # 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将每个区间的概率都用 1/K 即均等的概率代替.
            probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)

            # shape: [1] * B + [K]
            return probs.moveaxis(0, -1)

    def prob(self, x):
        # 区间索引 k \in [0, K-1]
        class_idx = float_to_idx(x, self.num_bins)
        # 区间中心 k_c
        centre = idx_to_float(class_idx, self.num_bins)
        # CDF(k_l), 其中 k_l 代表区间左端点.
        cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)
        # CDF(k_r), 其中 k_r 代表区间右端点.
        cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)
        
        if self.clip:
            '''对原来连续型分布的 CDF 做截断, 使得:
            CDF(k <= 0) = 0;
            CDF(k >= K-1) = 1'''
            
            cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)
            cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)
            
            return cdf_hi - cdf_lo
        else:
            '''以条件概率的思想来计算数据位于某个离散区间内的概率，其中的条件就是数据位于 [-1,1] 取值范围内.
            先计算原连续型分布在 1 和 -1 处的 CDF 值，将两者作差从而得到位于 [-1,1] 内的概率，以此作为条件对区间的概率进行缩放.'''
            
            cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(centre))
            cdf_range = cdf_max - cdf_min
            
            # 若 cdf_range 太小，则设置 mask，并将其以1代替，即不对区间的概率进行缩放, 否则会使得计算出来的采样概率非常接近于1.
            # 两个非常小的值相除, 由于它们都很小、非常接近，因此商接近于1.
            cdf_mask = cdf_range < self.min_prob
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
            prob = (cdf_hi - cdf_lo) / cdf_range
            
            # 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将区间的概率都用 1/K 即均等的概率代替.
            return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)

    def log_prob(self, x):
        prob = self.prob(x)

        return torch.where(
            prob < self.min_prob,
            # 将 x 以对应区间的中点 k_c 表示并计算出其在原来连续分布中的对数概率密度: log(p(k_c)).
            # 这里加上 log(2/K) 相当于将 k_c 乘以 2/K 再取对数.
            self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,
            safe_log(prob),
        )

    def sample(self, sample_shape=torch.Size([])):
        if self.clip:
            # 直接从原来的连续型分布中采样, 然后将其量化至对应的离散化区间.
            # 此处, clip 的意思是:
            # 若小于第一个区间，则以第一个区间中点表示；
            # 同理，若大于最后一个区间，则以最后一个区间的中点表示.
            return quantize(self.cts_dist.sample(sample_shape), self.num_bins)
        else:
            # 要求原来连续型分布的 CDF 存在反函数, 即可以根据概率值逆向求出对应的样本.
            assert hasattr(self.cts_dist, "icdf")
            
            # 数据的取值范围是 [-1,1], 先根据原来的连续型分布计算出 CDF(-1) 和 CDF(1),
            # 然后利用 CDF 的反函数仅在这个 range 内考虑采样.
            cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))

            # 由于 CDF 是服从均匀分布的, 因此从均匀分布中采样出 CDF 值并利用反函数求出对应样本就等价于从目标分布中采样.
            u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)
            cts_samp = self.cts_dist.icdf(u)

            # 最后将样本量化至对应的离散化区间.
            # 注意, 与前面 clip 的方式不同, 此处在量化前样本已经处于有效的离散化区间内了, 因为采样区间是在[-1,1]内考虑的.
            return quantize(cts_samp, self.num_bins)

In [None]:
CONST_exp_range = 10


def safe_exp(data: Tensor):
    return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()


class DiscretizedNormal(DiscretizedCtsDistribution):
    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        assert params.size(-1) == 2
        
        if min_std_dev < 0:
            min_std_dev = 1.0 / (num_bins * 5)
            
        mean, std_dev = params.split(1, -1)[:2]
        if log_dev:
            # 若传入的是对数标准差, 那么此处就需要取自然指数进行还原.
            std_dev = safe_exp(std_dev)
        std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)
        
        super().__init__(
            cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),
            num_bins=num_bins,
            device=params.device,
            # 注意所谓的 batch dims 并非指数据的 batch size,
            # 而是除离散化区间数量以外与分布本身关系不大的其它维度.
            batch_dims=params.ndim - 1,
            clip=clip,
            min_prob=min_prob,
        )

In [None]:
class DeltaDistribution(CtsDistribution):
    def __init__(self, mean, clip_range=1.0):
        if clip_range > 0:
            mean = mean.clip(min=-clip_range, max=clip_range)
        self.mean = mean

    def mode(self):
        return self.mean

    def mean(self):
        return self.mean

    def sample(self, sample_shape=torch.Size([])):
        return self.mean

In [None]:
class Bernoulli(DiscreteDistribution):
    def __init__(self, logits):
        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)

    @functools.cached_property
    def probs(self):
        p = self.bernoulli.probs.unsqueeze(-1)
        return torch.cat([1 - p, p], -1)

    @functools.cached_property
    def mode(self):
        return self.bernoulli.mode

    def log_prob(self, x):
        return self.bernoulli.log_prob(x.float())

    def sample(self, sample_shape=torch.Size([])):
        return self.bernoulli.sample(sample_shape)

In [None]:
class Categorical(DiscreteDistribution):
    def __init__(self, logits):
        self.categorical = torch_Categorical(logits=logits, validate_args=False)
        self.n_classes = logits.size(-1)

    @functools.cached_property
    def probs(self):
        return self.categorical.probs

    @functools.cached_property
    def mode(self):
        return self.categorical.mode

    def log_prob(self, x):
        return self.categorical.log_prob(x)

    def sample(self, sample_shape=torch.Size([])):
        return self.categorical.sample(sample_shape)

In [None]:
def noise_pred_params_to_data_pred_params(
    noise_pred_params: torch.Tensor, input_mean: torch.Tensor,
    t: torch.Tensor, min_variance: float, min_t=1e-6
):
    """Convert output parameters that predict the noise added to data, to parameters that predict the data.
    将模型预测的噪声分布的参数转换为数据分布的参数."""

    # (B,L,D)
    data_shape = list(noise_pred_params.shape)[:-1]
    # (B,L*D,NP), NP: num parameters per data
    noise_pred_params = sandwich(noise_pred_params)
    # (B,L*D)
    input_mean = input_mean.flatten(start_dim=1)
    
    if torch.is_tensor(t):
        t = t.flatten(start_dim=1)
    else:
        t = (input_mean * 0) + t
        
    # (B,L*D,1)
    alpha_mask = (t < min_t).unsqueeze(-1)
    
    # \sigma_1^{2t}
    posterior_var = torch.pow(min_variance, t.clamp(min=min_t))
    # \gamma(t) = 1 - \sigma_1^{2t}
    gamma = 1 - posterior_var

    # \frac{\mu}{\gamma(t)}
    A = (input_mean / gamma).unsqueeze(-1)
    # \sqrt{\frac{1-\gamma(t)}{\gamma(t)}}
    B = (posterior_var / gamma).sqrt().unsqueeze(-1)
    
    data_pred_params = []
    
    # 对应建模连续数据的场景: 模型预测的是噪声向量.
    if noise_pred_params.size(-1) == 1:
        noise_pred_mean = noise_pred_params
    # 对应建模离散化数据的场景: 模型预测的是噪声分布的均值与对数标准差. 
    elif noise_pred_params.size(-1) == 2:
        noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)
    else:
        assert noise_pred_params.size(-1) % 3 == 0
        mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)
        data_pred_params.append(mix_wt_logits)

    # 连续数据: x = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \epsilon
    # 离散化数据: \mu_{x} = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \mu_{\epsilon}
    data_pred_mean = A - (B * noise_pred_mean)
    # 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的均值置0
    data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)
    data_pred_params.append(data_pred_mean)
    
    if noise_pred_params.size(-1) >= 2:
        # 将对数标准差取自然指数复原: exp(ln(\sigma_{\epsilon})) -> \sigma_{\epsilon}
        noise_pred_dev = safe_exp(noise_pred_log_dev)
        # 将噪声分布的标准差转换为目标数据分布的标准差: \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} exp(ln(\sigma_{\epsilon})) -> \mu_x
        data_pred_dev = B * noise_pred_dev
        # 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的标准差置1
        data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)
        data_pred_params.append(data_pred_dev)

    # (B,L*D,NP)
    data_pred_params = torch.cat(data_pred_params, -1)
    # (B,L,D,NP)
    data_pred_params = data_pred_params.reshape(data_shape + [-1])
    
    return data_pred_params

In [None]:
class CtsDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:
        """Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory."""
        pass


class DiscreteDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:
        """Note: input_params and t are only required by PredDistToDataDistFactory."""
        pass


class DiscretizedNormalFactory(DiscreteDistributionFactory):
    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        self.num_bins = num_bins
        self.clip = clip
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev
        self.min_prob = min_prob
        self.log_dev = log_dev

    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedNormal(
            params,
            num_bins=self.num_bins,
            clip=self.clip,
            min_std_dev=self.min_std_dev,
            max_std_dev=self.max_std_dev,
            min_prob=self.min_prob,
            log_dev=self.log_dev,
        )


class DeltaFactory(CtsDistributionFactory):
    def __init__(self, clip_range=1.0):
        self.clip_range = clip_range

    def get_dist(self, params, input_params=None, t=None):
        return DeltaDistribution(params.squeeze(-1), self.clip_range)


class BernoulliFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Bernoulli(logits=params.squeeze(-1))


class CategoricalFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Categorical(logits=params)
    
class PredDistToDataDistFactory(DiscreteDistributionFactory):
    def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
        self.data_dist_factory = data_dist_factory
        # 之所以设为 False 是因为在以下 noise_pred_params_to_data_pred_params() 方法中会将对数标准差使用自然指数进行转换,
        # 而无需原数据分布的工厂自行转换.
        self.data_dist_factory.log_dev = False
        self.min_variance = min_variance
        self.min_t = min_t

    def get_dist(self, params, input_params, t):
        data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)
        return self.data_dist_factory.get_dist(data_pred_params)

## BayesianFlow class -CtsBayesianFlow and DiscreteBayesianFlow

In [None]:
class BayesianFlow(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:
        """Returns the initial input params (for a batch) at t=0. Used during sampling.
        For discrete data, the tuple has length 1 and contains the initial class probabilities.
        For continuous data, the tuple has length 2 and contains the mean and precision.
        
        返回起始时刻的先验参数, 作为模型的输入, 方法用于采样过程的开端."""
        pass

    @abstractmethod
    def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
        """Utility method to convert input distribution params to network inputs if needed.
        
        如果有必要的话, 将输入分布的参数转换为适合模型输入的形式.
        比如在建模离散化数据时, 输入分布的参数代表概率, 取值范围在[0,1], 于是在输入模型前会将其 scale 至[-1,1],
        从而与其他类型的数据场景兼容, 并且避免让模型永远只接收非负值."""
        pass

    @abstractmethod
    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
        """Returns the alpha at step i of total n_steps according to the flow schedule. Used:
        a) during sampling, when i and alpha are the same for all samples in the batch.
        b) during discrete time loss computation, when i and alpha are different for samples in the batch.
        
        计算某个离散时间步所对应的精度: /alpha_i = /beta(t_i) - /beta(t_{i-1}), 用于采样过程或离散时间的损失函数. """
        pass

    @abstractmethod
    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        """Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:
        a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.
        b) during discrete time loss computation when alpha are different for samples in the batch.
        
        返回指定精度 \alpha 下的输入分布. """
        pass

    @abstractmethod
    def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:
        """Updates the distribution parameters using Bayes' theorem in light of noisy sample y.
        Used during sampling when alpha is the same for the whole batch.
        
        根据贝叶斯定理利用观测样本 y 计算后验, 从而更新先验. """
        pass

    @abstractmethod
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
        """Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.
        Used during training when t (and thus accuracies) are different for different samples in the batch.
        For discrete data, the returned tuple has length 1 and contains the class probabilities.
        For continuous data, the returned tuple has length 2 and contains the mean and precision.
        
        从贝叶斯流分布中采样得到后验, 代表对输入分布参数的更新. """
        pass

In [None]:
class CtsBayesianFlow(BayesianFlow):
    """建模连续/离散化数据的贝叶斯流."""
    
    def __init__(
        self,
        min_variance: float = 1e-6,
    ):
        super().__init__()
        self.min_variance = min_variance

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
        """返回贝叶斯流分布的采样结果, 即经过后验更新的输入分布的均值向量: \mu."""
        
        # \omega_1^{2t}
        post_var = torch.pow(self.min_variance, t)
        # \gamma(t)
        alpha_t = 1 - post_var
        # \gamma(t)(1-\gamma(t))
        mean_var = alpha_t * post_var
        
        # 贝叶斯流分布的均值: \gamma(t)x
        mean_mean = alpha_t * data
        # 贝叶斯流分布的标准差: \sqrt{\gamma(t)(1-\gamma(t))}
        mean_std_dev = mean_var.sqrt()
        
        # 标准高斯噪声
        noise = torch.randn(mean_mean.shape, device=mean_mean.device)
        # 利用重参数化技术构造贝叶斯流分布的样本
        mean = mean_mean + (mean_std_dev * noise)
        
        # We don't need to compute the variance because it is not needed by the network, so set it to None
        input_params = (mean, None)
        
        return input_params
    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        # 仅取输入分布的均值向量作为 BFN 的输入
        # Only the mean is used by the network
        return params[0]

    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:
        # 起始时刻的先验是标准高斯分布, 均值为0, 方差为1(协方差矩阵是对角元均为1的对角阵)
        return torch.zeros(*data_shape, device=device), 1.0

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        # 根据 \beta(t_i) - \beta(t_{i-1}) 计算, 其中 t_i = \frac{i}{n}.
        sigma_1 = math.sqrt(self.min_variance)
        return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        # 返回输入分布, 精度 \alpha 是方差的倒数.
        dist = D.Normal(x, 1.0 / alpha**0.5)
        return dist

    def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:
        """贝叶斯更新函数, 对输入分布的参数进行后验更新."""
        
        input_mean, input_precision = input_params
        # \rho_i = \rho_{i-1} + \alpha
        new_precision = input_precision + alpha
        # 根据贝叶斯定理计算: \mu_i = \frac{ \rho_{i-1} \mu_{i-1} + \alpha y }{\rho_i}
        new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision
        
        return new_mean, new_precision

In [None]:
class DiscreteBayesianFlow(BayesianFlow):
    def __init__(
        self,
        n_classes: int,
        min_sqrt_beta: float = 1e-10,
        discretize: bool = False,
        epsilon: float = 1e-6,
        max_sqrt_beta: float = 1,
    ):
        super().__init__()
        
        # K
        self.n_classes = n_classes
        # 一个极小值, 用于将传入贝叶斯流分布的时间变量最大值限制至 1-epsilon.
        # 因为贝叶斯流分布是用于最终时刻前的, 所以需要 t < 1.
        self.epsilon = epsilon
        
        # 是否进行离散化操作
        self.discretize = discretize
        
        # \sqrt{\beta} 的下限
        self.min_sqrt_beta = min_sqrt_beta
        # \sqrt{\beta(1)}
        self.max_sqrt_beta = max_sqrt_beta
        
        # 均匀分布的期望熵: H = - \sum_{i=1}^K{p(x_i)ln(p(x_i))}, p(x_i)=\frac{1}{K}
        self.uniform_entropy = math.log(self.n_classes)

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
        """根据贝叶斯流分布完成后验更新."""
        
        if self.discretize:
            # 若要进行离散化操作, 则将数据以对应的离散化区间索引表示.
            data = float_to_idx(data, self.n_classes)
        
        # \sqrt{\beta(t)}
        sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))
        lo_beta = sqrt_beta < self.min_sqrt_beta
        sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)
        # \beta(t)
        beta = sqrt_beta.square().unsqueeze(-1)
        
        # 从精度参数为 \beta(t) 的发送者分布中采样观测样本以作为贝叶斯流分布的 logits.
        logits = self.count_sample(data, beta)
        probs = F.softmax(logits, -1)
        # 将精度太小的部分所对应的后验以均匀先验 \frac{1}{K} 代替.
        # 这是因为精度太小, 那么对应的观测样本也"不靠谱"——所包含真实数据的信息太少,
        # 将其作为 logits 就不靠谱, 即以此为根据而实现的后验更新意义不大.
        probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)
        if self.n_classes == 2:
            # 如果是二分类则只取其中一类的概率即可.
            probs = probs[..., :1]
            probs = probs.reshape_as(data)
            
        input_params = (probs,)
        
        return input_params

    def t_to_sqrt_beta(self, t):
        """计算当前时刻的 accuracy schedule: \beta(t) 的开根:
           sqrt{\beta(t)} = t \sqrt{\beta(1)}."""
        
        return t * self.max_sqrt_beta

    def count_dist(self, x, beta=None) -> D.Distribution:
        """贝叶斯流分布中的期望部分所对应的发送者分布."""

        # Ke_x - 1
        mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1
        # \sqrt{K}
        std_dev = math.sqrt(self.n_classes)
        
        if beta is not None:
            # \beta(t)(Ke_x - 1)
            mean = mean * beta
            # \sqrt{\beta(t)K}
            std_dev = std_dev * beta.sqrt()
            
        return D.Normal(mean, std_dev, validate_args=False)

    def count_sample(self, x, beta):
        """利用重参数化采样技术(rsample())采样出观测样本作为贝叶斯流分布的 logits 源(下一步将其输入 softmax 以实现后验更新)."""
        return self.count_dist(x, beta).rsample()

    def float_to_idx(data: Tensor, n_classes: int) -> Tensor:
        """Convert continuous data to discrete indices."""
        # Assuming data is normalized between 0 and 1
        return (data * n_classes).long().clamp(0, n_classes - 1)
    
    @torch.no_grad()
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:
        """初始先验: 各类别概率相等的均匀分布 U{1, K}."""
        
        # 注意返回的是元组, 这是为了与连续/离散化数据的场景保持一致性.
        return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)

    @torch.no_grad()
    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        params = params[0]
        if self.n_classes == 2:
            # 作者使用的 MNIST 数据集是经过二值化处理的, 因此这部分针对 MNIST 操作,
            # 将模型输入的范围缩放至 [-1,1]
            params = params * 2 - 1  # We scale-shift here for MNIST instead of in the network like for text
            # 因为总共只有两个类别, 所以取其中一类所对应的概率即可.
            params = params[..., :1]
            
        return params

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        # 计算离散时间步所对应的精度: \alpha_i = \beta(1) \frac{2i-1}{n^2}
        return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        e_x = F.one_hot(x.long(), self.n_classes)
        alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha
        dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)
        
        return dist

    def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:
        """贝叶斯更新函数: 利用贝叶斯定理计算后验."""
        
        new_input_params = input_params[0] * y.exp()
        new_input_params /= new_input_params.sum(-1, keepdims=True)
        
        # 注意返回的是元组
        return (new_input_params,)

## Loss class -cts and discrete

In [None]:
class Loss(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:
        """Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        连续时间的损失函数. """
        pass

    @abstractmethod
    def discrete_time_loss(
        self, data: Tensor,
        output_params: Tensor, input_params: Tensor,
        t: Tensor, n_steps: int, n_samples: int = 20
    ) -> Tensor:
        """Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using
        n_samples for Monte Carlo estimation of the discrete loss.
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        离散时间的损失函数, 当所需计算的 KL 散度没有解析形式时, 使用蒙特卡洛方法来近似估计. """
        pass

    @abstractmethod
    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        """Returns the reconstruction loss, i.e. the final cost of transmitting clean data.
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        重构损失, 不参与训练. """
        pass

In [None]:
class CtsBayesianFlowLoss(Loss):
    """建模连续/离散化数据场景时所用的损失函数, 包括：
    -离散时间损失函数;
    -连续时间损失函数;
    -重构损失"""
    
    def __init__(
        self,
        bayesian_flow: CtsBayesianFlow,
        distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],
        min_loss_variance: float = -1,
        noise_pred: bool = True,
    ):
        super().__init__()
        
        self.bayesian_flow = bayesian_flow
        # 返回输出分布的factory对象
        self.distribution_factory = distribution_factory
        # \sigma_1^{2} 的下限, 以防用作分母时溢出.
        self.min_loss_variance = min_loss_variance
        # -ln(\sigma_1)
        self.C = -0.5 * math.log(bayesian_flow.min_variance)
        
        # 是否预测噪声(亦或是直接预测数据)
        self.noise_pred = noise_pred
        if self.noise_pred:
            self.distribution_factory.log_dev = False
            # 在预测噪声的情况下, 将预测的噪声(或噪声分布相关的参数)转换为对应数据分布(输出分布)的参数.
            self.distribution_factory = PredDistToDataDistFactory(
                self.distribution_factory, self.bayesian_flow.min_variance
            )

    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        # 模型输出
        # reshape 成3维:(B, -1, D)
        output_params = sandwich(output_params)

        
        t = t.flatten(start_dim=1).float()
        flat_target = data.flatten(start_dim=1)
        
        # \sigma_1^{2t}
        posterior_var = torch.pow(self.bayesian_flow.min_variance, t)
        if self.min_loss_variance > 0:
            # 做最小值截断, 以防其作分母时防止溢出
            posterior_var = posterior_var.clamp(min=self.min_loss_variance)
        
        # 输出分布
        pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        # 输出分布的均值 E[P(\theta, t)]
        pred_mean = pred_dist.mean
        
        mse_loss = (pred_mean - flat_target).square()
        # 连续时间的损失函数计算公式: -ln(\sigma_1) \sigma_1{-2t} || x - E[P(\theta, t)] ||^2
        loss = self.C * mse_loss / posterior_var
        
        return loss
    def discrete_time_loss(
        self, data: Tensor,
        output_params: Tensor, input_params: Tensor,
        t: Tensor, n_steps: int, n_samples=10
    ) -> Tensor:
        # (B,-1,D)
        output_params = sandwich(output_params)
        t = t.flatten(start_dim=1).float()
        
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)

        # 离散化数据的场景
        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            t = t.flatten(start_dim=1)
            i = t * n_steps + 1  # since t = (i - 1) / n
            
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            
            flat_target = data.flatten(start_dim=1)
            # 发送者分布
            sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
            # 因为使用蒙特卡洛方法来估计发送者分布与接收者分布之间的 KL 散度，所以要从发送者分布中采样观测样本 y,
            # 采样的样本数默认为10.
            y = sender_dist.sample(torch.Size([n_samples]))
            
            # 模型输出的分配到各离散化区间的概率值. 
            #(B,D,K)
            receiver_mix_wts = sandwich(output_dist.probs)
            # 输出分布是类别分布, 在每个离散化区间都分配一定概率.
            receiver_mix_dist= D.Categorical(probs=receiver_mix_wts, validate_args=False)
            # 以各离散化区间的中心为均值构造多个一维高斯分布，其中每个都与发送者分布的形式一致(噪声强度相等, 即方差一致).\
            receiver_components = D.Normal(
                output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False
            )
            # 接收者分布, 在数据的每个维度上都是混合高斯分布.
            receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)
            
            # (B,1)
            loss = (
                (sender_dist.log_prob(y) - receiver_dist.log_prob(y))  # 发送者分布和接收者分布的概率密度对数差
                .mean(0)  # 在蒙特卡洛采样的样本数上做平均
                .flatten(start_dim=1)
                .mean(1, keepdims=True)
            )
        # 连续数据的场景
        else:  # output distribution is normal
            pred_mean = output_dist.mean
            flat_target = data.flatten(start_dim=1)
            mse_loss = (pred_mean - flat_target).square()
            i = t * n_steps + 1
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            loss = alpha * mse_loss / 2
            
        return n_steps * loss
    
    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        output_params = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        
        # 重构损失只发生在最后时刻，于是 t=1.
        t = torch.ones_like(data).flatten(start_dim=1).float()
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        
        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            reconstruction_loss = -output_dist.log_prob(flat_data)
        else:  # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)
            if self.bayesian_flow.min_variance == 1e-3:  # used for 16 bin CIFAR10
                noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 16
            else:
                noise_dev = math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 256
                
            mean = output_dist.mean.flatten(start_dim=1)
            final_dist = D.Normal(mean, noise_dev)
            # 离散化的正态分布
            final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)
            reconstruction_loss = -final_dist.log_prob(flat_data)
            
        return reconstruction_loss

In [None]:
class DiscreteBayesianFlowLoss(Loss):
    def __init__(
        self,
        bayesian_flow: DiscreteBayesianFlow,
        distribution_factory: DiscreteDistributionFactory,
    ):
        super().__init__()
        
        self.bayesian_flow = bayesian_flow
        self.distribution_factory = distribution_factory
        # 离散数据的输出分布建模为类别分布，这个变量就代表类别数量.
        self.K = self.bayesian_flow.n_classes

    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        flat_output = sandwich(output_params)
        # 输出分布在各类别上分配的概率
        pred_probs = self.distribution_factory.get_dist(flat_output).probs
        
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)

        tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)
        kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)
        t = t.flatten(start_dim=1).float()
        loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl
        
        return loss
    
    def discrete_time_loss(
        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10) -> Tensor:
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)
        
        # 根据 t = \frac{i-1}{n} 反过来计算 i 
        i = t * n_steps + 1
        # \alpha_i
        alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)

        # (B,D,K)
        flat_output = sandwich(output_params)
        # 模型预测的在各个类别上的概率.
        receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs
        # 这里之所以要在倒数第2个维度上加一维是因为以下 components 在每个类别上的均值向量都是 K 维 one-hot,
        # 从而在每个类别上生成的是 K 个相互独立的正态分布. 总共有 K 类, 于是就有 K x K 个分布.
        # 因此这里增加维度是为了让 categorical 权重 与 components 对齐.
        receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))
        
        # 增加2个维度是为了对应 batch dim: B 和 data dim: D.
        classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)
        receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))
        # 接收者分布, 它是多个混合高斯分布的联合分布, 其中每个数据维度都是混合高斯分布.
        receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)
        
        sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
        # 从发送者分布中采样, 以蒙特卡洛方法近似估计其与接收者分布之间的 KL loss
        y = sender_dist.sample(torch.Size([n_samples]))
        
        # (B,1)
        loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)
        
        return loss
    
    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        flat_outputs = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        output_dist = self.distribution_factory.get_dist(flat_outputs)
        
        return -output_dist.log_prob(flat_data)

    def float_to_idx(data: Tensor, n_classes: int) -> Tensor:
        """Convert continuous data to discrete indices."""
        # Assuming data is normalized between 0 and 1
        return (data * n_classes).long().clamp(0, n_classes - 1)

## BFN (not network)

In [None]:
class BFN(nn.Module):
    def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):
        super().__init__()
        
        self.net = net
        self.bayesian_flow = bayesian_flow
        self.loss = loss

    @staticmethod
    @torch.no_grad()
    def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
        """采样时间变量 t, 包括连续时间和离散时间两种情况."""
        
        # 连续时间情况不需要指定总步数, 从 U(0,1) 连续型均匀分布中采样.
        if n_steps == 0 or n_steps is None:
            # (B,1)
            t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)
        # 离散时间情况则先从 U{0,n-1} 离散型均匀分布采样出时间步，然后再除总步数 n 计算出对应的时间变量值: t = \frac{i-1}{n}
        # 注意, 这是每个区间起始时刻的值.
        else:
            # (B,1)
            t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps
        # 扩展至和数据同样的维度, 不同的数据样本的时间变量不一致, 同一个样本内所有维度上所对应的时间变量则相同.
        t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)
        
        return t

    def forward(
        self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None
    ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:
        """
        Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.
        t is sampled randomly if None. If t is not None, expect t.shape == data.shape.
        
        使用蒙特卡洛方法估计发送者分布和接收者分布之间的 KL 散度损失:
        -采样时间变量;
        -从贝叶斯流分布中采样得到输入分布的参数(后验更新);
        -将输入分布的参数喂给模型;
        -模型返回输出分布;
        -计算连续/离散时间 loss.
        """

        t = self.sample_t(data, n_steps) if t is None else t
        
        # sample input parameter flow
        # 从贝叶斯流分布中采样出输入分布的参数(代表已完成后验更新).
        input_params = self.bayesian_flow(data, t)
        # 在输入模型前转换为适合于模型输入的形式(如有必要的话)
        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)
        # compute output distribution parameters
        # 注意, 这里模型输出的通常不是输出分布的参数, 而是某些变量(比如估计的噪声),
        # 它们经过后处理才最终成为输出分布的参数.
        output_params: Tensor = self.net(net_inputs, t)

        # compute KL loss in float32
        with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False):
            if n_steps == 0 or n_steps is None:
                loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)
            else:
                loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)

        # loss shape is (batch_size, 1)
        return loss.mean()
    
    @torch.inference_mode()
    def sample(self, data_shape: tuple, n_steps: int) -> Tensor:
        device = next(self.parameters()).device
        
        # 起始时刻的先验
        input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)
        distribution_factory = self.loss.distribution_factory

        for i in range(1, n_steps):
            # t_{i-1} = \frac{i-1}{n}
            t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps
            
            # 模型接收输入分布的参数并预测，形成输出分布的参数后，再从其中采样作为预测(生成)的数据样本.
            output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
            output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()
            output_sample = output_sample.reshape(*data_shape)
            
            # 计算精度 \alpha_i
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            # 采样观测样本
            y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()
            # 后验更新
            input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)

        # 最后时刻 t=1
        t = torch.ones(*data_shape, device=device)
        output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
        # 概率分布的众数(mode)作为样本.
        output_sample = distribution_factory.get_dist(output_params, input_params, t).mode
        output_sample = output_sample.reshape(*data_shape)
        
        return output_sample

## Other

In [None]:
class FourierFeatures(nn.Module):
    def __init__(self, first=5.0, last=6.0, step=1.0):
        super().__init__()
        self.freqs_exponent = torch.arange(first, last + 1e-8, step)

    @property
    def num_features(self):
        return len(self.freqs_exponent) * 2

    def forward(self, x):
        assert len(x.shape) >= 2

        # Compute (2pi * 2^n) for n in freqs.
        freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device)  # (F, )
        freqs = 2.0**freqs_exponent * 2 * torch.pi  # (F, )
        freqs = freqs.view(-1, *([1] * (x.dim() - 1)))  # (F, 1, 1, ...)

        # Compute (2pi * 2^n * x) for n in freqs.
        features = freqs * x.unsqueeze(1)  # (B, F, X1, X2, ...)
        features = features.flatten(1, 2)  # (B, F * C, X1, X2, ...)

        # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).
        return torch.cat([features.sin(), features.cos()], dim=1)

In [None]:
def attention_inner_heads(qkv, num_heads):
    """Computes attention with heads inside of qkv in the channel dimension.

    Args:
        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
            H = number of heads,
            C = number of channels per head.
        num_heads: number of heads.

    Returns:
        Attention output of shape (B, H*C, T).
    """

    bs, width, length = qkv.shape
    ch = width // (3 * num_heads)

    # Split into (q, k, v) of shape (B, H*C, T).
    q, k, v = qkv.chunk(3, dim=1)

    # 对 Q, K 各自缩放 1/d^{1/4} 相当于 Q, K 矩阵相乘后的结果缩放了 1/(\sqrt{d})
    # Rescale q and k. This makes them contiguous in memory.
    scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt
    q = q * scale
    k = k * scale

    # Reshape qkv to (B*H, C, T).
    new_shape = (bs * num_heads, ch, length)
    q = q.view(*new_shape)
    k = k.view(*new_shape)
    v = v.reshape(*new_shape)

    # Compute attention.
    weight = einsum("bct,bcs->bts", q, k)  # (B*H, T, T)
    weight = softmax(weight.float(), dim=-1).to(weight.dtype)  # (B*H, T, T)
    out = einsum("bts,bcs->bct", weight, v)  # (B*H, C, T)
    
    return out.reshape(bs, num_heads * ch, length)  # (B, H*C, T)


class Attention(nn.Module):
    """Based on https://github.com/openai/guided-diffusion."""

    def __init__(self, n_heads):
        super().__init__()
        
        self.n_heads = n_heads

    def forward(self, qkv):
        assert qkv.dim() >= 3, qkv.dim()
        assert qkv.shape[1] % (3 * self.n_heads) == 0
        
        spatial_dims = qkv.shape[2:]
        qkv = qkv.view(*qkv.shape[:2], -1)  # (B, 3*n_heads*C, T)
        out = attention_inner_heads(qkv, self.n_heads)  # (B, n_heads*C, T)
        
        return out.view(*out.shape[:2], *spatial_dims).contiguous()


class AttentionBlock(nn.Module):
    """Self-attention residual block."""

    def __init__(self, n_heads, n_channels, norm_groups):
        super().__init__()
        
        assert n_channels % n_heads == 0
        
        self.layers = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
            # 之所以将通道数扩展3倍是因为后续要输入到 Attention 模块, 为 Q, K ,V 各分配数量一致的通道数.
            nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),  # (B, 3 * C, H, W)
            Attention(n_heads),
            # 输出卷积层初始化为全0，因此在参数更新前这部分输出特征相当于不起作用.
            zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
        )

    def forward(self, x):
        return self.layers(x) + x

In [None]:
class ResnetBlock(nn.Module):
    def __init__(
        self,
        ch_in,
        ch_out=None,
        condition_dim=None,
        dropout_prob=0.0,
        norm_groups=32,
    ):
        super().__init__()
        
        ch_out = ch_in if ch_out is None else ch_out
        
        self.ch_out = ch_out
        self.condition_dim = condition_dim
        
        self.net1 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),
            nn.SiLU(),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
        )
        
        if condition_dim is not None:
            self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))
        
        self.net2 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),
        )
        
        if ch_in != ch_out:
            self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)

    def forward(self, x, condition):
        h = self.net1(x)
        
        if condition is not None:
            assert condition.shape == (x.shape[0], self.condition_dim)
            
            # 这个条件映射层(全连接层)初始化为全0, 因此在参数更新前条件变量不起作用.
            condition = self.cond_proj(condition)
            # (B,D,1,1)
            condition = condition[:, :, None, None]
            h = h + condition
        
        h = self.net2(h)
        
        if x.shape[1] != self.ch_out:
            x = self.skip_conv(x)
        assert x.shape == h.shape
        
        return x + h

In [None]:
class UpDownBlock(nn.Module):
    def __init__(self, resnet_block, attention_block=None):
        super().__init__()
        
        self.resnet_block = resnet_block
        self.attention_block = attention_block

    def forward(self, x, cond):
        x = self.resnet_block(x, cond)
        if self.attention_block is not None:
            x = self.attention_block(x)
            
        return x

In [None]:
def idx_to_float(idx: np.ndarray, num_bins: int):
    """将离散化区间索引 k 转换为对应的区间中心值 k_c.
    注意, 此处 k 的取值范围与论文中的不同, 论文中 k 的取值范围是 1~K, 而这里:
    k_c = \frac{2k+1}{K} - 1, where k \in [0, K-1]."""
    
    flt_zero_one = (idx + 0.5) / num_bins
    return (2.0 * flt_zero_one) - 1.0


def float_to_idx(flt: np.ndarray, num_bins: int):
    """根据离散化值 k_c 计算出对应的区间索引 k, 是 float_to_idx() 的逆向操作."""
    
    flt_zero_one = (flt / 2.0) + 0.5
    return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()


def quantize(flt, num_bins: int):
    """将浮点值量化以对应的离散化区间中点 k_c 表示, 因此看作是一个量化的过程."""
    return idx_to_float(float_to_idx(flt, num_bins), num_bins)


def rgb_image_transform(x, num_bins=256):
    """将 RGB 图像进行离散化, 其中 x \in [0,1]"""
    return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()

## Train

In [None]:
def make_from_cfg(module, cfg, **parameters):
    return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None


def make_bfn(cfg: DictConfig):
    data_adapters = {
        "input_adapter": make_from_cfg(adapters, cfg.input_adapter),
        "output_adapter": make_from_cfg(adapters, cfg.output_adapter),
    }
    
    net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)
    bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)
    distribution_factory = make_from_cfg(probability, cfg.distribution_factory)
    loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)
    bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)
    
    return bfn

In [None]:
def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
    """Create the model, dataloader and optimizer"""
    
    dataloaders = make_dataloaders(cfg)
    
    model = make_bfn(cfg.model)    
    if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"):
        # 区分了 decay 与不 decay 的参数.
        params = model.net.get_optim_groups(cfg.optimizer.weight_decay)
    else:
        params = model.net.parameters()
    
    # Instantiate the optimizer using the hyper-parameters in the config
    optimizer = optim.AdamW(params=params, **cfg.optimizer)
    
    return model, dataloaders, optimizer

In [None]:
import copy
import logging
import math

from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

import torch
import neptune

from accelerate import Accelerator
from accelerate.logging import get_logger

from omegaconf import OmegaConf

from rich.logging import RichHandler
from rich.progress import Progress

from torch import nn, optim
from torch.utils.data import DataLoader

from model import BFN
from utils_train import (
    seed_everything, log_cfg,
    checkpoint_training_state,
    init_checkpointing,
    log,
    update_ema,
    ddict,
    make_infinite,
    make_progress_bar, make_config, make_dataloaders, make_bfn,
)


torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True, show_time=False)],
)

logger = get_logger(__name__)


def ddict():
    """Infinite default dict to fake neptune run on non-main processes"""
    return defaultdict(ddict)


def main(cfg):
    acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)

    cfg.training.seed = seed_everything(cfg.training.seed)
    logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True)

    with acc.main_process_first():
        model, dataloaders, optimizer = setup(cfg)
        
    ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None  # EMA on main proc only
    model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"])
    
    # 这个 ddict() 对象是一个无限嵌套的 defaultdict，将其视作假的 neptune run 对象，
    # 用于主进程之外的其它进程，类似一种 placeholder 的角色，而主进程会重新对 run 变量进行赋值，使其成为真正的neptune run 对象。
    run = ddict()
    
    if acc.is_main_process:
        ema.to(acc.device)
        try:
            if cfg.meta.neptune:
                import neptune
                
                run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None)
                run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)
                log_cfg(cfg, run)
        except ImportError:
            logger.info("Did not find neptune installed. Logging will be disabled.")

    train(cfg.training, acc, model, ema, dataloaders, optimizer, run)


if __name__ == "__main__":
    cfg_file = OmegaConf.from_cli()['config_file']
    main(make_config(cfg_file))