In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange #pip install einops
from typing import List
import random
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from timm.utils import ModelEmaV3 
from tqdm import tqdm 
import matplotlib.pyplot as plt 
import torch.optim as optim
import numpy as np
import os  


  from .autonotebook import tqdm as notebook_tqdm


In [12]:
class SinusoidalEmbeddings(nn.Module):
    """
    使用正弦/余弦函数对时间步（timestep）进行位置编码，将离散的时间步 t 映射为连续的、具有周期性的高维向量。
    """
    def __init__(self, time_steps: int, embed_dim: int):
        super().__init__()
        # 创建形状为 [time_steps, 1] 的时间步索引
        position = torch.arange(time_steps).unsqueeze(1).float()  

        # 计算频率缩放因子，用于控制不同维度的周期
        div = torch.exp(
            torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim)
        ) 

        embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
        embeddings[:, 0::2] = torch.sin(position * div)        # 偶数维度用 sin 编码
        embeddings[:, 1::2] = torch.cos(position * div)      # 奇数维度用 cos 编码
        self.register_buffer('embeddings', embeddings)

    def forward(self, x, t):
        """
        根据输入的时间步索引 t，取出对应的时间嵌入向量，并调整形状以匹配特征图 x。
        """
        # 从预计算的 embeddings 表中取出 t 对应的嵌入向量
        embeds = self.embeddings[t].to(x.device)

        # 添加两个维度 (H=1, W=1)，变为 (B, embed_dim, 1, 1)
        # 这样在后续与特征图 x (B, C, H, W) 相加时，embeds 会自动广播到每个空间位置
        return embeds[:, :, None, None]

In [13]:
#残差网络， 带时间嵌入注入的残差块，使网络能根据当前时间步 t 调整去噪行为。
class ResBlock(nn.Module):
    def __init__(self, C: int, num_groups: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob, inplace=True)

    def forward(self, x, embeddings):
        x = x + embeddings[:, :x.shape[1], :, :]
        r = self.conv1(self.relu(self.gnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.gnorm2(r)))
        return r + x




In [14]:
#将输入特征图的每个空间位置视为一个 token，计算它们之间的全局依赖关系。
class Attention(nn.Module):
    def __init__(self, C: int, num_heads: int, dropout_prob: float):
        super().__init__()
        self.proj1 = nn.Linear(C, C * 3)
        self.proj2 = nn.Linear(C, C)
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob

    def forward(self, x):
        h, w = x.shape[2:]
        x = rearrange(x, 'b c h w -> b (h w) c')#将空间维度展平为序列 
        x = self.proj1(x)#线性投影得到 Q, K, V 
        head_dim = x.shape[-1] // (3 * self.num_heads)# # 计算每个注意力头的维度
        
        x = rearrange(x, 'b L (K H C) -> K b H L C', K=3, H=self.num_heads, C=head_dim)
        q, k, v = x[0], x[1], x[2]
        x = F.scaled_dot_product_attention(q, k, v, is_causal=False, dropout_p=self.dropout_prob)
        x = rearrange(x, 'b H L C -> b L (H C)')
        x = rearrange(x, 'b (h w) C -> b h w C', h=h, w=w)
        x = self.proj2(x)
        return rearrange(x, 'b h w C -> b C h w')


